Model: Add vision loading support
Adds the ability to load vision parts of text + image models. Requires an explicit flag in config because there isn't a way to automatically determine whether the vision tower should be used. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
cc2516790d
commit
69ac0eb8aa
5 changed files with 42 additions and 5 deletions
|
|
@ -20,6 +20,7 @@ from exllamav2 import (
|
|||
ExLlamaV2Cache_TP,
|
||||
ExLlamaV2Tokenizer,
|
||||
ExLlamaV2Lora,
|
||||
ExLlamaV2VisionTower,
|
||||
)
|
||||
from exllamav2.generator import (
|
||||
ExLlamaV2Sampler,
|
||||
|
|
@ -28,6 +29,7 @@ from exllamav2.generator import (
|
|||
)
|
||||
from itertools import zip_longest
|
||||
from loguru import logger
|
||||
from PIL import Image
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from ruamel.yaml import YAML
|
||||
|
|
@ -91,6 +93,10 @@ class ExllamaV2Container:
|
|||
autosplit_reserve: List[float] = [96 * 1024**2]
|
||||
use_tp: bool = False
|
||||
|
||||
# Vision vars
|
||||
use_vision: bool = False
|
||||
vision_model: Optional[ExLlamaV2VisionTower] = None
|
||||
|
||||
# Load state
|
||||
model_is_loading: bool = False
|
||||
model_loaded: bool = False
|
||||
|
|
@ -144,6 +150,9 @@ class ExllamaV2Container:
|
|||
# Apply a model's config overrides while respecting user settings
|
||||
kwargs = await self.set_model_overrides(**kwargs)
|
||||
|
||||
# Set vision state
|
||||
self.use_vision = unwrap(kwargs.get("vision"), True)
|
||||
|
||||
# Prepare the draft model config if necessary
|
||||
draft_args = unwrap(kwargs.get("draft_model"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
|
|
@ -608,6 +617,14 @@ class ExllamaV2Container:
|
|||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||
|
||||
# Load vision tower if it exists
|
||||
if self.use_vision:
|
||||
self.vision_model = ExLlamaV2VisionTower(self.config)
|
||||
|
||||
for value in self.vision_model.load_gen(callback_gen=progress_callback):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
self.model = ExLlamaV2(self.config)
|
||||
if not self.quiet:
|
||||
logger.info("Loading model: " + self.config.model_dir)
|
||||
|
|
|
|||
|
|
@ -270,6 +270,12 @@ class ModelConfig(BaseConfigModel):
|
|||
"NOTE: Only works with chat completion message lists!"
|
||||
),
|
||||
)
|
||||
vision: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Enables vision support if the model supports it. (default: False)"
|
||||
),
|
||||
)
|
||||
num_experts_per_token: Optional[int] = Field(
|
||||
None,
|
||||
description=(
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class ModelType(Enum):
|
|||
MODEL = "model"
|
||||
DRAFT = "draft"
|
||||
EMBEDDING = "embedding"
|
||||
VISION = "vision"
|
||||
|
||||
|
||||
def load_progress(module, modules):
|
||||
|
|
@ -70,17 +71,26 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
# Create a new container
|
||||
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
||||
|
||||
model_type = "draft" if container.draft_config else "model"
|
||||
# Add possible types of models that can be loaded
|
||||
model_type = [ModelType.MODEL]
|
||||
|
||||
if container.use_vision:
|
||||
model_type.insert(0, ModelType.VISION)
|
||||
|
||||
if container.draft_config:
|
||||
model_type.insert(0, ModelType.DRAFT)
|
||||
|
||||
load_status = container.load_gen(load_progress, **kwargs)
|
||||
|
||||
progress = get_loading_progress_bar()
|
||||
progress.start()
|
||||
|
||||
try:
|
||||
index = 0
|
||||
async for module, modules in load_status:
|
||||
if module == 0:
|
||||
loading_task = progress.add_task(
|
||||
f"[cyan]Loading {model_type} modules", total=modules
|
||||
f"[cyan]Loading {model_type[index].value} modules", total=modules
|
||||
)
|
||||
else:
|
||||
progress.advance(loading_task)
|
||||
|
|
@ -89,10 +99,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
|
||||
if module == modules:
|
||||
# Switch to model progress if the draft model is loaded
|
||||
if model_type == "draft":
|
||||
model_type = "model"
|
||||
else:
|
||||
if index == len(model_type):
|
||||
progress.stop()
|
||||
else:
|
||||
index += 1
|
||||
finally:
|
||||
progress.stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -124,6 +124,9 @@ model:
|
|||
# NOTE: Only works with chat completion message lists!
|
||||
prompt_template:
|
||||
|
||||
# Enables vision support if the model supports it. (default: False)
|
||||
vision: false
|
||||
|
||||
# Number of experts to use per token.
|
||||
# Fetched from the model's config.json if empty.
|
||||
# NOTE: For MoE models only.
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ class ModelLoadRequest(BaseModel):
|
|||
cache_mode: Optional[str] = None
|
||||
chunk_size: Optional[int] = None
|
||||
prompt_template: Optional[str] = None
|
||||
vision: Optional[bool] = None
|
||||
num_experts_per_token: Optional[int] = None
|
||||
|
||||
# Non-config arguments
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue