From 1c9891bf04f1128b1b6526ad8d035905050cbe1a Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sun, 15 Jun 2025 19:22:51 +0200 Subject: [PATCH] Exl3: Add vision capability --- backends/exllamav2/vision.py | 4 ++-- backends/exllamav3/model.py | 31 +++++++++++++++++++++++++++++++ backends/exllamav3/vision.py | 27 +++++++++++++++++++++++++++ common/multimodal.py | 25 ++++++++++++++++++++----- 4 files changed, 80 insertions(+), 7 deletions(-) create mode 100644 backends/exllamav3/vision.py diff --git a/backends/exllamav2/vision.py b/backends/exllamav2/vision.py index 8432eb5..90106e3 100644 --- a/backends/exllamav2/vision.py +++ b/backends/exllamav2/vision.py @@ -14,7 +14,7 @@ if dependencies.exllamav2: # Fetch the return type on runtime @alru_cache(20) -async def get_image_embedding(url: str) -> "ExLlamaV2MMEmbedding": +async def get_image_embedding_exl2(url: str) -> "ExLlamaV2MMEmbedding": image = await get_image(url) return model.container.vision_model.get_image_embeddings( model=model.container.model, @@ -25,4 +25,4 @@ async def get_image_embedding(url: str) -> "ExLlamaV2MMEmbedding": def clear_image_embedding_cache(): - get_image_embedding.cache_clear() + get_image_embedding_exl2.cache_clear() diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 24c08f3..c8761ff 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -69,6 +69,7 @@ class ExllamaV3Container(BaseModelContainer): config: Optional[Config] = None draft_config: Optional[Config] = None generator: Optional[AsyncGenerator] = None + vision_model: Optional[Model] = None # Class-specific vars gpu_split: Optional[List[float]] = None @@ -112,6 +113,19 @@ class ExllamaV3Container(BaseModelContainer): self.model = Model.from_config(self.config) self.tokenizer = Tokenizer.from_config(self.config) + # Prepare vision model if requested in config + self.use_vision = kwargs.get("vision") + if self.use_vision and "vision" in self.config.model_classes: + self.vision_model = Model.from_config(self.config, component="vision") + else: + logger.warning( + "The provided model does not have vision capabilities that are " + "supported by ExllamaV3. " + "Vision input is disabled." + ) + self.vision_model = None + self.use_vision = False + # Fallback to 4096 since exl3 can't fetch from HF's config.json self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096) @@ -418,6 +432,14 @@ class ExllamaV3Container(BaseModelContainer): @torch.inference_mode() def load_model_sync(self, progress_callback=None): + if self.use_vision: + for value in self.vision_model.load_gen( + reserve_per_device=self.autosplit_reserve, + callback=progress_callback + ): + if value: + yield value + if self.use_draft_model: for value in self.draft_model.load_gen( reserve_per_device=self.autosplit_reserve, @@ -527,6 +549,9 @@ class ExllamaV3Container(BaseModelContainer): A list of integer token IDs. """ + mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings") + mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] + return ( self.tokenizer.encode( text, @@ -534,6 +559,7 @@ class ExllamaV3Container(BaseModelContainer): kwargs.get("add_bos_token"), self.hf_model.add_bos_token() ), encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), + embeddings=mm_embeddings_content ) .flatten() .tolist() @@ -802,6 +828,9 @@ class ExllamaV3Container(BaseModelContainer): stop_conditions = params.stop add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token()) + # Get multimodal embeddings if present + mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] + # Fetch EOS tokens from generation_config if they exist eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id] @@ -812,6 +841,7 @@ class ExllamaV3Container(BaseModelContainer): prompt, add_bos=add_bos_token, encode_special_tokens=True, + embeddings=mm_embeddings_content, ) for prompt in prompts ] @@ -855,6 +885,7 @@ class ExllamaV3Container(BaseModelContainer): max_new_tokens=max_tokens, stop_conditions=stop_conditions, banned_strings=params.banned_strings, + embeddings=mm_embeddings_content, ) generated_tokens = 0 diff --git a/backends/exllamav3/vision.py b/backends/exllamav3/vision.py new file mode 100644 index 0000000..c59d881 --- /dev/null +++ b/backends/exllamav3/vision.py @@ -0,0 +1,27 @@ +"""Vision utilities for ExLlamaV2.""" + +from async_lru import alru_cache + +from common import model +from common.optional_dependencies import dependencies +from common.image_util import get_image + +# Since this is used outside the Exl3 backend, the dependency +# may be optional +if dependencies.exllamav3: + from exllamav3.tokenizer import MMEmbedding + + +# Fetch the return type on runtime +@alru_cache(20) +async def get_image_embedding_exl3(url: str) -> "MMEmbedding": + image = await get_image(url) + return model.container.vision_model.get_image_embeddings( + tokenizer=model.container.tokenizer, + image=image, + text_alias=None, + ) + + +def clear_image_embedding_cache(): + get_image_embedding_exl3.cache_clear() diff --git a/common/multimodal.py b/common/multimodal.py index dee865f..8b21587 100644 --- a/common/multimodal.py +++ b/common/multimodal.py @@ -1,4 +1,5 @@ -from backends.exllamav2.vision import get_image_embedding +from backends.exllamav2.vision import get_image_embedding_exl2 +from backends.exllamav3.vision import get_image_embedding_exl3 from common import model from loguru import logger from pydantic import BaseModel, Field @@ -8,7 +9,8 @@ from common.optional_dependencies import dependencies if dependencies.exllamav2: from exllamav2 import ExLlamaV2VisionTower - +if dependencies.exllamav3: + from exllamav3 import Model class MultimodalEmbeddingWrapper(BaseModel): """Common multimodal embedding wrapper""" @@ -20,12 +22,25 @@ class MultimodalEmbeddingWrapper(BaseModel): async def add(self, url: str): # Determine the type of vision embedding to use if not self.type: - if isinstance(model.container.vision_model, ExLlamaV2VisionTower): + if ( + dependencies.exllamav2 and + isinstance(model.container.vision_model, ExLlamaV2VisionTower) + ): self.type = "ExLlamaV2MMEmbedding" + elif ( + dependencies.exllamav3 and + isinstance(model.container.vision_model, Model) + ): + self.type = "MMEmbedding" + # Create the embedding if self.type == "ExLlamaV2MMEmbedding": - embedding = await get_image_embedding(url) + embedding = await get_image_embedding_exl2(url) + self.content.append(embedding) + self.text_alias.append(embedding.text_alias) + elif self.type == "MMEmbedding": + embedding = await get_image_embedding_exl3(url) self.content.append(embedding) self.text_alias.append(embedding.text_alias) else: - logger.error("No valid vision model to create embedding") + logger.error("No valid vision model to create embedding") \ No newline at end of file