OAI: support stream_options argument

This commit is contained in:
Volodymyr Kuznetsov 2024-07-08 13:42:54 -07:00
parent 073e9fa6f0
commit b149d3398d
3 changed files with 24 additions and 1 deletions

View file

@ -246,6 +246,7 @@ async def stream_generate_chat_completion(
gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = []
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
need_usage = data.stream_options and data.stream_options.include_usage
try:
gen_params = data.to_gen_params()
@ -275,10 +276,26 @@ async def stream_generate_chat_completion(
raise generation
response = _create_stream_chunk(const_id, generation, model_path.name)
yield response.model_dump_json()
yield response.model_dump_json(exclude=None if need_usage else "usage")
# Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty():
if need_usage:
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
response = ChatCompletionStreamChunk(
id=const_id,
choices=[],
model=unwrap(model_path.name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
yield response.model_dump_json()
break
except CancelledError:
# Get out if the request gets disconnected