Model: Add vision loading support

Adds the ability to load vision parts of text + image models. Requires
an explicit flag in config because there isn't a way to automatically
determine whether the vision tower should be used.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-11-11 12:04:40 -05:00
parent cc2516790d
commit 69ac0eb8aa
5 changed files with 42 additions and 5 deletions

View file

@ -20,6 +20,7 @@ from exllamav2 import (
ExLlamaV2Cache_TP, ExLlamaV2Cache_TP,
ExLlamaV2Tokenizer, ExLlamaV2Tokenizer,
ExLlamaV2Lora, ExLlamaV2Lora,
ExLlamaV2VisionTower,
) )
from exllamav2.generator import ( from exllamav2.generator import (
ExLlamaV2Sampler, ExLlamaV2Sampler,
@ -28,6 +29,7 @@ from exllamav2.generator import (
) )
from itertools import zip_longest from itertools import zip_longest
from loguru import logger from loguru import logger
from PIL import Image
from typing import List, Optional, Union from typing import List, Optional, Union
from ruamel.yaml import YAML from ruamel.yaml import YAML
@ -91,6 +93,10 @@ class ExllamaV2Container:
autosplit_reserve: List[float] = [96 * 1024**2] autosplit_reserve: List[float] = [96 * 1024**2]
use_tp: bool = False use_tp: bool = False
# Vision vars
use_vision: bool = False
vision_model: Optional[ExLlamaV2VisionTower] = None
# Load state # Load state
model_is_loading: bool = False model_is_loading: bool = False
model_loaded: bool = False model_loaded: bool = False
@ -144,6 +150,9 @@ class ExllamaV2Container:
# Apply a model's config overrides while respecting user settings # Apply a model's config overrides while respecting user settings
kwargs = await self.set_model_overrides(**kwargs) kwargs = await self.set_model_overrides(**kwargs)
# Set vision state
self.use_vision = unwrap(kwargs.get("vision"), True)
# Prepare the draft model config if necessary # Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {}) draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name") draft_model_name = draft_args.get("draft_model_name")
@ -608,6 +617,14 @@ class ExllamaV2Container:
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long) input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True) self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
# Load vision tower if it exists
if self.use_vision:
self.vision_model = ExLlamaV2VisionTower(self.config)
for value in self.vision_model.load_gen(callback_gen=progress_callback):
if value:
yield value
self.model = ExLlamaV2(self.config) self.model = ExLlamaV2(self.config)
if not self.quiet: if not self.quiet:
logger.info("Loading model: " + self.config.model_dir) logger.info("Loading model: " + self.config.model_dir)

View file

@ -270,6 +270,12 @@ class ModelConfig(BaseConfigModel):
"NOTE: Only works with chat completion message lists!" "NOTE: Only works with chat completion message lists!"
), ),
) )
vision: Optional[bool] = Field(
False,
description=(
"Enables vision support if the model supports it. (default: False)"
),
)
num_experts_per_token: Optional[int] = Field( num_experts_per_token: Optional[int] = Field(
None, None,
description=( description=(

View file

@ -33,6 +33,7 @@ class ModelType(Enum):
MODEL = "model" MODEL = "model"
DRAFT = "draft" DRAFT = "draft"
EMBEDDING = "embedding" EMBEDDING = "embedding"
VISION = "vision"
def load_progress(module, modules): def load_progress(module, modules):
@ -70,17 +71,26 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
# Create a new container # Create a new container
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs) container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
model_type = "draft" if container.draft_config else "model" # Add possible types of models that can be loaded
model_type = [ModelType.MODEL]
if container.use_vision:
model_type.insert(0, ModelType.VISION)
if container.draft_config:
model_type.insert(0, ModelType.DRAFT)
load_status = container.load_gen(load_progress, **kwargs) load_status = container.load_gen(load_progress, **kwargs)
progress = get_loading_progress_bar() progress = get_loading_progress_bar()
progress.start() progress.start()
try: try:
index = 0
async for module, modules in load_status: async for module, modules in load_status:
if module == 0: if module == 0:
loading_task = progress.add_task( loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules f"[cyan]Loading {model_type[index].value} modules", total=modules
) )
else: else:
progress.advance(loading_task) progress.advance(loading_task)
@ -89,10 +99,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
if module == modules: if module == modules:
# Switch to model progress if the draft model is loaded # Switch to model progress if the draft model is loaded
if model_type == "draft": if index == len(model_type):
model_type = "model"
else:
progress.stop() progress.stop()
else:
index += 1
finally: finally:
progress.stop() progress.stop()

View file

@ -124,6 +124,9 @@ model:
# NOTE: Only works with chat completion message lists! # NOTE: Only works with chat completion message lists!
prompt_template: prompt_template:
# Enables vision support if the model supports it. (default: False)
vision: false
# Number of experts to use per token. # Number of experts to use per token.
# Fetched from the model's config.json if empty. # Fetched from the model's config.json if empty.
# NOTE: For MoE models only. # NOTE: For MoE models only.

View file

@ -107,6 +107,7 @@ class ModelLoadRequest(BaseModel):
cache_mode: Optional[str] = None cache_mode: Optional[str] = None
chunk_size: Optional[int] = None chunk_size: Optional[int] = None
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
vision: Optional[bool] = None
num_experts_per_token: Optional[int] = None num_experts_per_token: Optional[int] = None
# Non-config arguments # Non-config arguments