Model: Add vision loading support

Adds the ability to load vision parts of text + image models. Requires
an explicit flag in config because there isn't a way to automatically
determine whether the vision tower should be used.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-11-11 12:04:40 -05:00
parent cc2516790d
commit 69ac0eb8aa
5 changed files with 42 additions and 5 deletions

View file

@ -20,6 +20,7 @@ from exllamav2 import (
ExLlamaV2Cache_TP,
ExLlamaV2Tokenizer,
ExLlamaV2Lora,
ExLlamaV2VisionTower,
)
from exllamav2.generator import (
ExLlamaV2Sampler,
@ -28,6 +29,7 @@ from exllamav2.generator import (
)
from itertools import zip_longest
from loguru import logger
from PIL import Image
from typing import List, Optional, Union
from ruamel.yaml import YAML
@ -91,6 +93,10 @@ class ExllamaV2Container:
autosplit_reserve: List[float] = [96 * 1024**2]
use_tp: bool = False
# Vision vars
use_vision: bool = False
vision_model: Optional[ExLlamaV2VisionTower] = None
# Load state
model_is_loading: bool = False
model_loaded: bool = False
@ -144,6 +150,9 @@ class ExllamaV2Container:
# Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs)
# Set vision state
self.use_vision = unwrap(kwargs.get("vision"), True)
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
@ -608,6 +617,14 @@ class ExllamaV2Container:
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
# Load vision tower if it exists
if self.use_vision:
self.vision_model = ExLlamaV2VisionTower(self.config)
for value in self.vision_model.load_gen(callback_gen=progress_callback):
if value:
yield value
self.model = ExLlamaV2(self.config)
if not self.quiet:
logger.info("Loading model: " + self.config.model_dir)

View file

@ -270,6 +270,12 @@ class ModelConfig(BaseConfigModel):
"NOTE: Only works with chat completion message lists!"
),
)
vision: Optional[bool] = Field(
False,
description=(
"Enables vision support if the model supports it. (default: False)"
),
)
num_experts_per_token: Optional[int] = Field(
None,
description=(

View file

@ -33,6 +33,7 @@ class ModelType(Enum):
MODEL = "model"
DRAFT = "draft"
EMBEDDING = "embedding"
VISION = "vision"
def load_progress(module, modules):
@ -70,17 +71,26 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
# Create a new container
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
model_type = "draft" if container.draft_config else "model"
# Add possible types of models that can be loaded
model_type = [ModelType.MODEL]
if container.use_vision:
model_type.insert(0, ModelType.VISION)
if container.draft_config:
model_type.insert(0, ModelType.DRAFT)
load_status = container.load_gen(load_progress, **kwargs)
progress = get_loading_progress_bar()
progress.start()
try:
index = 0
async for module, modules in load_status:
if module == 0:
loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules
f"[cyan]Loading {model_type[index].value} modules", total=modules
)
else:
progress.advance(loading_task)
@ -89,10 +99,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
if module == modules:
# Switch to model progress if the draft model is loaded
if model_type == "draft":
model_type = "model"
else:
if index == len(model_type):
progress.stop()
else:
index += 1
finally:
progress.stop()

View file

@ -124,6 +124,9 @@ model:
# NOTE: Only works with chat completion message lists!
prompt_template:
# Enables vision support if the model supports it. (default: False)
vision: false
# Number of experts to use per token.
# Fetched from the model's config.json if empty.
# NOTE: For MoE models only.

View file

@ -107,6 +107,7 @@ class ModelLoadRequest(BaseModel):
cache_mode: Optional[str] = None
chunk_size: Optional[int] = None
prompt_template: Optional[str] = None
vision: Optional[bool] = None
num_experts_per_token: Optional[int] = None
# Non-config arguments