OAI: Initial vision support in OAI chat completions

* Support image_url inputs containing URLs or base64 strings following OAI vision spec
* Use async lru cache for image embeddings
* Add generic wrapper class for multimodal embeddings
This commit is contained in:
DocShotgun 2024-11-17 21:23:09 -08:00
parent 5fa298e601
commit dd41eec8a4
7 changed files with 115 additions and 26 deletions

View file

@ -17,6 +17,7 @@ from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
generate_chat_completion,
preprocess_vision_request,
stream_generate_chat_completion,
)
from endpoints.OAI.utils.completion import (
@ -126,6 +127,8 @@ async def chat_completion_request(
if isinstance(data.messages, str):
prompt = data.messages
else:
if model.container.use_vision:
data.messages, embeddings = await preprocess_vision_request(data.messages)
prompt = await format_prompt_with_template(data)
# Set an empty JSON schema if the request wants a JSON response
@ -136,12 +139,14 @@ async def chat_completion_request(
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_chat_completion(prompt, data, request, model_path),
stream_generate_chat_completion(
prompt, embeddings, data, request, model_path
),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, request, model_path)
generate_chat_completion(prompt, embeddings, data, request, model_path)
)
response = await run_with_request_disconnect(

View file

@ -3,9 +3,10 @@
import asyncio
import pathlib
from asyncio import CancelledError
from typing import List, Optional
from typing import Dict, List, Optional
import json
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
from fastapi import HTTPException, Request
from jinja2 import TemplateError
from loguru import logger
@ -279,7 +280,11 @@ async def format_prompt_with_template(
async def stream_generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
data: ChatCompletionRequest,
request: Request,
model_path: pathlib.Path,
):
"""Generator for the generation process."""
abort_event = asyncio.Event()
@ -298,6 +303,7 @@ async def stream_generate_chat_completion(
n,
gen_queue,
prompt,
embeddings,
request.state.id,
abort_event,
**task_gen_params.model_dump(exclude={"prompt"}),
@ -372,7 +378,11 @@ async def stream_generate_chat_completion(
async def generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
data: ChatCompletionRequest,
request: Request,
model_path: pathlib.Path,
):
gen_tasks: List[asyncio.Task] = []
@ -381,7 +391,10 @@ async def generate_chat_completion(
gen_tasks.append(
asyncio.create_task(
model.container.generate(
prompt, request.state.id, **data.model_dump(exclude={"prompt"})
prompt,
embeddings,
request.state.id,
**data.model_dump(exclude={"prompt"}),
)
)
)
@ -459,3 +472,22 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]:
tool_call["function"]["arguments"]
)
return [ToolCall(**tool_call) for tool_call in tool_calls]
async def preprocess_vision_request(messages: List[Dict]):
embeddings = MultimodalEmbeddingWrapper()
for message in messages:
if isinstance(message["content"], list):
concatenated_content = ""
for content in message["content"]:
if content["type"] == "text":
concatenated_content += content["text"]
elif content["type"] == "image_url":
embeddings = await add_image_embedding(
embeddings, content["image_url"]["url"]
)
concatenated_content += embeddings.text_alias[-1]
message["content"] = concatenated_content
return messages, embeddings

View file

@ -7,6 +7,7 @@ Also serves as a common module for completions and chat completions.
import asyncio
import pathlib
from asyncio import CancelledError
from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import HTTPException, Request
from typing import List, Union
@ -87,6 +88,7 @@ async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str,
abort_event: asyncio.Event,
**kwargs,
@ -95,7 +97,7 @@ async def _stream_collector(
try:
new_generation = model.container.generate_gen(
prompt, request_id, abort_event, **kwargs
prompt, embeddings, request_id, abort_event, **kwargs
)
async for generation in new_generation:
generation["index"] = task_idx