OAI: support stream_options argument
This commit is contained in:
parent
073e9fa6f0
commit
b149d3398d
3 changed files with 24 additions and 1 deletions
|
|
@ -64,3 +64,4 @@ class ChatCompletionStreamChunk(BaseModel):
|
|||
created: int = Field(default_factory=lambda: int(time()))
|
||||
model: str
|
||||
object: str = "chat.completion.chunk"
|
||||
usage: Optional[UsageStats] = None
|
||||
|
|
|
|||
|
|
@ -18,6 +18,10 @@ class CompletionResponseFormat(BaseModel):
|
|||
type: str = "text"
|
||||
|
||||
|
||||
class ChatCompletionStreamOptions(BaseModel):
|
||||
include_usage: Optional[bool] = False
|
||||
|
||||
|
||||
class CommonCompletionRequest(BaseSamplerRequest):
|
||||
"""Represents a common completion request."""
|
||||
|
||||
|
|
@ -27,6 +31,7 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
|||
|
||||
# Generation info (remainder is in BaseSamplerRequest superclass)
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[ChatCompletionStreamOptions] = None
|
||||
logprobs: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("logprobs", 0)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -246,6 +246,7 @@ async def stream_generate_chat_completion(
|
|||
gen_queue = asyncio.Queue()
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
need_usage = data.stream_options and data.stream_options.include_usage
|
||||
|
||||
try:
|
||||
gen_params = data.to_gen_params()
|
||||
|
|
@ -275,10 +276,26 @@ async def stream_generate_chat_completion(
|
|||
raise generation
|
||||
|
||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||
yield response.model_dump_json()
|
||||
yield response.model_dump_json(exclude=None if need_usage else "usage")
|
||||
|
||||
# Check if all tasks are completed
|
||||
if all(task.done() for task in gen_tasks) and gen_queue.empty():
|
||||
if need_usage:
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
|
||||
|
||||
response = ChatCompletionStreamChunk(
|
||||
id=const_id,
|
||||
choices=[],
|
||||
model=unwrap(model_path.name, ""),
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
)
|
||||
|
||||
yield response.model_dump_json()
|
||||
break
|
||||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue