Model + API: Migrate to use BaseSamplerParams
kwargs is pretty ugly when figuring out which arguments to use. The base requests falls back to defaults anyways, so pass in the params object as is. However, since Python's typing isn't like TypeScript where types can be transformed, the type hinting has a possiblity of None showing up despite there always being a value for some params. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
dcb36e9ab2
commit
3084ef9fa1
5 changed files with 113 additions and 121 deletions
|
|
@ -32,10 +32,6 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
|||
# Generation info (remainder is in BaseSamplerRequest superclass)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[ChatCompletionStreamOptions] = None
|
||||
logprobs: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("logprobs", 0),
|
||||
ge=0,
|
||||
)
|
||||
response_format: Optional[CompletionResponseFormat] = Field(
|
||||
default_factory=CompletionResponseFormat
|
||||
)
|
||||
|
|
|
|||
|
|
@ -333,11 +333,11 @@ async def stream_generate_chat_completion(
|
|||
_stream_collector(
|
||||
n,
|
||||
gen_queue,
|
||||
prompt,
|
||||
request.state.id,
|
||||
prompt,
|
||||
task_gen_params,
|
||||
abort_event,
|
||||
embeddings=embeddings,
|
||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
||||
mm_embeddings=embeddings,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -422,10 +422,10 @@ async def generate_chat_completion(
|
|||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
prompt,
|
||||
request.state.id,
|
||||
embeddings=embeddings,
|
||||
**data.model_dump(exclude={"prompt"}),
|
||||
prompt,
|
||||
data,
|
||||
mm_embeddings=embeddings,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -465,7 +465,6 @@ async def generate_tool_calls(
|
|||
# FIXME: May not be necessary depending on how the codebase evolves
|
||||
tool_data = data.model_copy(deep=True)
|
||||
tool_data.json_schema = tool_data.tool_call_schema
|
||||
gen_params = tool_data.model_dump()
|
||||
|
||||
for idx, gen in enumerate(generations):
|
||||
if gen["stop_str"] in tool_data.tool_call_start:
|
||||
|
|
@ -488,10 +487,10 @@ async def generate_tool_calls(
|
|||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
pre_tool_prompt,
|
||||
request.state.id,
|
||||
pre_tool_prompt,
|
||||
tool_data,
|
||||
embeddings=mm_embeddings,
|
||||
**gen_params,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ import asyncio
|
|||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from fastapi import HTTPException, Request
|
||||
from typing import List, Union
|
||||
|
||||
from loguru import logger
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from common import model
|
||||
from common.auth import get_key_permission
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
|
|
@ -86,16 +86,21 @@ def _create_response(
|
|||
async def _stream_collector(
|
||||
task_idx: int,
|
||||
gen_queue: asyncio.Queue,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
prompt: str,
|
||||
params: CompletionRequest,
|
||||
abort_event: asyncio.Event,
|
||||
**kwargs,
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None,
|
||||
):
|
||||
"""Collects a stream and places results in a common queue"""
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
request_id,
|
||||
prompt,
|
||||
params,
|
||||
abort_event,
|
||||
mm_embeddings,
|
||||
)
|
||||
async for generation in new_generation:
|
||||
generation["index"] = task_idx
|
||||
|
|
@ -195,10 +200,10 @@ async def stream_generate_completion(
|
|||
_stream_collector(
|
||||
n,
|
||||
gen_queue,
|
||||
data.prompt,
|
||||
request.state.id,
|
||||
data.prompt,
|
||||
task_gen_params,
|
||||
abort_event,
|
||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -256,9 +261,9 @@ async def generate_completion(
|
|||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
data.prompt,
|
||||
request.state.id,
|
||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
||||
data.prompt,
|
||||
task_gen_params,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue