Model + API: Migrate to use BaseSamplerParams

kwargs is pretty ugly when figuring out which arguments to use. The
base requests falls back to defaults anyways, so pass in the params
object as is.

However, since Python's typing isn't like TypeScript where types
can be transformed, the type hinting has a possiblity of None showing
up despite there always being a value for some params.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-16 00:50:05 -04:00
parent dcb36e9ab2
commit 3084ef9fa1
5 changed files with 113 additions and 121 deletions

View file

@ -32,10 +32,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
stream_options: Optional[ChatCompletionStreamOptions] = None
logprobs: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("logprobs", 0),
ge=0,
)
response_format: Optional[CompletionResponseFormat] = Field(
default_factory=CompletionResponseFormat
)

View file

@ -333,11 +333,11 @@ async def stream_generate_chat_completion(
_stream_collector(
n,
gen_queue,
prompt,
request.state.id,
prompt,
task_gen_params,
abort_event,
embeddings=embeddings,
**task_gen_params.model_dump(exclude={"prompt"}),
mm_embeddings=embeddings,
)
)
@ -422,10 +422,10 @@ async def generate_chat_completion(
gen_tasks.append(
asyncio.create_task(
model.container.generate(
prompt,
request.state.id,
embeddings=embeddings,
**data.model_dump(exclude={"prompt"}),
prompt,
data,
mm_embeddings=embeddings,
)
)
)
@ -465,7 +465,6 @@ async def generate_tool_calls(
# FIXME: May not be necessary depending on how the codebase evolves
tool_data = data.model_copy(deep=True)
tool_data.json_schema = tool_data.tool_call_schema
gen_params = tool_data.model_dump()
for idx, gen in enumerate(generations):
if gen["stop_str"] in tool_data.tool_call_start:
@ -488,10 +487,10 @@ async def generate_tool_calls(
gen_tasks.append(
asyncio.create_task(
model.container.generate(
pre_tool_prompt,
request.state.id,
pre_tool_prompt,
tool_data,
embeddings=mm_embeddings,
**gen_params,
)
)
)

View file

@ -8,12 +8,12 @@ import asyncio
import pathlib
from asyncio import CancelledError
from fastapi import HTTPException, Request
from typing import List, Union
from loguru import logger
from typing import List, Optional, Union
from common import model
from common.auth import get_key_permission
from common.multimodal import MultimodalEmbeddingWrapper
from common.networking import (
get_generator_error,
handle_request_disconnect,
@ -86,16 +86,21 @@ def _create_response(
async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
request_id: str,
prompt: str,
params: CompletionRequest,
abort_event: asyncio.Event,
**kwargs,
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
):
"""Collects a stream and places results in a common queue"""
try:
new_generation = model.container.generate_gen(
prompt, request_id, abort_event, **kwargs
request_id,
prompt,
params,
abort_event,
mm_embeddings,
)
async for generation in new_generation:
generation["index"] = task_idx
@ -195,10 +200,10 @@ async def stream_generate_completion(
_stream_collector(
n,
gen_queue,
data.prompt,
request.state.id,
data.prompt,
task_gen_params,
abort_event,
**task_gen_params.model_dump(exclude={"prompt"}),
)
)
@ -256,9 +261,9 @@ async def generate_completion(
gen_tasks.append(
asyncio.create_task(
model.container.generate(
data.prompt,
request.state.id,
**task_gen_params.model_dump(exclude={"prompt"}),
data.prompt,
task_gen_params,
)
)
)