From 54b8a20a19f3654a21c6877f3e45fe2cf2773527 Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Sat, 17 May 2025 18:04:39 -0400 Subject: [PATCH] API: Fix types for chat completions Messages were mistakenly being sent as Pydantic objects, but templates expect dictionaries. Properly convert these before render. In addition, initialize all Optional lists as an empty list since this will cause the least problems when interacting with other parts of API code, such as templates. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- endpoints/OAI/types/chat_completion.py | 16 ++++++++++------ endpoints/OAI/utils/chat_completion.py | 11 +++++++---- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 15988cc..0bd91e0 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -11,7 +11,7 @@ from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema class ChatCompletionLogprob(BaseModel): token: str logprob: float - top_logprobs: Optional[List["ChatCompletionLogprob"]] = None + top_logprobs: Optional[List["ChatCompletionLogprob"]] = Field(default_factory=list) class ChatCompletionLogprobs(BaseModel): @@ -30,8 +30,10 @@ class ChatCompletionMessagePart(BaseModel): class ChatCompletionMessage(BaseModel): role: str = "user" - content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None - tool_calls: Optional[List[ToolCall]] = None + content: Optional[Union[str, List[ChatCompletionMessagePart]]] = Field( + default_factory=list + ) + tool_calls: Optional[List[ToolCall]] = Field(default_factory=list) tool_calls_json: SkipJsonSchema[Optional[str]] = None @@ -76,13 +78,15 @@ class ChatCompletionRequest(CommonCompletionRequest): # tools is follows the format OAI schema, functions is more flexible # both are available in the chat template. - tools: Optional[List[ToolSpec]] = None - functions: Optional[List[Dict]] = None + tools: Optional[List[ToolSpec]] = Field(default_factory=list) + functions: Optional[List[Dict]] = Field(default_factory=list) # Typically collected from Chat Template. # Don't include this in the OpenAPI docs # TODO: Use these custom parameters - tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = None + tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = Field( + default_factory=list + ) tool_call_end: SkipJsonSchema[Optional[str]] = None tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 11d4088..5975173 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -210,14 +210,14 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars: async def format_messages_with_template( messages: List[ChatCompletionMessage], existing_template_vars: Optional[dict] = None, - add_bos_token: bool = True, - ban_eos_token: bool = False, ): """Barebones function to format chat completion messages into a prompt.""" template_vars = unwrap(existing_template_vars, {}) mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None + # Convert all messages to a dictionary representation + message_dicts: List[dict] = [] for message in messages: if isinstance(message.content, list): concatenated_content = "" @@ -238,9 +238,12 @@ async def format_messages_with_template( # store the list of dicts rather than the ToolCallProcessor object. message.tool_calls = ToolCallProcessor.dump(message.tool_calls) + message_dicts.append(message.model_dump()) + + # Get all special tokens special_tokens_dict = model.container.get_special_tokens() - template_vars.update({"messages": messages, **special_tokens_dict}) + template_vars.update({"messages": message_dicts, **special_tokens_dict}) prompt = await model.container.prompt_template.render(template_vars) return prompt, mm_embeddings, template_vars @@ -270,7 +273,7 @@ async def apply_chat_template( ) prompt, mm_embeddings, template_vars = await format_messages_with_template( - data.messages, data.template_vars, data.add_bos_token, data.ban_eos_token + data.messages, data.template_vars ) # Append response prefix if present