Fix Tool Call JSON Serialization Error (#302)

* Fix Tool Call JSON Serialization Error

* Incorporate changes from PR 292

kingbri note: Adjusts the tool JSON formation and incorporates finish
reasons. Added both authors as co-authors due to edits on this commit
from the original PR.

Co-Authored-by: David Allada <dallada1@vt.edu>
Co-Authored-by: Benjamin Oldenburg <benjamin.oldenburg@ordis.co.th>
Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>

* API: Cleanup tool call JSON parsing

Split pre and post-processing of tool calls to its own class. This
cleans up the chat_completion utility module and also fixes the
JSON serialization bug.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>

---------

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
Co-authored-by: David Allada <dallada1@vt.edu>
Co-authored-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
Benjamin Oldenburg 2025-03-15 02:01:33 +07:00 committed by GitHub
parent de77955428
commit a2a14ea114
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -30,7 +30,7 @@ from endpoints.OAI.types.chat_completion import (
)
from endpoints.OAI.types.common import UsageStats
from endpoints.OAI.utils.completion import _stream_collector
from endpoints.OAI.types.tools import ToolCall
from endpoints.OAI.utils.tools import ToolCallProcessor
def _create_response(
@ -49,7 +49,7 @@ def _create_response(
tool_calls = generation["tool_calls"]
if tool_calls:
message.tool_calls = postprocess_tool_call(tool_calls)
message.tool_calls = ToolCallProcessor.from_json(tool_calls)
logprob_response = None
@ -74,9 +74,14 @@ def _create_response(
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
# Finish reason will always be present in a completion
# If a tool call is present, mark the finish reason as such
if message.tool_calls:
finish_reason = "tool_calls"
choice = ChatCompletionRespChoice(
index=index,
finish_reason=generation.get("finish_reason"),
finish_reason=finish_reason,
stop_str=generation.get("stop_str"),
message=message,
logprobs=logprob_response,
@ -120,18 +125,19 @@ def _create_stream_chunk(
total_tokens=prompt_tokens + completion_tokens,
)
elif "finish_reason" in generation:
choice = ChatCompletionStreamChoice(
index=index,
finish_reason=generation.get("finish_reason"),
)
# Get the finish reason from the generation
finish_reason = generation.get("finish_reason")
choice = ChatCompletionStreamChoice(index=index, finish_reason=finish_reason)
# lets check if we have tool calls since we are at the end of the generation
# Mark finish_reason as tool_calls since this is the last chunk
if "tool_calls" in generation:
tool_calls = generation["tool_calls"]
message = ChatCompletionMessage(
tool_calls=postprocess_tool_call(tool_calls)
tool_calls=ToolCallProcessor.from_json(tool_calls)
)
choice.delta = message
choice.finish_reason = "tool_calls"
choices.append(choice)
@ -224,7 +230,7 @@ async def format_messages_with_template(
message.content = concatenated_content
if message.tool_calls:
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)
message.tool_calls_json = ToolCallProcessor.to_json(message.tool_calls)
special_tokens_dict = model.container.get_special_tokens(
add_bos_token, ban_eos_token
@ -482,12 +488,3 @@ async def generate_tool_calls(
generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"]
return generations
def postprocess_tool_call(call_str: str) -> List[ToolCall]:
tool_calls = json.loads(call_str)
for tool_call in tool_calls:
tool_call["function"]["arguments"] = json.dumps(
tool_call["function"]["arguments"]
)
return [ToolCall(**tool_call) for tool_call in tool_calls]