Fix Tool Call JSON Serialization Error (#302)
* Fix Tool Call JSON Serialization Error * Incorporate changes from PR 292 kingbri note: Adjusts the tool JSON formation and incorporates finish reasons. Added both authors as co-authors due to edits on this commit from the original PR. Co-Authored-by: David Allada <dallada1@vt.edu> Co-Authored-by: Benjamin Oldenburg <benjamin.oldenburg@ordis.co.th> Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> * API: Cleanup tool call JSON parsing Split pre and post-processing of tool calls to its own class. This cleans up the chat_completion utility module and also fixes the JSON serialization bug. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --------- Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> Co-authored-by: David Allada <dallada1@vt.edu> Co-authored-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
de77955428
commit
a2a14ea114
1 changed files with 15 additions and 18 deletions
|
|
@ -30,7 +30,7 @@ from endpoints.OAI.types.chat_completion import (
|
|||
)
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
from endpoints.OAI.utils.completion import _stream_collector
|
||||
from endpoints.OAI.types.tools import ToolCall
|
||||
from endpoints.OAI.utils.tools import ToolCallProcessor
|
||||
|
||||
|
||||
def _create_response(
|
||||
|
|
@ -49,7 +49,7 @@ def _create_response(
|
|||
|
||||
tool_calls = generation["tool_calls"]
|
||||
if tool_calls:
|
||||
message.tool_calls = postprocess_tool_call(tool_calls)
|
||||
message.tool_calls = ToolCallProcessor.from_json(tool_calls)
|
||||
|
||||
logprob_response = None
|
||||
|
||||
|
|
@ -74,9 +74,14 @@ def _create_response(
|
|||
|
||||
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
|
||||
|
||||
# Finish reason will always be present in a completion
|
||||
# If a tool call is present, mark the finish reason as such
|
||||
if message.tool_calls:
|
||||
finish_reason = "tool_calls"
|
||||
|
||||
choice = ChatCompletionRespChoice(
|
||||
index=index,
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
finish_reason=finish_reason,
|
||||
stop_str=generation.get("stop_str"),
|
||||
message=message,
|
||||
logprobs=logprob_response,
|
||||
|
|
@ -120,18 +125,19 @@ def _create_stream_chunk(
|
|||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
elif "finish_reason" in generation:
|
||||
choice = ChatCompletionStreamChoice(
|
||||
index=index,
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
)
|
||||
# Get the finish reason from the generation
|
||||
finish_reason = generation.get("finish_reason")
|
||||
choice = ChatCompletionStreamChoice(index=index, finish_reason=finish_reason)
|
||||
|
||||
# lets check if we have tool calls since we are at the end of the generation
|
||||
# Mark finish_reason as tool_calls since this is the last chunk
|
||||
if "tool_calls" in generation:
|
||||
tool_calls = generation["tool_calls"]
|
||||
message = ChatCompletionMessage(
|
||||
tool_calls=postprocess_tool_call(tool_calls)
|
||||
tool_calls=ToolCallProcessor.from_json(tool_calls)
|
||||
)
|
||||
choice.delta = message
|
||||
choice.finish_reason = "tool_calls"
|
||||
|
||||
choices.append(choice)
|
||||
|
||||
|
|
@ -224,7 +230,7 @@ async def format_messages_with_template(
|
|||
message.content = concatenated_content
|
||||
|
||||
if message.tool_calls:
|
||||
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)
|
||||
message.tool_calls_json = ToolCallProcessor.to_json(message.tool_calls)
|
||||
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
add_bos_token, ban_eos_token
|
||||
|
|
@ -482,12 +488,3 @@ async def generate_tool_calls(
|
|||
generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"]
|
||||
|
||||
return generations
|
||||
|
||||
|
||||
def postprocess_tool_call(call_str: str) -> List[ToolCall]:
|
||||
tool_calls = json.loads(call_str)
|
||||
for tool_call in tool_calls:
|
||||
tool_call["function"]["arguments"] = json.dumps(
|
||||
tool_call["function"]["arguments"]
|
||||
)
|
||||
return [ToolCall(**tool_call) for tool_call in tool_calls]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue