diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index fd23a28..091043a 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1003,7 +1003,6 @@ class ExllamaV2Container(BaseModelContainer): params: BaseSamplerRequest, gen_settings: ExLlamaV2Sampler.Settings, grammar_handler: ExLlamaV2Grammar, - banned_strings: List[str], ): # Apply settings gen_settings.temperature = params.temperature @@ -1111,16 +1110,6 @@ class ExllamaV2Container(BaseModelContainer): params.grammar_string, self.model, self.tokenizer ) - # Set banned strings - banned_strings = params.banned_strings - if banned_strings and len(grammar_handler.filters) > 0: - logger.warning( - "Disabling banned_strings because " - "they cannot be used with grammar filters." - ) - - banned_strings = [] - # Speculative Ngram self.generator.speculative_ngram = params.speculative_ngram @@ -1226,15 +1215,23 @@ class ExllamaV2Container(BaseModelContainer): prompts = [prompt] gen_settings = ExLlamaV2Sampler.Settings() grammar_handler = ExLlamaV2Grammar() - banned_strings = [] self.assign_gen_params( params, gen_settings, grammar_handler, - banned_strings, ) + # Set banned strings + banned_strings = params.banned_strings + if banned_strings and len(grammar_handler.filters) > 0: + logger.warning( + "Disabling banned_strings because " + "they cannot be used with grammar filters." + ) + + banned_strings = [] + # Set CFG scale and negative prompt cfg_scale = params.cfg_scale negative_prompt = None diff --git a/backends/exllamav2/vision.py b/backends/exllamav2/vision.py index 7db0b09..90106e3 100644 --- a/backends/exllamav2/vision.py +++ b/backends/exllamav2/vision.py @@ -1,19 +1,10 @@ """Vision utilities for ExLlamaV2.""" -import aiohttp -import base64 -import io -import re from async_lru import alru_cache -from fastapi import HTTPException -from PIL import Image from common import model -from common.networking import ( - handle_request_error, -) from common.optional_dependencies import dependencies -from common.tabby_config import config +from common.image_util import get_image # Since this is used outside the Exl2 backend, the dependency # may be optional @@ -21,49 +12,9 @@ if dependencies.exllamav2: from exllamav2.generator import ExLlamaV2MMEmbedding -async def get_image(url: str) -> Image: - if url.startswith("data:image"): - # Handle base64 image - match = re.match(r"^data:image\/[a-zA-Z0-9]+;base64,(.*)$", url) - if match: - base64_image = match.group(1) - bytes_image = base64.b64decode(base64_image) - else: - error_message = handle_request_error( - "Failed to read base64 image input.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - else: - # Handle image URL - if config.network.disable_fetch_requests: - error_message = handle_request_error( - f"Failed to fetch image from {url} as fetch requests are disabled.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status == 200: - bytes_image = await response.read() - else: - error_message = handle_request_error( - f"Failed to fetch image from {url}.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - return Image.open(io.BytesIO(bytes_image)) - - # Fetch the return type on runtime @alru_cache(20) -async def get_image_embedding(url: str) -> "ExLlamaV2MMEmbedding": +async def get_image_embedding_exl2(url: str) -> "ExLlamaV2MMEmbedding": image = await get_image(url) return model.container.vision_model.get_image_embeddings( model=model.container.model, @@ -74,4 +25,4 @@ async def get_image_embedding(url: str) -> "ExLlamaV2MMEmbedding": def clear_image_embedding_cache(): - get_image_embedding.cache_clear() + get_image_embedding_exl2.cache_clear() diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 9e69074..ee2dfaa 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -69,6 +69,7 @@ class ExllamaV3Container(BaseModelContainer): config: Optional[Config] = None draft_config: Optional[Config] = None generator: Optional[AsyncGenerator] = None + vision_model: Optional[Model] = None # Class-specific vars gpu_split: Optional[List[float]] = None @@ -99,7 +100,7 @@ class ExllamaV3Container(BaseModelContainer): self = cls() # Make sure ExllamaV3 is up to date - check_package_version("exllamav3", "0.0.3") + check_package_version("exllamav3", "0.0.4") logger.warning( "ExllamaV3 is currently in an alpha state. " @@ -112,6 +113,19 @@ class ExllamaV3Container(BaseModelContainer): self.model = Model.from_config(self.config) self.tokenizer = Tokenizer.from_config(self.config) + # Prepare vision model if requested in config + self.use_vision = kwargs.get("vision") + if self.use_vision and "vision" in self.config.model_classes: + self.vision_model = Model.from_config(self.config, component="vision") + else: + logger.warning( + "The provided model does not have vision capabilities that are " + "supported by ExllamaV3. " + "Vision input is disabled." + ) + self.vision_model = None + self.use_vision = False + # Fallback to 4096 since exl3 can't fetch from HF's config.json self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096) @@ -418,6 +432,14 @@ class ExllamaV3Container(BaseModelContainer): @torch.inference_mode() def load_model_sync(self, progress_callback=None): + if self.use_vision: + for value in self.vision_model.load_gen( + reserve_per_device=self.autosplit_reserve, + callback=progress_callback, + ): + if value: + yield value + if self.use_draft_model: for value in self.draft_model.load_gen( reserve_per_device=self.autosplit_reserve, @@ -527,6 +549,9 @@ class ExllamaV3Container(BaseModelContainer): A list of integer token IDs. """ + mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings") + mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] + return ( self.tokenizer.encode( text, @@ -534,6 +559,7 @@ class ExllamaV3Container(BaseModelContainer): kwargs.get("add_bos_token"), self.hf_model.add_bos_token() ), encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), + embeddings=mm_embeddings_content, ) .flatten() .tolist() @@ -802,6 +828,9 @@ class ExllamaV3Container(BaseModelContainer): stop_conditions = params.stop add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token()) + # Get multimodal embeddings if present + mm_embeddings_content = mm_embeddings.content if mm_embeddings else [] + # Fetch EOS tokens from generation_config if they exist eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id] @@ -812,6 +841,7 @@ class ExllamaV3Container(BaseModelContainer): prompt, add_bos=add_bos_token, encode_special_tokens=True, + embeddings=mm_embeddings_content, ) for prompt in prompts ] @@ -855,6 +885,7 @@ class ExllamaV3Container(BaseModelContainer): max_new_tokens=max_tokens, stop_conditions=stop_conditions, banned_strings=params.banned_strings, + embeddings=mm_embeddings_content, ) generated_tokens = 0 diff --git a/backends/exllamav3/vision.py b/backends/exllamav3/vision.py new file mode 100644 index 0000000..c59d881 --- /dev/null +++ b/backends/exllamav3/vision.py @@ -0,0 +1,27 @@ +"""Vision utilities for ExLlamaV2.""" + +from async_lru import alru_cache + +from common import model +from common.optional_dependencies import dependencies +from common.image_util import get_image + +# Since this is used outside the Exl3 backend, the dependency +# may be optional +if dependencies.exllamav3: + from exllamav3.tokenizer import MMEmbedding + + +# Fetch the return type on runtime +@alru_cache(20) +async def get_image_embedding_exl3(url: str) -> "MMEmbedding": + image = await get_image(url) + return model.container.vision_model.get_image_embeddings( + tokenizer=model.container.tokenizer, + image=image, + text_alias=None, + ) + + +def clear_image_embedding_cache(): + get_image_embedding_exl3.cache_clear() diff --git a/common/image_util.py b/common/image_util.py new file mode 100644 index 0000000..9790cfe --- /dev/null +++ b/common/image_util.py @@ -0,0 +1,52 @@ +import aiohttp +import base64 +import io +import re + +from fastapi import HTTPException +from PIL import Image + +from common.networking import ( + handle_request_error, +) +from common.tabby_config import config + + +async def get_image(url: str) -> Image: + if url.startswith("data:image"): + # Handle base64 image + match = re.match(r"^data:image\/[a-zA-Z0-9]+;base64,(.*)$", url) + if match: + base64_image = match.group(1) + bytes_image = base64.b64decode(base64_image) + else: + error_message = handle_request_error( + "Failed to read base64 image input.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + else: + # Handle image URL + if config.network.disable_fetch_requests: + error_message = handle_request_error( + f"Failed to fetch image from {url} as fetch requests are disabled.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + async with aiohttp.ClientSession() as session: + async with session.get(url) as response: + if response.status == 200: + bytes_image = await response.read() + else: + error_message = handle_request_error( + f"Failed to fetch image from {url}.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + return Image.open(io.BytesIO(bytes_image)) diff --git a/common/multimodal.py b/common/multimodal.py index dee865f..b92386f 100644 --- a/common/multimodal.py +++ b/common/multimodal.py @@ -1,4 +1,5 @@ -from backends.exllamav2.vision import get_image_embedding +from backends.exllamav2.vision import get_image_embedding_exl2 +from backends.exllamav3.vision import get_image_embedding_exl3 from common import model from loguru import logger from pydantic import BaseModel, Field @@ -8,6 +9,8 @@ from common.optional_dependencies import dependencies if dependencies.exllamav2: from exllamav2 import ExLlamaV2VisionTower +if dependencies.exllamav3: + from exllamav3 import Model class MultimodalEmbeddingWrapper(BaseModel): @@ -20,11 +23,22 @@ class MultimodalEmbeddingWrapper(BaseModel): async def add(self, url: str): # Determine the type of vision embedding to use if not self.type: - if isinstance(model.container.vision_model, ExLlamaV2VisionTower): + if dependencies.exllamav2 and isinstance( + model.container.vision_model, ExLlamaV2VisionTower + ): self.type = "ExLlamaV2MMEmbedding" + elif dependencies.exllamav3 and isinstance( + model.container.vision_model, Model + ): + self.type = "MMEmbedding" + # Create the embedding if self.type == "ExLlamaV2MMEmbedding": - embedding = await get_image_embedding(url) + embedding = await get_image_embedding_exl2(url) + self.content.append(embedding) + self.text_alias.append(embedding.text_alias) + elif self.type == "MMEmbedding": + embedding = await get_image_embedding_exl3(url) self.content.append(embedding) self.text_alias.append(embedding.text_alias) else: diff --git a/pyproject.toml b/pyproject.toml index ca69823..6a3a6ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -78,14 +78,14 @@ cu121 = [ "exllamav2 @ https://github.com/turboderp-org/exllamav2/releases/download/v0.3.1/exllamav2-0.3.1+cu128.torch2.7.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl3 - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.3/exllamav3-0.0.3+cu128.torch2.7.0-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.3/exllamav3-0.0.3+cu128.torch2.7.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.3/exllamav3-0.0.3+cu128.torch2.7.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.3/exllamav3-0.0.3+cu128.torch2.7.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.3/exllamav3-0.0.3+cu128.torch2.7.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.3/exllamav3-0.0.3+cu128.torch2.7.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.3/exllamav3-0.0.3+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.3/exllamav3-0.0.3+cu128.torch2.7.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.4/exllamav3-0.0.4+cu128.torch2.7.0-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.4/exllamav3-0.0.4+cu128.torch2.7.0-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.4/exllamav3-0.0.4+cu128.torch2.7.0-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.4/exllamav3-0.0.4+cu128.torch2.7.0-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.4/exllamav3-0.0.4+cu128.torch2.7.0-cp313-cp313-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.13'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.4/exllamav3-0.0.4+cu128.torch2.7.0-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.4/exllamav3-0.0.4+cu128.torch2.7.0-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav3 @ https://github.com/turboderp-org/exllamav3/releases/download/v0.0.4/exllamav3-0.0.4+cu128.torch2.7.0-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Windows FA2 from https://github.com/kingbri1/flash-attention/releases "flash_attn @ https://github.com/kingbri1/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu128torch2.7.0cxx11abiFALSE-cp313-cp313-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.13'",