Exl3: Add vision capability
This commit is contained in:
parent
4605c0f6bd
commit
1c9891bf04
4 changed files with 80 additions and 7 deletions
|
|
@ -14,7 +14,7 @@ if dependencies.exllamav2:
|
|||
|
||||
# Fetch the return type on runtime
|
||||
@alru_cache(20)
|
||||
async def get_image_embedding(url: str) -> "ExLlamaV2MMEmbedding":
|
||||
async def get_image_embedding_exl2(url: str) -> "ExLlamaV2MMEmbedding":
|
||||
image = await get_image(url)
|
||||
return model.container.vision_model.get_image_embeddings(
|
||||
model=model.container.model,
|
||||
|
|
@ -25,4 +25,4 @@ async def get_image_embedding(url: str) -> "ExLlamaV2MMEmbedding":
|
|||
|
||||
|
||||
def clear_image_embedding_cache():
|
||||
get_image_embedding.cache_clear()
|
||||
get_image_embedding_exl2.cache_clear()
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
config: Optional[Config] = None
|
||||
draft_config: Optional[Config] = None
|
||||
generator: Optional[AsyncGenerator] = None
|
||||
vision_model: Optional[Model] = None
|
||||
|
||||
# Class-specific vars
|
||||
gpu_split: Optional[List[float]] = None
|
||||
|
|
@ -112,6 +113,19 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
self.model = Model.from_config(self.config)
|
||||
self.tokenizer = Tokenizer.from_config(self.config)
|
||||
|
||||
# Prepare vision model if requested in config
|
||||
self.use_vision = kwargs.get("vision")
|
||||
if self.use_vision and "vision" in self.config.model_classes:
|
||||
self.vision_model = Model.from_config(self.config, component="vision")
|
||||
else:
|
||||
logger.warning(
|
||||
"The provided model does not have vision capabilities that are "
|
||||
"supported by ExllamaV3. "
|
||||
"Vision input is disabled."
|
||||
)
|
||||
self.vision_model = None
|
||||
self.use_vision = False
|
||||
|
||||
# Fallback to 4096 since exl3 can't fetch from HF's config.json
|
||||
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
||||
|
||||
|
|
@ -418,6 +432,14 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
|
||||
@torch.inference_mode()
|
||||
def load_model_sync(self, progress_callback=None):
|
||||
if self.use_vision:
|
||||
for value in self.vision_model.load_gen(
|
||||
reserve_per_device=self.autosplit_reserve,
|
||||
callback=progress_callback
|
||||
):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
if self.use_draft_model:
|
||||
for value in self.draft_model.load_gen(
|
||||
reserve_per_device=self.autosplit_reserve,
|
||||
|
|
@ -527,6 +549,9 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
A list of integer token IDs.
|
||||
"""
|
||||
|
||||
mm_embeddings: MultimodalEmbeddingWrapper = kwargs.get("embeddings")
|
||||
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
|
||||
|
||||
return (
|
||||
self.tokenizer.encode(
|
||||
text,
|
||||
|
|
@ -534,6 +559,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
kwargs.get("add_bos_token"), self.hf_model.add_bos_token()
|
||||
),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
embeddings=mm_embeddings_content
|
||||
)
|
||||
.flatten()
|
||||
.tolist()
|
||||
|
|
@ -802,6 +828,9 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
stop_conditions = params.stop
|
||||
add_bos_token = unwrap(params.add_bos_token, self.hf_model.add_bos_token())
|
||||
|
||||
# Get multimodal embeddings if present
|
||||
mm_embeddings_content = mm_embeddings.content if mm_embeddings else []
|
||||
|
||||
# Fetch EOS tokens from generation_config if they exist
|
||||
eos_tokens = self.hf_model.eos_tokens() or [self.tokenizer.eos_token_id]
|
||||
|
||||
|
|
@ -812,6 +841,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
prompt,
|
||||
add_bos=add_bos_token,
|
||||
encode_special_tokens=True,
|
||||
embeddings=mm_embeddings_content,
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
|
|
@ -855,6 +885,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
max_new_tokens=max_tokens,
|
||||
stop_conditions=stop_conditions,
|
||||
banned_strings=params.banned_strings,
|
||||
embeddings=mm_embeddings_content,
|
||||
)
|
||||
|
||||
generated_tokens = 0
|
||||
|
|
|
|||
27
backends/exllamav3/vision.py
Normal file
27
backends/exllamav3/vision.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Vision utilities for ExLlamaV2."""
|
||||
|
||||
from async_lru import alru_cache
|
||||
|
||||
from common import model
|
||||
from common.optional_dependencies import dependencies
|
||||
from common.image_util import get_image
|
||||
|
||||
# Since this is used outside the Exl3 backend, the dependency
|
||||
# may be optional
|
||||
if dependencies.exllamav3:
|
||||
from exllamav3.tokenizer import MMEmbedding
|
||||
|
||||
|
||||
# Fetch the return type on runtime
|
||||
@alru_cache(20)
|
||||
async def get_image_embedding_exl3(url: str) -> "MMEmbedding":
|
||||
image = await get_image(url)
|
||||
return model.container.vision_model.get_image_embeddings(
|
||||
tokenizer=model.container.tokenizer,
|
||||
image=image,
|
||||
text_alias=None,
|
||||
)
|
||||
|
||||
|
||||
def clear_image_embedding_cache():
|
||||
get_image_embedding_exl3.cache_clear()
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
from backends.exllamav2.vision import get_image_embedding
|
||||
from backends.exllamav2.vision import get_image_embedding_exl2
|
||||
from backends.exllamav3.vision import get_image_embedding_exl3
|
||||
from common import model
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field
|
||||
|
|
@ -8,7 +9,8 @@ from common.optional_dependencies import dependencies
|
|||
|
||||
if dependencies.exllamav2:
|
||||
from exllamav2 import ExLlamaV2VisionTower
|
||||
|
||||
if dependencies.exllamav3:
|
||||
from exllamav3 import Model
|
||||
|
||||
class MultimodalEmbeddingWrapper(BaseModel):
|
||||
"""Common multimodal embedding wrapper"""
|
||||
|
|
@ -20,12 +22,25 @@ class MultimodalEmbeddingWrapper(BaseModel):
|
|||
async def add(self, url: str):
|
||||
# Determine the type of vision embedding to use
|
||||
if not self.type:
|
||||
if isinstance(model.container.vision_model, ExLlamaV2VisionTower):
|
||||
if (
|
||||
dependencies.exllamav2 and
|
||||
isinstance(model.container.vision_model, ExLlamaV2VisionTower)
|
||||
):
|
||||
self.type = "ExLlamaV2MMEmbedding"
|
||||
elif (
|
||||
dependencies.exllamav3 and
|
||||
isinstance(model.container.vision_model, Model)
|
||||
):
|
||||
self.type = "MMEmbedding"
|
||||
|
||||
# Create the embedding
|
||||
if self.type == "ExLlamaV2MMEmbedding":
|
||||
embedding = await get_image_embedding(url)
|
||||
embedding = await get_image_embedding_exl2(url)
|
||||
self.content.append(embedding)
|
||||
self.text_alias.append(embedding.text_alias)
|
||||
elif self.type == "MMEmbedding":
|
||||
embedding = await get_image_embedding_exl3(url)
|
||||
self.content.append(embedding)
|
||||
self.text_alias.append(embedding.text_alias)
|
||||
else:
|
||||
logger.error("No valid vision model to create embedding")
|
||||
logger.error("No valid vision model to create embedding")
|
||||
Loading…
Add table
Add a link
Reference in a new issue