diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index a646924..9393e90 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -30,7 +30,7 @@ from endpoints.OAI.types.chat_completion import ( ) from endpoints.OAI.types.common import UsageStats from endpoints.OAI.utils.completion import _stream_collector -from endpoints.OAI.types.tools import ToolCall +from endpoints.OAI.utils.tools import ToolCallProcessor def _create_response( @@ -49,7 +49,7 @@ def _create_response( tool_calls = generation["tool_calls"] if tool_calls: - message.tool_calls = postprocess_tool_call(tool_calls) + message.tool_calls = ToolCallProcessor.from_json(tool_calls) logprob_response = None @@ -74,9 +74,14 @@ def _create_response( logprob_response = ChatCompletionLogprobs(content=collected_token_probs) + # Finish reason will always be present in a completion + # If a tool call is present, mark the finish reason as such + if message.tool_calls: + finish_reason = "tool_calls" + choice = ChatCompletionRespChoice( index=index, - finish_reason=generation.get("finish_reason"), + finish_reason=finish_reason, stop_str=generation.get("stop_str"), message=message, logprobs=logprob_response, @@ -120,18 +125,19 @@ def _create_stream_chunk( total_tokens=prompt_tokens + completion_tokens, ) elif "finish_reason" in generation: - choice = ChatCompletionStreamChoice( - index=index, - finish_reason=generation.get("finish_reason"), - ) + # Get the finish reason from the generation + finish_reason = generation.get("finish_reason") + choice = ChatCompletionStreamChoice(index=index, finish_reason=finish_reason) # lets check if we have tool calls since we are at the end of the generation + # Mark finish_reason as tool_calls since this is the last chunk if "tool_calls" in generation: tool_calls = generation["tool_calls"] message = ChatCompletionMessage( - tool_calls=postprocess_tool_call(tool_calls) + tool_calls=ToolCallProcessor.from_json(tool_calls) ) choice.delta = message + choice.finish_reason = "tool_calls" choices.append(choice) @@ -224,7 +230,7 @@ async def format_messages_with_template( message.content = concatenated_content if message.tool_calls: - message.tool_calls_json = json.dumps(message.tool_calls, indent=2) + message.tool_calls_json = ToolCallProcessor.to_json(message.tool_calls) special_tokens_dict = model.container.get_special_tokens( add_bos_token, ban_eos_token @@ -482,12 +488,3 @@ async def generate_tool_calls( generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"] return generations - - -def postprocess_tool_call(call_str: str) -> List[ToolCall]: - tool_calls = json.loads(call_str) - for tool_call in tool_calls: - tool_call["function"]["arguments"] = json.dumps( - tool_call["function"]["arguments"] - ) - return [ToolCall(**tool_call) for tool_call in tool_calls]