API: Fix types for chat completions

Messages were mistakenly being sent as Pydantic objects, but templates
expect dictionaries. Properly convert these before render.

In addition, initialize all Optional lists as an empty list since
this will cause the least problems when interacting with other parts
of API code, such as templates.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-05-17 18:04:39 -04:00
parent 81170eee00
commit 54b8a20a19
2 changed files with 17 additions and 10 deletions

View file

@ -11,7 +11,7 @@ from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema
class ChatCompletionLogprob(BaseModel):
token: str
logprob: float
top_logprobs: Optional[List["ChatCompletionLogprob"]] = None
top_logprobs: Optional[List["ChatCompletionLogprob"]] = Field(default_factory=list)
class ChatCompletionLogprobs(BaseModel):
@ -30,8 +30,10 @@ class ChatCompletionMessagePart(BaseModel):
class ChatCompletionMessage(BaseModel):
role: str = "user"
content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None
tool_calls: Optional[List[ToolCall]] = None
content: Optional[Union[str, List[ChatCompletionMessagePart]]] = Field(
default_factory=list
)
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
tool_calls_json: SkipJsonSchema[Optional[str]] = None
@ -76,13 +78,15 @@ class ChatCompletionRequest(CommonCompletionRequest):
# tools is follows the format OAI schema, functions is more flexible
# both are available in the chat template.
tools: Optional[List[ToolSpec]] = None
functions: Optional[List[Dict]] = None
tools: Optional[List[ToolSpec]] = Field(default_factory=list)
functions: Optional[List[Dict]] = Field(default_factory=list)
# Typically collected from Chat Template.
# Don't include this in the OpenAPI docs
# TODO: Use these custom parameters
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = None
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = Field(
default_factory=list
)
tool_call_end: SkipJsonSchema[Optional[str]] = None
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema

View file

@ -210,14 +210,14 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars:
async def format_messages_with_template(
messages: List[ChatCompletionMessage],
existing_template_vars: Optional[dict] = None,
add_bos_token: bool = True,
ban_eos_token: bool = False,
):
"""Barebones function to format chat completion messages into a prompt."""
template_vars = unwrap(existing_template_vars, {})
mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None
# Convert all messages to a dictionary representation
message_dicts: List[dict] = []
for message in messages:
if isinstance(message.content, list):
concatenated_content = ""
@ -238,9 +238,12 @@ async def format_messages_with_template(
# store the list of dicts rather than the ToolCallProcessor object.
message.tool_calls = ToolCallProcessor.dump(message.tool_calls)
message_dicts.append(message.model_dump())
# Get all special tokens
special_tokens_dict = model.container.get_special_tokens()
template_vars.update({"messages": messages, **special_tokens_dict})
template_vars.update({"messages": message_dicts, **special_tokens_dict})
prompt = await model.container.prompt_template.render(template_vars)
return prompt, mm_embeddings, template_vars
@ -270,7 +273,7 @@ async def apply_chat_template(
)
prompt, mm_embeddings, template_vars = await format_messages_with_template(
data.messages, data.template_vars, data.add_bos_token, data.ban_eos_token
data.messages, data.template_vars
)
# Append response prefix if present