OAI: Strictly type chat completions

Previously, the messages were a list of dicts. These are untyped
and don't provide strict hinting. Add types for chat completion
messages and reformat existing code.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-11-19 23:15:47 -05:00
parent 0fadb1e5e8
commit 8ffc636dce
3 changed files with 37 additions and 25 deletions

View file

@ -132,6 +132,7 @@ async def chat_completion_request(
else:
if model.container.use_vision:
data.messages, embeddings = await preprocess_vision_request(data.messages)
prompt = await format_prompt_with_template(data)
# Set an empty JSON schema if the request wants a JSON response

View file

@ -1,7 +1,7 @@
from pydantic import BaseModel, Field
from pydantic.json_schema import SkipJsonSchema
from time import time
from typing import Union, List, Optional, Dict
from typing import Literal, Union, List, Optional, Dict
from uuid import uuid4
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
@ -18,10 +18,21 @@ class ChatCompletionLogprobs(BaseModel):
content: List[ChatCompletionLogprob] = Field(default_factory=list)
class ChatCompletionImageUrl(BaseModel):
url: str
class ChatCompletionMessagePart(BaseModel):
type: Literal["text", "image_url"] = "text"
text: Optional[str] = None
image_url: Optional[ChatCompletionImageUrl] = None
class ChatCompletionMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
role: str = "user"
content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None
tool_calls: Optional[List[ToolCall]] = None
tool_calls_json: SkipJsonSchema[Optional[str]] = None
class ChatCompletionRespChoice(BaseModel):
@ -51,7 +62,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
# WIP this can probably be tightened, or maybe match the OAI lib type
# in openai\types\chat\chat_completion_message_param.py
messages: Union[str, List[Dict]]
messages: List[ChatCompletionMessage] = Field(default_factory=list)
prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = {}

View file

@ -1,17 +1,16 @@
"""Chat completion utilities for OAI server."""
import asyncio
import json
import pathlib
from asyncio import CancelledError
from typing import Dict, List, Optional
import json
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
from typing import List, Optional
from fastapi import HTTPException, Request
from jinja2 import TemplateError
from loguru import logger
from common import model
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
from common.networking import (
get_generator_error,
handle_request_disconnect,
@ -214,21 +213,21 @@ async def format_prompt_with_template(
unwrap(data.ban_eos_token, False),
)
# Deal with list in messages.content
# Just replace the content list with the very first text message
# Convert list to text-based content
# Use the first instance of text inside the part list
for message in data.messages:
if isinstance(message["content"], list):
message["content"] = next(
if isinstance(message.content, list):
message.content = next(
(
content["text"]
for content in message["content"]
if content["type"] == "text"
content.text
for content in message.content
if content.type == "text"
),
"",
)
if "tool_calls" in message:
message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2)
if message.tool_calls:
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)
# Overwrite any protected vars with their values
data.template_vars.update(
@ -474,20 +473,21 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]:
return [ToolCall(**tool_call) for tool_call in tool_calls]
async def preprocess_vision_request(messages: List[Dict]):
# TODO: Combine this with the existing preprocessor in format_prompt_with_template
async def preprocess_vision_request(messages: List[ChatCompletionMessage]):
embeddings = MultimodalEmbeddingWrapper()
for message in messages:
if isinstance(message["content"], list):
if isinstance(message.content, list):
concatenated_content = ""
for content in message["content"]:
if content["type"] == "text":
concatenated_content += content["text"]
elif content["type"] == "image_url":
for content in message.content:
if content.type == "text":
concatenated_content += content.text
elif content.type == "image_url":
embeddings = await add_image_embedding(
embeddings, content["image_url"]["url"]
embeddings, content.image_url.url
)
concatenated_content += embeddings.text_alias[-1]
message["content"] = concatenated_content
message.content = concatenated_content
return messages, embeddings