API: Fix types for chat completions
Messages were mistakenly being sent as Pydantic objects, but templates expect dictionaries. Properly convert these before render. In addition, initialize all Optional lists as an empty list since this will cause the least problems when interacting with other parts of API code, such as templates. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
81170eee00
commit
54b8a20a19
2 changed files with 17 additions and 10 deletions
|
|
@ -11,7 +11,7 @@ from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema
|
|||
class ChatCompletionLogprob(BaseModel):
|
||||
token: str
|
||||
logprob: float
|
||||
top_logprobs: Optional[List["ChatCompletionLogprob"]] = None
|
||||
top_logprobs: Optional[List["ChatCompletionLogprob"]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionLogprobs(BaseModel):
|
||||
|
|
@ -30,8 +30,10 @@ class ChatCompletionMessagePart(BaseModel):
|
|||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: str = "user"
|
||||
content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None
|
||||
tool_calls: Optional[List[ToolCall]] = None
|
||||
content: Optional[Union[str, List[ChatCompletionMessagePart]]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default_factory=list)
|
||||
tool_calls_json: SkipJsonSchema[Optional[str]] = None
|
||||
|
||||
|
||||
|
|
@ -76,13 +78,15 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
|||
# tools is follows the format OAI schema, functions is more flexible
|
||||
# both are available in the chat template.
|
||||
|
||||
tools: Optional[List[ToolSpec]] = None
|
||||
functions: Optional[List[Dict]] = None
|
||||
tools: Optional[List[ToolSpec]] = Field(default_factory=list)
|
||||
functions: Optional[List[Dict]] = Field(default_factory=list)
|
||||
|
||||
# Typically collected from Chat Template.
|
||||
# Don't include this in the OpenAPI docs
|
||||
# TODO: Use these custom parameters
|
||||
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = None
|
||||
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
tool_call_end: SkipJsonSchema[Optional[str]] = None
|
||||
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
|
||||
|
||||
|
|
|
|||
|
|
@ -210,14 +210,14 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars:
|
|||
async def format_messages_with_template(
|
||||
messages: List[ChatCompletionMessage],
|
||||
existing_template_vars: Optional[dict] = None,
|
||||
add_bos_token: bool = True,
|
||||
ban_eos_token: bool = False,
|
||||
):
|
||||
"""Barebones function to format chat completion messages into a prompt."""
|
||||
|
||||
template_vars = unwrap(existing_template_vars, {})
|
||||
mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None
|
||||
|
||||
# Convert all messages to a dictionary representation
|
||||
message_dicts: List[dict] = []
|
||||
for message in messages:
|
||||
if isinstance(message.content, list):
|
||||
concatenated_content = ""
|
||||
|
|
@ -238,9 +238,12 @@ async def format_messages_with_template(
|
|||
# store the list of dicts rather than the ToolCallProcessor object.
|
||||
message.tool_calls = ToolCallProcessor.dump(message.tool_calls)
|
||||
|
||||
message_dicts.append(message.model_dump())
|
||||
|
||||
# Get all special tokens
|
||||
special_tokens_dict = model.container.get_special_tokens()
|
||||
|
||||
template_vars.update({"messages": messages, **special_tokens_dict})
|
||||
template_vars.update({"messages": message_dicts, **special_tokens_dict})
|
||||
|
||||
prompt = await model.container.prompt_template.render(template_vars)
|
||||
return prompt, mm_embeddings, template_vars
|
||||
|
|
@ -270,7 +273,7 @@ async def apply_chat_template(
|
|||
)
|
||||
|
||||
prompt, mm_embeddings, template_vars = await format_messages_with_template(
|
||||
data.messages, data.template_vars, data.add_bos_token, data.ban_eos_token
|
||||
data.messages, data.template_vars
|
||||
)
|
||||
|
||||
# Append response prefix if present
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue