diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 0bd91e0..561d9bc 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -30,10 +30,8 @@ class ChatCompletionMessagePart(BaseModel): class ChatCompletionMessage(BaseModel): role: str = "user" - content: Optional[Union[str, List[ChatCompletionMessagePart]]] = Field( - default_factory=list - ) - tool_calls: Optional[List[ToolCall]] = Field(default_factory=list) + content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None + tool_calls: Optional[List[ToolCall]] = None tool_calls_json: SkipJsonSchema[Optional[str]] = None @@ -58,12 +56,6 @@ class ChatCompletionStreamChoice(BaseModel): # Inherited from common request class ChatCompletionRequest(CommonCompletionRequest): - # Messages - # Take in a string as well even though it's not part of the OAI spec - # support messages.content as a list of dict - - # WIP this can probably be tightened, or maybe match the OAI lib type - # in openai\types\chat\chat_completion_message_param.py messages: List[ChatCompletionMessage] = Field(default_factory=list) prompt_template: Optional[str] = None add_generation_prompt: Optional[bool] = True @@ -78,8 +70,8 @@ class ChatCompletionRequest(CommonCompletionRequest): # tools is follows the format OAI schema, functions is more flexible # both are available in the chat template. - tools: Optional[List[ToolSpec]] = Field(default_factory=list) - functions: Optional[List[Dict]] = Field(default_factory=list) + tools: Optional[List[ToolSpec]] = None + functions: Optional[List[Dict]] = None # Typically collected from Chat Template. # Don't include this in the OpenAPI docs diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 5975173..c1e6681 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -238,7 +238,7 @@ async def format_messages_with_template( # store the list of dicts rather than the ToolCallProcessor object. message.tool_calls = ToolCallProcessor.dump(message.tool_calls) - message_dicts.append(message.model_dump()) + message_dicts.append(message.model_dump(exclude_none=True)) # Get all special tokens special_tokens_dict = model.container.get_special_tokens()