From a2a14ea1148e0872fd513708c40582d56b563874 Mon Sep 17 00:00:00 2001 From: Benjamin Oldenburg Date: Sat, 15 Mar 2025 02:01:33 +0700 Subject: [PATCH] Fix Tool Call JSON Serialization Error (#302) * Fix Tool Call JSON Serialization Error * Incorporate changes from PR 292 kingbri note: Adjusts the tool JSON formation and incorporates finish reasons. Added both authors as co-authors due to edits on this commit from the original PR. Co-Authored-by: David Allada Co-Authored-by: Benjamin Oldenburg Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> * API: Cleanup tool call JSON parsing Split pre and post-processing of tool calls to its own class. This cleans up the chat_completion utility module and also fixes the JSON serialization bug. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --------- Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> Co-authored-by: David Allada Co-authored-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- endpoints/OAI/utils/chat_completion.py | 33 ++++++++++++-------------- 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index a646924..9393e90 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -30,7 +30,7 @@ from endpoints.OAI.types.chat_completion import ( ) from endpoints.OAI.types.common import UsageStats from endpoints.OAI.utils.completion import _stream_collector -from endpoints.OAI.types.tools import ToolCall +from endpoints.OAI.utils.tools import ToolCallProcessor def _create_response( @@ -49,7 +49,7 @@ def _create_response( tool_calls = generation["tool_calls"] if tool_calls: - message.tool_calls = postprocess_tool_call(tool_calls) + message.tool_calls = ToolCallProcessor.from_json(tool_calls) logprob_response = None @@ -74,9 +74,14 @@ def _create_response( logprob_response = ChatCompletionLogprobs(content=collected_token_probs) + # Finish reason will always be present in a completion + # If a tool call is present, mark the finish reason as such + if message.tool_calls: + finish_reason = "tool_calls" + choice = ChatCompletionRespChoice( index=index, - finish_reason=generation.get("finish_reason"), + finish_reason=finish_reason, stop_str=generation.get("stop_str"), message=message, logprobs=logprob_response, @@ -120,18 +125,19 @@ def _create_stream_chunk( total_tokens=prompt_tokens + completion_tokens, ) elif "finish_reason" in generation: - choice = ChatCompletionStreamChoice( - index=index, - finish_reason=generation.get("finish_reason"), - ) + # Get the finish reason from the generation + finish_reason = generation.get("finish_reason") + choice = ChatCompletionStreamChoice(index=index, finish_reason=finish_reason) # lets check if we have tool calls since we are at the end of the generation + # Mark finish_reason as tool_calls since this is the last chunk if "tool_calls" in generation: tool_calls = generation["tool_calls"] message = ChatCompletionMessage( - tool_calls=postprocess_tool_call(tool_calls) + tool_calls=ToolCallProcessor.from_json(tool_calls) ) choice.delta = message + choice.finish_reason = "tool_calls" choices.append(choice) @@ -224,7 +230,7 @@ async def format_messages_with_template( message.content = concatenated_content if message.tool_calls: - message.tool_calls_json = json.dumps(message.tool_calls, indent=2) + message.tool_calls_json = ToolCallProcessor.to_json(message.tool_calls) special_tokens_dict = model.container.get_special_tokens( add_bos_token, ban_eos_token @@ -482,12 +488,3 @@ async def generate_tool_calls( generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"] return generations - - -def postprocess_tool_call(call_str: str) -> List[ToolCall]: - tool_calls = json.loads(call_str) - for tool_call in tool_calls: - tool_call["function"]["arguments"] = json.dumps( - tool_call["function"]["arguments"] - ) - return [ToolCall(**tool_call) for tool_call in tool_calls]