diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index acb35f9..8403a87 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -132,6 +132,7 @@ async def chat_completion_request( else: if model.container.use_vision: data.messages, embeddings = await preprocess_vision_request(data.messages) + prompt = await format_prompt_with_template(data) # Set an empty JSON schema if the request wants a JSON response diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 30ec769..86a2247 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -1,7 +1,7 @@ from pydantic import BaseModel, Field from pydantic.json_schema import SkipJsonSchema from time import time -from typing import Union, List, Optional, Dict +from typing import Literal, Union, List, Optional, Dict from uuid import uuid4 from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest @@ -18,10 +18,21 @@ class ChatCompletionLogprobs(BaseModel): content: List[ChatCompletionLogprob] = Field(default_factory=list) +class ChatCompletionImageUrl(BaseModel): + url: str + + +class ChatCompletionMessagePart(BaseModel): + type: Literal["text", "image_url"] = "text" + text: Optional[str] = None + image_url: Optional[ChatCompletionImageUrl] = None + + class ChatCompletionMessage(BaseModel): - role: Optional[str] = None - content: Optional[str] = None + role: str = "user" + content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None tool_calls: Optional[List[ToolCall]] = None + tool_calls_json: SkipJsonSchema[Optional[str]] = None class ChatCompletionRespChoice(BaseModel): @@ -51,7 +62,7 @@ class ChatCompletionRequest(CommonCompletionRequest): # WIP this can probably be tightened, or maybe match the OAI lib type # in openai\types\chat\chat_completion_message_param.py - messages: Union[str, List[Dict]] + messages: List[ChatCompletionMessage] = Field(default_factory=list) prompt_template: Optional[str] = None add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = {} diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index a59f425..c14a8dc 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -1,17 +1,16 @@ """Chat completion utilities for OAI server.""" import asyncio +import json import pathlib from asyncio import CancelledError -from typing import Dict, List, Optional -import json - -from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding +from typing import List, Optional from fastapi import HTTPException, Request from jinja2 import TemplateError from loguru import logger from common import model +from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding from common.networking import ( get_generator_error, handle_request_disconnect, @@ -214,21 +213,21 @@ async def format_prompt_with_template( unwrap(data.ban_eos_token, False), ) - # Deal with list in messages.content - # Just replace the content list with the very first text message + # Convert list to text-based content + # Use the first instance of text inside the part list for message in data.messages: - if isinstance(message["content"], list): - message["content"] = next( + if isinstance(message.content, list): + message.content = next( ( - content["text"] - for content in message["content"] - if content["type"] == "text" + content.text + for content in message.content + if content.type == "text" ), "", ) - if "tool_calls" in message: - message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2) + if message.tool_calls: + message.tool_calls_json = json.dumps(message.tool_calls, indent=2) # Overwrite any protected vars with their values data.template_vars.update( @@ -474,20 +473,21 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]: return [ToolCall(**tool_call) for tool_call in tool_calls] -async def preprocess_vision_request(messages: List[Dict]): +# TODO: Combine this with the existing preprocessor in format_prompt_with_template +async def preprocess_vision_request(messages: List[ChatCompletionMessage]): embeddings = MultimodalEmbeddingWrapper() for message in messages: - if isinstance(message["content"], list): + if isinstance(message.content, list): concatenated_content = "" - for content in message["content"]: - if content["type"] == "text": - concatenated_content += content["text"] - elif content["type"] == "image_url": + for content in message.content: + if content.type == "text": + concatenated_content += content.text + elif content.type == "image_url": embeddings = await add_image_embedding( - embeddings, content["image_url"]["url"] + embeddings, content.image_url.url ) concatenated_content += embeddings.text_alias[-1] - message["content"] = concatenated_content + message.content = concatenated_content return messages, embeddings