diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 15988cc..0bd91e0 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -11,7 +11,7 @@ from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema class ChatCompletionLogprob(BaseModel): token: str logprob: float - top_logprobs: Optional[List["ChatCompletionLogprob"]] = None + top_logprobs: Optional[List["ChatCompletionLogprob"]] = Field(default_factory=list) class ChatCompletionLogprobs(BaseModel): @@ -30,8 +30,10 @@ class ChatCompletionMessagePart(BaseModel): class ChatCompletionMessage(BaseModel): role: str = "user" - content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None - tool_calls: Optional[List[ToolCall]] = None + content: Optional[Union[str, List[ChatCompletionMessagePart]]] = Field( + default_factory=list + ) + tool_calls: Optional[List[ToolCall]] = Field(default_factory=list) tool_calls_json: SkipJsonSchema[Optional[str]] = None @@ -76,13 +78,15 @@ class ChatCompletionRequest(CommonCompletionRequest): # tools is follows the format OAI schema, functions is more flexible # both are available in the chat template. - tools: Optional[List[ToolSpec]] = None - functions: Optional[List[Dict]] = None + tools: Optional[List[ToolSpec]] = Field(default_factory=list) + functions: Optional[List[Dict]] = Field(default_factory=list) # Typically collected from Chat Template. # Don't include this in the OpenAPI docs # TODO: Use these custom parameters - tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = None + tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = Field( + default_factory=list + ) tool_call_end: SkipJsonSchema[Optional[str]] = None tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 11d4088..5975173 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -210,14 +210,14 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars: async def format_messages_with_template( messages: List[ChatCompletionMessage], existing_template_vars: Optional[dict] = None, - add_bos_token: bool = True, - ban_eos_token: bool = False, ): """Barebones function to format chat completion messages into a prompt.""" template_vars = unwrap(existing_template_vars, {}) mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None + # Convert all messages to a dictionary representation + message_dicts: List[dict] = [] for message in messages: if isinstance(message.content, list): concatenated_content = "" @@ -238,9 +238,12 @@ async def format_messages_with_template( # store the list of dicts rather than the ToolCallProcessor object. message.tool_calls = ToolCallProcessor.dump(message.tool_calls) + message_dicts.append(message.model_dump()) + + # Get all special tokens special_tokens_dict = model.container.get_special_tokens() - template_vars.update({"messages": messages, **special_tokens_dict}) + template_vars.update({"messages": message_dicts, **special_tokens_dict}) prompt = await model.container.prompt_template.render(template_vars) return prompt, mm_embeddings, template_vars @@ -270,7 +273,7 @@ async def apply_chat_template( ) prompt, mm_embeddings, template_vars = await format_messages_with_template( - data.messages, data.template_vars, data.add_bos_token, data.ban_eos_token + data.messages, data.template_vars ) # Append response prefix if present