OAI: support stream_options argument

This commit is contained in:
Volodymyr Kuznetsov 2024-07-08 13:42:54 -07:00
parent 073e9fa6f0
commit b149d3398d
3 changed files with 24 additions and 1 deletions

View file

@ -64,3 +64,4 @@ class ChatCompletionStreamChunk(BaseModel):
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "chat.completion.chunk"
usage: Optional[UsageStats] = None

View file

@ -18,6 +18,10 @@ class CompletionResponseFormat(BaseModel):
type: str = "text"
class ChatCompletionStreamOptions(BaseModel):
include_usage: Optional[bool] = False
class CommonCompletionRequest(BaseSamplerRequest):
"""Represents a common completion request."""
@ -27,6 +31,7 @@ class CommonCompletionRequest(BaseSamplerRequest):
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
stream_options: Optional[ChatCompletionStreamOptions] = None
logprobs: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("logprobs", 0)
)

View file

@ -246,6 +246,7 @@ async def stream_generate_chat_completion(
gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = []
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
need_usage = data.stream_options and data.stream_options.include_usage
try:
gen_params = data.to_gen_params()
@ -275,10 +276,26 @@ async def stream_generate_chat_completion(
raise generation
response = _create_stream_chunk(const_id, generation, model_path.name)
yield response.model_dump_json()
yield response.model_dump_json(exclude=None if need_usage else "usage")
# Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty():
if need_usage:
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
response = ChatCompletionStreamChunk(
id=const_id,
choices=[],
model=unwrap(model_path.name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
yield response.model_dump_json()
break
except CancelledError:
# Get out if the request gets disconnected