OAI: Fix usage chunk return

Place the logic into their proper utility functions and cleanup
the code with formatting.

Also, OAI's docs specify that a [DONE] return is needed when everything
is finished.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-12 14:35:48 -04:00
parent b149d3398d
commit c1b61441f4

View file

@ -93,22 +93,37 @@ def _create_stream_chunk(
const_id: str,
generation: Optional[dict] = None,
model_name: Optional[str] = None,
is_usage_chunk: bool = False,
):
"""Create a chat completion stream chunk from the provided text."""
index = generation.get("index")
logprob_response = None
choices = []
usage_stats = None
if "finish_reason" in generation:
if is_usage_chunk:
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
usage_stats = UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
elif "finish_reason" in generation:
choice = ChatCompletionStreamChoice(
index=index,
finish_reason=generation.get("finish_reason"),
)
choices.append(choice)
else:
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
)
logprob_response = None
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), {})
@ -132,8 +147,13 @@ def _create_stream_chunk(
logprobs=logprob_response,
)
choices.append(choice)
chunk = ChatCompletionStreamChunk(
id=const_id, choices=[choice], model=unwrap(model_name, "")
id=const_id,
choices=choices,
model=unwrap(model_name, ""),
usage=usage_stats,
)
return chunk
@ -246,7 +266,6 @@ async def stream_generate_chat_completion(
gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = []
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
need_usage = data.stream_options and data.stream_options.include_usage
try:
gen_params = data.to_gen_params()
@ -276,26 +295,18 @@ async def stream_generate_chat_completion(
raise generation
response = _create_stream_chunk(const_id, generation, model_path.name)
yield response.model_dump_json(exclude=None if need_usage else "usage")
yield response.model_dump_json()
# Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty():
if need_usage:
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
response = ChatCompletionStreamChunk(
id=const_id,
choices=[],
model=unwrap(model_path.name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
# Send a usage chunk
if data.stream_options and data.stream_options.include_usage:
usage_chunk = _create_stream_chunk(
const_id, generation, model_path.name, is_usage_chunk=True
)
yield usage_chunk.model_dump_json()
yield response.model_dump_json()
yield "[DONE]"
break
except CancelledError:
# Get out if the request gets disconnected