OAI: Strictly type chat completions
Previously, the messages were a list of dicts. These are untyped and don't provide strict hinting. Add types for chat completion messages and reformat existing code. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
0fadb1e5e8
commit
8ffc636dce
3 changed files with 37 additions and 25 deletions
|
|
@ -132,6 +132,7 @@ async def chat_completion_request(
|
|||
else:
|
||||
if model.container.use_vision:
|
||||
data.messages, embeddings = await preprocess_vision_request(data.messages)
|
||||
|
||||
prompt = await format_prompt_with_template(data)
|
||||
|
||||
# Set an empty JSON schema if the request wants a JSON response
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic.json_schema import SkipJsonSchema
|
||||
from time import time
|
||||
from typing import Union, List, Optional, Dict
|
||||
from typing import Literal, Union, List, Optional, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
|
||||
|
|
@ -18,10 +18,21 @@ class ChatCompletionLogprobs(BaseModel):
|
|||
content: List[ChatCompletionLogprob] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ChatCompletionImageUrl(BaseModel):
|
||||
url: str
|
||||
|
||||
|
||||
class ChatCompletionMessagePart(BaseModel):
|
||||
type: Literal["text", "image_url"] = "text"
|
||||
text: Optional[str] = None
|
||||
image_url: Optional[ChatCompletionImageUrl] = None
|
||||
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
role: str = "user"
|
||||
content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None
|
||||
tool_calls: Optional[List[ToolCall]] = None
|
||||
tool_calls_json: SkipJsonSchema[Optional[str]] = None
|
||||
|
||||
|
||||
class ChatCompletionRespChoice(BaseModel):
|
||||
|
|
@ -51,7 +62,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
|||
|
||||
# WIP this can probably be tightened, or maybe match the OAI lib type
|
||||
# in openai\types\chat\chat_completion_message_param.py
|
||||
messages: Union[str, List[Dict]]
|
||||
messages: List[ChatCompletionMessage] = Field(default_factory=list)
|
||||
prompt_template: Optional[str] = None
|
||||
add_generation_prompt: Optional[bool] = True
|
||||
template_vars: Optional[dict] = {}
|
||||
|
|
|
|||
|
|
@ -1,17 +1,16 @@
|
|||
"""Chat completion utilities for OAI server."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from typing import Dict, List, Optional
|
||||
import json
|
||||
|
||||
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
|
||||
from typing import List, Optional
|
||||
from fastapi import HTTPException, Request
|
||||
from jinja2 import TemplateError
|
||||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
|
|
@ -214,21 +213,21 @@ async def format_prompt_with_template(
|
|||
unwrap(data.ban_eos_token, False),
|
||||
)
|
||||
|
||||
# Deal with list in messages.content
|
||||
# Just replace the content list with the very first text message
|
||||
# Convert list to text-based content
|
||||
# Use the first instance of text inside the part list
|
||||
for message in data.messages:
|
||||
if isinstance(message["content"], list):
|
||||
message["content"] = next(
|
||||
if isinstance(message.content, list):
|
||||
message.content = next(
|
||||
(
|
||||
content["text"]
|
||||
for content in message["content"]
|
||||
if content["type"] == "text"
|
||||
content.text
|
||||
for content in message.content
|
||||
if content.type == "text"
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
if "tool_calls" in message:
|
||||
message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2)
|
||||
if message.tool_calls:
|
||||
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)
|
||||
|
||||
# Overwrite any protected vars with their values
|
||||
data.template_vars.update(
|
||||
|
|
@ -474,20 +473,21 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]:
|
|||
return [ToolCall(**tool_call) for tool_call in tool_calls]
|
||||
|
||||
|
||||
async def preprocess_vision_request(messages: List[Dict]):
|
||||
# TODO: Combine this with the existing preprocessor in format_prompt_with_template
|
||||
async def preprocess_vision_request(messages: List[ChatCompletionMessage]):
|
||||
embeddings = MultimodalEmbeddingWrapper()
|
||||
for message in messages:
|
||||
if isinstance(message["content"], list):
|
||||
if isinstance(message.content, list):
|
||||
concatenated_content = ""
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
concatenated_content += content["text"]
|
||||
elif content["type"] == "image_url":
|
||||
for content in message.content:
|
||||
if content.type == "text":
|
||||
concatenated_content += content.text
|
||||
elif content.type == "image_url":
|
||||
embeddings = await add_image_embedding(
|
||||
embeddings, content["image_url"]["url"]
|
||||
embeddings, content.image_url.url
|
||||
)
|
||||
concatenated_content += embeddings.text_alias[-1]
|
||||
|
||||
message["content"] = concatenated_content
|
||||
message.content = concatenated_content
|
||||
|
||||
return messages, embeddings
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue