diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index df8cacf..3c6634f 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -6,6 +6,8 @@ import gc import math import pathlib import traceback +from backends.exllamav2.vision import clear_image_embedding_cache +from common.multimodal import MultimodalEmbeddingWrapper import torch import uuid from copy import deepcopy @@ -816,6 +818,9 @@ class ExllamaV2Container: # Delete references held in the grammar module clear_grammar_func_cache() + # Clear the image embedding cache + clear_image_embedding_cache() + # Unload LoRAs if self.generator and self.generator.generator.current_loras: for lora in self.generator.generator.current_loras: @@ -908,12 +913,17 @@ class ExllamaV2Container: return dict(zip_longest(top_tokens, cleaned_values)) async def generate( - self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs + self, + prompt: str, + embeddings: MultimodalEmbeddingWrapper, + request_id: str, + abort_event: asyncio.Event = None, + **kwargs, ): """Generate a response to a prompt.""" generations = [] async for generation in self.generate_gen( - prompt, request_id, abort_event, **kwargs + prompt, embeddings, request_id, abort_event, **kwargs ): generations.append(generation) @@ -979,6 +989,7 @@ class ExllamaV2Container: async def generate_gen( self, prompt: str, + embeddings: MultimodalEmbeddingWrapper, request_id: str, abort_event: Optional[asyncio.Event] = None, **kwargs, @@ -1246,7 +1257,10 @@ class ExllamaV2Container: # Encode both positive and negative prompts input_ids = [ self.tokenizer.encode( - prompt, add_bos=add_bos_token, encode_special_tokens=True + prompt, + add_bos=add_bos_token, + encode_special_tokens=True, + embeddings=embeddings.content, ) for prompt in prompts ] @@ -1297,6 +1311,7 @@ class ExllamaV2Container: banned_strings=banned_strings, token_healing=token_healing, identifier=job_id, + embeddings=embeddings.content, ) # Save generated tokens and full response diff --git a/backends/exllamav2/vision.py b/backends/exllamav2/vision.py index c49bf33..d207d3e 100644 --- a/backends/exllamav2/vision.py +++ b/backends/exllamav2/vision.py @@ -4,18 +4,14 @@ import io import base64 import re from PIL import Image +from common import model import aiohttp from common.networking import ( handle_request_error, ) from fastapi import HTTPException -from exllamav2 import ( - ExLlamaV2, - ExLlamaV2Tokenizer, - ExLlamaV2VisionTower, - ExLlamaV2MMEmbedding, -) -from functools import lru_cache +from exllamav2.generator import ExLlamaV2MMEmbedding +from async_lru import alru_cache async def get_image(url: str) -> Image: @@ -50,14 +46,16 @@ async def get_image(url: str) -> Image: return Image.open(io.BytesIO(bytes_image)) -@lru_cache(20) -async def get_image_embedding( - model: ExLlamaV2, - tokenizer: ExLlamaV2Tokenizer, - vision_model: ExLlamaV2VisionTower, - url: str, -) -> ExLlamaV2MMEmbedding: +@alru_cache(20) +async def get_image_embedding(url: str) -> ExLlamaV2MMEmbedding: image = await get_image(url) - return vision_model.get_image_embeddings( - model=model, tokenizer=tokenizer, image=image + return model.container.vision_model.get_image_embeddings( + model=model.container.model, + tokenizer=model.container.tokenizer, + image=image, + text_alias=None, ) + + +def clear_image_embedding_cache(): + get_image_embedding.cache_clear() diff --git a/common/multimodal.py b/common/multimodal.py new file mode 100644 index 0000000..74d4964 --- /dev/null +++ b/common/multimodal.py @@ -0,0 +1,36 @@ +from typing import List +from backends.exllamav2.vision import get_image_embedding +from common import model +from pydantic import BaseModel +from loguru import logger + +from common.optional_dependencies import dependencies + +if dependencies.exllamav2: + from exllamav2 import ExLlamaV2VisionTower + + +class MultimodalEmbeddingWrapper(BaseModel): + """Common multimodal embedding wrapper""" + + type: str = None + content: List = [] + text_alias: List[str] = [] + + +async def add_image_embedding( + embeddings: MultimodalEmbeddingWrapper, url: str +) -> MultimodalEmbeddingWrapper: + # Determine the type of vision embedding to use + if not embeddings.type: + if isinstance(model.container.vision_model, ExLlamaV2VisionTower): + embeddings.type = "ExLlamaV2MMEmbedding" + + if embeddings.type == "ExLlamaV2MMEmbedding": + embedding = await get_image_embedding(url) + embeddings.content.append(embedding) + embeddings.text_alias.append(embedding.text_alias) + else: + logger.error("No valid vision model to create embedding") + + return embeddings diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index b6a44c9..c018f5c 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -17,6 +17,7 @@ from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse from endpoints.OAI.utils.chat_completion import ( format_prompt_with_template, generate_chat_completion, + preprocess_vision_request, stream_generate_chat_completion, ) from endpoints.OAI.utils.completion import ( @@ -126,6 +127,8 @@ async def chat_completion_request( if isinstance(data.messages, str): prompt = data.messages else: + if model.container.use_vision: + data.messages, embeddings = await preprocess_vision_request(data.messages) prompt = await format_prompt_with_template(data) # Set an empty JSON schema if the request wants a JSON response @@ -136,12 +139,14 @@ async def chat_completion_request( if data.stream and not disable_request_streaming: return EventSourceResponse( - stream_generate_chat_completion(prompt, data, request, model_path), + stream_generate_chat_completion( + prompt, embeddings, data, request, model_path + ), ping=maxsize, ) else: generate_task = asyncio.create_task( - generate_chat_completion(prompt, data, request, model_path) + generate_chat_completion(prompt, embeddings, data, request, model_path) ) response = await run_with_request_disconnect( diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 3b5c07f..a59f425 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -3,9 +3,10 @@ import asyncio import pathlib from asyncio import CancelledError -from typing import List, Optional +from typing import Dict, List, Optional import json +from common.multimodal import MultimodalEmbeddingWrapper, add_image_embedding from fastapi import HTTPException, Request from jinja2 import TemplateError from loguru import logger @@ -279,7 +280,11 @@ async def format_prompt_with_template( async def stream_generate_chat_completion( - prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path + prompt: str, + embeddings: MultimodalEmbeddingWrapper, + data: ChatCompletionRequest, + request: Request, + model_path: pathlib.Path, ): """Generator for the generation process.""" abort_event = asyncio.Event() @@ -298,6 +303,7 @@ async def stream_generate_chat_completion( n, gen_queue, prompt, + embeddings, request.state.id, abort_event, **task_gen_params.model_dump(exclude={"prompt"}), @@ -372,7 +378,11 @@ async def stream_generate_chat_completion( async def generate_chat_completion( - prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path + prompt: str, + embeddings: MultimodalEmbeddingWrapper, + data: ChatCompletionRequest, + request: Request, + model_path: pathlib.Path, ): gen_tasks: List[asyncio.Task] = [] @@ -381,7 +391,10 @@ async def generate_chat_completion( gen_tasks.append( asyncio.create_task( model.container.generate( - prompt, request.state.id, **data.model_dump(exclude={"prompt"}) + prompt, + embeddings, + request.state.id, + **data.model_dump(exclude={"prompt"}), ) ) ) @@ -459,3 +472,22 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]: tool_call["function"]["arguments"] ) return [ToolCall(**tool_call) for tool_call in tool_calls] + + +async def preprocess_vision_request(messages: List[Dict]): + embeddings = MultimodalEmbeddingWrapper() + for message in messages: + if isinstance(message["content"], list): + concatenated_content = "" + for content in message["content"]: + if content["type"] == "text": + concatenated_content += content["text"] + elif content["type"] == "image_url": + embeddings = await add_image_embedding( + embeddings, content["image_url"]["url"] + ) + concatenated_content += embeddings.text_alias[-1] + + message["content"] = concatenated_content + + return messages, embeddings diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index e939525..65ff0d3 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -7,6 +7,7 @@ Also serves as a common module for completions and chat completions. import asyncio import pathlib from asyncio import CancelledError +from common.multimodal import MultimodalEmbeddingWrapper from fastapi import HTTPException, Request from typing import List, Union @@ -87,6 +88,7 @@ async def _stream_collector( task_idx: int, gen_queue: asyncio.Queue, prompt: str, + embeddings: MultimodalEmbeddingWrapper, request_id: str, abort_event: asyncio.Event, **kwargs, @@ -95,7 +97,7 @@ async def _stream_collector( try: new_generation = model.container.generate_gen( - prompt, request_id, abort_event, **kwargs + prompt, embeddings, request_id, abort_event, **kwargs ) async for generation in new_generation: generation["index"] = task_idx diff --git a/pyproject.toml b/pyproject.toml index 81f8bf2..ca4b511 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "lm-format-enforcer >= 0.9.6", "aiofiles", "aiohttp", + "async_lru", "huggingface_hub", "psutil", "httptools>=0.5.0",