OAI: Initial vision support in OAI chat completions
* Support image_url inputs containing URLs or base64 strings following OAI vision spec * Use async lru cache for image embeddings * Add generic wrapper class for multimodal embeddings
This commit is contained in:
parent
5fa298e601
commit
dd41eec8a4
7 changed files with 115 additions and 26 deletions
|
|
@ -17,6 +17,7 @@ from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
|
|||
from endpoints.OAI.utils.chat_completion import (
|
||||
format_prompt_with_template,
|
||||
generate_chat_completion,
|
||||
preprocess_vision_request,
|
||||
stream_generate_chat_completion,
|
||||
)
|
||||
from endpoints.OAI.utils.completion import (
|
||||
|
|
@ -126,6 +127,8 @@ async def chat_completion_request(
|
|||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
if model.container.use_vision:
|
||||
data.messages, embeddings = await preprocess_vision_request(data.messages)
|
||||
prompt = await format_prompt_with_template(data)
|
||||
|
||||
# Set an empty JSON schema if the request wants a JSON response
|
||||
|
|
@ -136,12 +139,14 @@ async def chat_completion_request(
|
|||
|
||||
if data.stream and not disable_request_streaming:
|
||||
return EventSourceResponse(
|
||||
stream_generate_chat_completion(prompt, data, request, model_path),
|
||||
stream_generate_chat_completion(
|
||||
prompt, embeddings, data, request, model_path
|
||||
),
|
||||
ping=maxsize,
|
||||
)
|
||||
else:
|
||||
generate_task = asyncio.create_task(
|
||||
generate_chat_completion(prompt, data, request, model_path)
|
||||
generate_chat_completion(prompt, embeddings, data, request, model_path)
|
||||
)
|
||||
|
||||
response = await run_with_request_disconnect(
|
||||
|
|
|
|||
|
|
@ -3,9 +3,10 @@
|
|||
import asyncio
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
import json
|
||||
|
||||
from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding
|
||||
from fastapi import HTTPException, Request
|
||||
from jinja2 import TemplateError
|
||||
from loguru import logger
|
||||
|
|
@ -279,7 +280,11 @@ async def format_prompt_with_template(
|
|||
|
||||
|
||||
async def stream_generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||
prompt: str,
|
||||
embeddings: MultimodalEmbeddingWrapper,
|
||||
data: ChatCompletionRequest,
|
||||
request: Request,
|
||||
model_path: pathlib.Path,
|
||||
):
|
||||
"""Generator for the generation process."""
|
||||
abort_event = asyncio.Event()
|
||||
|
|
@ -298,6 +303,7 @@ async def stream_generate_chat_completion(
|
|||
n,
|
||||
gen_queue,
|
||||
prompt,
|
||||
embeddings,
|
||||
request.state.id,
|
||||
abort_event,
|
||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
||||
|
|
@ -372,7 +378,11 @@ async def stream_generate_chat_completion(
|
|||
|
||||
|
||||
async def generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||
prompt: str,
|
||||
embeddings: MultimodalEmbeddingWrapper,
|
||||
data: ChatCompletionRequest,
|
||||
request: Request,
|
||||
model_path: pathlib.Path,
|
||||
):
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
|
||||
|
|
@ -381,7 +391,10 @@ async def generate_chat_completion(
|
|||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
prompt, request.state.id, **data.model_dump(exclude={"prompt"})
|
||||
prompt,
|
||||
embeddings,
|
||||
request.state.id,
|
||||
**data.model_dump(exclude={"prompt"}),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
|
@ -459,3 +472,22 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]:
|
|||
tool_call["function"]["arguments"]
|
||||
)
|
||||
return [ToolCall(**tool_call) for tool_call in tool_calls]
|
||||
|
||||
|
||||
async def preprocess_vision_request(messages: List[Dict]):
|
||||
embeddings = MultimodalEmbeddingWrapper()
|
||||
for message in messages:
|
||||
if isinstance(message["content"], list):
|
||||
concatenated_content = ""
|
||||
for content in message["content"]:
|
||||
if content["type"] == "text":
|
||||
concatenated_content += content["text"]
|
||||
elif content["type"] == "image_url":
|
||||
embeddings = await add_image_embedding(
|
||||
embeddings, content["image_url"]["url"]
|
||||
)
|
||||
concatenated_content += embeddings.text_alias[-1]
|
||||
|
||||
message["content"] = concatenated_content
|
||||
|
||||
return messages, embeddings
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ Also serves as a common module for completions and chat completions.
|
|||
import asyncio
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from fastapi import HTTPException, Request
|
||||
from typing import List, Union
|
||||
|
||||
|
|
@ -87,6 +88,7 @@ async def _stream_collector(
|
|||
task_idx: int,
|
||||
gen_queue: asyncio.Queue,
|
||||
prompt: str,
|
||||
embeddings: MultimodalEmbeddingWrapper,
|
||||
request_id: str,
|
||||
abort_event: asyncio.Event,
|
||||
**kwargs,
|
||||
|
|
@ -95,7 +97,7 @@ async def _stream_collector(
|
|||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
prompt, embeddings, request_id, abort_event, **kwargs
|
||||
)
|
||||
async for generation in new_generation:
|
||||
generation["index"] = task_idx
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue