Model: Add vision loading support
Adds the ability to load vision parts of text + image models. Requires an explicit flag in config because there isn't a way to automatically determine whether the vision tower should be used. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
cc2516790d
commit
69ac0eb8aa
5 changed files with 42 additions and 5 deletions
|
|
@ -20,6 +20,7 @@ from exllamav2 import (
|
||||||
ExLlamaV2Cache_TP,
|
ExLlamaV2Cache_TP,
|
||||||
ExLlamaV2Tokenizer,
|
ExLlamaV2Tokenizer,
|
||||||
ExLlamaV2Lora,
|
ExLlamaV2Lora,
|
||||||
|
ExLlamaV2VisionTower,
|
||||||
)
|
)
|
||||||
from exllamav2.generator import (
|
from exllamav2.generator import (
|
||||||
ExLlamaV2Sampler,
|
ExLlamaV2Sampler,
|
||||||
|
|
@ -28,6 +29,7 @@ from exllamav2.generator import (
|
||||||
)
|
)
|
||||||
from itertools import zip_longest
|
from itertools import zip_longest
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
from PIL import Image
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
from ruamel.yaml import YAML
|
from ruamel.yaml import YAML
|
||||||
|
|
@ -91,6 +93,10 @@ class ExllamaV2Container:
|
||||||
autosplit_reserve: List[float] = [96 * 1024**2]
|
autosplit_reserve: List[float] = [96 * 1024**2]
|
||||||
use_tp: bool = False
|
use_tp: bool = False
|
||||||
|
|
||||||
|
# Vision vars
|
||||||
|
use_vision: bool = False
|
||||||
|
vision_model: Optional[ExLlamaV2VisionTower] = None
|
||||||
|
|
||||||
# Load state
|
# Load state
|
||||||
model_is_loading: bool = False
|
model_is_loading: bool = False
|
||||||
model_loaded: bool = False
|
model_loaded: bool = False
|
||||||
|
|
@ -144,6 +150,9 @@ class ExllamaV2Container:
|
||||||
# Apply a model's config overrides while respecting user settings
|
# Apply a model's config overrides while respecting user settings
|
||||||
kwargs = await self.set_model_overrides(**kwargs)
|
kwargs = await self.set_model_overrides(**kwargs)
|
||||||
|
|
||||||
|
# Set vision state
|
||||||
|
self.use_vision = unwrap(kwargs.get("vision"), True)
|
||||||
|
|
||||||
# Prepare the draft model config if necessary
|
# Prepare the draft model config if necessary
|
||||||
draft_args = unwrap(kwargs.get("draft_model"), {})
|
draft_args = unwrap(kwargs.get("draft_model"), {})
|
||||||
draft_model_name = draft_args.get("draft_model_name")
|
draft_model_name = draft_args.get("draft_model_name")
|
||||||
|
|
@ -608,6 +617,14 @@ class ExllamaV2Container:
|
||||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||||
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
self.draft_model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||||
|
|
||||||
|
# Load vision tower if it exists
|
||||||
|
if self.use_vision:
|
||||||
|
self.vision_model = ExLlamaV2VisionTower(self.config)
|
||||||
|
|
||||||
|
for value in self.vision_model.load_gen(callback_gen=progress_callback):
|
||||||
|
if value:
|
||||||
|
yield value
|
||||||
|
|
||||||
self.model = ExLlamaV2(self.config)
|
self.model = ExLlamaV2(self.config)
|
||||||
if not self.quiet:
|
if not self.quiet:
|
||||||
logger.info("Loading model: " + self.config.model_dir)
|
logger.info("Loading model: " + self.config.model_dir)
|
||||||
|
|
|
||||||
|
|
@ -270,6 +270,12 @@ class ModelConfig(BaseConfigModel):
|
||||||
"NOTE: Only works with chat completion message lists!"
|
"NOTE: Only works with chat completion message lists!"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
vision: Optional[bool] = Field(
|
||||||
|
False,
|
||||||
|
description=(
|
||||||
|
"Enables vision support if the model supports it. (default: False)"
|
||||||
|
),
|
||||||
|
)
|
||||||
num_experts_per_token: Optional[int] = Field(
|
num_experts_per_token: Optional[int] = Field(
|
||||||
None,
|
None,
|
||||||
description=(
|
description=(
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,7 @@ class ModelType(Enum):
|
||||||
MODEL = "model"
|
MODEL = "model"
|
||||||
DRAFT = "draft"
|
DRAFT = "draft"
|
||||||
EMBEDDING = "embedding"
|
EMBEDDING = "embedding"
|
||||||
|
VISION = "vision"
|
||||||
|
|
||||||
|
|
||||||
def load_progress(module, modules):
|
def load_progress(module, modules):
|
||||||
|
|
@ -70,17 +71,26 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||||
# Create a new container
|
# Create a new container
|
||||||
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
||||||
|
|
||||||
model_type = "draft" if container.draft_config else "model"
|
# Add possible types of models that can be loaded
|
||||||
|
model_type = [ModelType.MODEL]
|
||||||
|
|
||||||
|
if container.use_vision:
|
||||||
|
model_type.insert(0, ModelType.VISION)
|
||||||
|
|
||||||
|
if container.draft_config:
|
||||||
|
model_type.insert(0, ModelType.DRAFT)
|
||||||
|
|
||||||
load_status = container.load_gen(load_progress, **kwargs)
|
load_status = container.load_gen(load_progress, **kwargs)
|
||||||
|
|
||||||
progress = get_loading_progress_bar()
|
progress = get_loading_progress_bar()
|
||||||
progress.start()
|
progress.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
index = 0
|
||||||
async for module, modules in load_status:
|
async for module, modules in load_status:
|
||||||
if module == 0:
|
if module == 0:
|
||||||
loading_task = progress.add_task(
|
loading_task = progress.add_task(
|
||||||
f"[cyan]Loading {model_type} modules", total=modules
|
f"[cyan]Loading {model_type[index].value} modules", total=modules
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
progress.advance(loading_task)
|
progress.advance(loading_task)
|
||||||
|
|
@ -89,10 +99,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||||
|
|
||||||
if module == modules:
|
if module == modules:
|
||||||
# Switch to model progress if the draft model is loaded
|
# Switch to model progress if the draft model is loaded
|
||||||
if model_type == "draft":
|
if index == len(model_type):
|
||||||
model_type = "model"
|
|
||||||
else:
|
|
||||||
progress.stop()
|
progress.stop()
|
||||||
|
else:
|
||||||
|
index += 1
|
||||||
finally:
|
finally:
|
||||||
progress.stop()
|
progress.stop()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,9 @@ model:
|
||||||
# NOTE: Only works with chat completion message lists!
|
# NOTE: Only works with chat completion message lists!
|
||||||
prompt_template:
|
prompt_template:
|
||||||
|
|
||||||
|
# Enables vision support if the model supports it. (default: False)
|
||||||
|
vision: false
|
||||||
|
|
||||||
# Number of experts to use per token.
|
# Number of experts to use per token.
|
||||||
# Fetched from the model's config.json if empty.
|
# Fetched from the model's config.json if empty.
|
||||||
# NOTE: For MoE models only.
|
# NOTE: For MoE models only.
|
||||||
|
|
|
||||||
|
|
@ -107,6 +107,7 @@ class ModelLoadRequest(BaseModel):
|
||||||
cache_mode: Optional[str] = None
|
cache_mode: Optional[str] = None
|
||||||
chunk_size: Optional[int] = None
|
chunk_size: Optional[int] = None
|
||||||
prompt_template: Optional[str] = None
|
prompt_template: Optional[str] = None
|
||||||
|
vision: Optional[bool] = None
|
||||||
num_experts_per_token: Optional[int] = None
|
num_experts_per_token: Optional[int] = None
|
||||||
|
|
||||||
# Non-config arguments
|
# Non-config arguments
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue