API: Fix tool call serialization
To render in the template, tool call start tokens needed to have less checks and remove the line to convert message.tool_calls to a dict since that breaks the rest of the chain by disconnecting the types. model_dump on the message itself already accomplishes this. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
d23fefbecd
commit
b6a26da50c
2 changed files with 3 additions and 8 deletions
|
|
@ -56,7 +56,7 @@ class ChatCompletionStreamChoice(BaseModel):
|
|||
|
||||
# Inherited from common request
|
||||
class ChatCompletionRequest(CommonCompletionRequest):
|
||||
messages: List[ChatCompletionMessage] = Field(default_factory=list)
|
||||
messages: List[ChatCompletionMessage]
|
||||
prompt_template: Optional[str] = None
|
||||
add_generation_prompt: Optional[bool] = True
|
||||
template_vars: Optional[dict] = Field(
|
||||
|
|
|
|||
|
|
@ -207,12 +207,11 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars:
|
|||
if isinstance(data.stop, str):
|
||||
data.stop = [data.stop] + template_metadata.stop_strings
|
||||
else:
|
||||
data.stop += template_metadata.stop_strings
|
||||
data.stop.extend(template_metadata.stop_strings)
|
||||
|
||||
# Tool call start strings
|
||||
if template_metadata.tool_starts:
|
||||
if data.tool_call_start is None:
|
||||
data.tool_call_start = template_metadata.tool_starts
|
||||
data.tool_call_start.extend(template_metadata.tool_starts)
|
||||
|
||||
# Append to stop strings to halt for a tool call generation
|
||||
data.stop.extend(template_metadata.tool_starts)
|
||||
|
|
@ -245,10 +244,6 @@ async def format_messages_with_template(
|
|||
if message.tool_calls:
|
||||
message.tool_calls_json = ToolCallProcessor.to_json(message.tool_calls)
|
||||
|
||||
# The tools variable is inspectable in the template, so
|
||||
# store the list of dicts rather than the ToolCallProcessor object.
|
||||
message.tool_calls = ToolCallProcessor.dump(message.tool_calls)
|
||||
|
||||
message_dicts.append(message.model_dump(exclude_none=True))
|
||||
|
||||
# Get all special tokens
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue