OAI: Fix usage chunk return
Place the logic into their proper utility functions and cleanup the code with formatting. Also, OAI's docs specify that a [DONE] return is needed when everything is finished. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
b149d3398d
commit
c1b61441f4
1 changed files with 30 additions and 19 deletions
|
|
@ -93,22 +93,37 @@ def _create_stream_chunk(
|
|||
const_id: str,
|
||||
generation: Optional[dict] = None,
|
||||
model_name: Optional[str] = None,
|
||||
is_usage_chunk: bool = False,
|
||||
):
|
||||
"""Create a chat completion stream chunk from the provided text."""
|
||||
|
||||
index = generation.get("index")
|
||||
logprob_response = None
|
||||
choices = []
|
||||
usage_stats = None
|
||||
|
||||
if "finish_reason" in generation:
|
||||
if is_usage_chunk:
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
|
||||
|
||||
usage_stats = UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
elif "finish_reason" in generation:
|
||||
choice = ChatCompletionStreamChoice(
|
||||
index=index,
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
)
|
||||
|
||||
choices.append(choice)
|
||||
else:
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
logprob_response = None
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
if token_probs:
|
||||
logprobs = unwrap(generation.get("logprobs"), {})
|
||||
|
|
@ -132,8 +147,13 @@ def _create_stream_chunk(
|
|||
logprobs=logprob_response,
|
||||
)
|
||||
|
||||
choices.append(choice)
|
||||
|
||||
chunk = ChatCompletionStreamChunk(
|
||||
id=const_id, choices=[choice], model=unwrap(model_name, "")
|
||||
id=const_id,
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
usage=usage_stats,
|
||||
)
|
||||
|
||||
return chunk
|
||||
|
|
@ -246,7 +266,6 @@ async def stream_generate_chat_completion(
|
|||
gen_queue = asyncio.Queue()
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
need_usage = data.stream_options and data.stream_options.include_usage
|
||||
|
||||
try:
|
||||
gen_params = data.to_gen_params()
|
||||
|
|
@ -276,26 +295,18 @@ async def stream_generate_chat_completion(
|
|||
raise generation
|
||||
|
||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||
yield response.model_dump_json(exclude=None if need_usage else "usage")
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Check if all tasks are completed
|
||||
if all(task.done() for task in gen_tasks) and gen_queue.empty():
|
||||
if need_usage:
|
||||
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
|
||||
|
||||
response = ChatCompletionStreamChunk(
|
||||
id=const_id,
|
||||
choices=[],
|
||||
model=unwrap(model_path.name, ""),
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
),
|
||||
# Send a usage chunk
|
||||
if data.stream_options and data.stream_options.include_usage:
|
||||
usage_chunk = _create_stream_chunk(
|
||||
const_id, generation, model_path.name, is_usage_chunk=True
|
||||
)
|
||||
yield usage_chunk.model_dump_json()
|
||||
|
||||
yield response.model_dump_json()
|
||||
yield "[DONE]"
|
||||
break
|
||||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue