Model: Add vision loading support
Adds the ability to load vision parts of text + image models. Requires an explicit flag in config because there isn't a way to automatically determine whether the vision tower should be used. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
cc2516790d
commit
69ac0eb8aa
5 changed files with 42 additions and 5 deletions
|
|
@ -270,6 +270,12 @@ class ModelConfig(BaseConfigModel):
|
|||
"NOTE: Only works with chat completion message lists!"
|
||||
),
|
||||
)
|
||||
vision: Optional[bool] = Field(
|
||||
False,
|
||||
description=(
|
||||
"Enables vision support if the model supports it. (default: False)"
|
||||
),
|
||||
)
|
||||
num_experts_per_token: Optional[int] = Field(
|
||||
None,
|
||||
description=(
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ class ModelType(Enum):
|
|||
MODEL = "model"
|
||||
DRAFT = "draft"
|
||||
EMBEDDING = "embedding"
|
||||
VISION = "vision"
|
||||
|
||||
|
||||
def load_progress(module, modules):
|
||||
|
|
@ -70,17 +71,26 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
# Create a new container
|
||||
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
||||
|
||||
model_type = "draft" if container.draft_config else "model"
|
||||
# Add possible types of models that can be loaded
|
||||
model_type = [ModelType.MODEL]
|
||||
|
||||
if container.use_vision:
|
||||
model_type.insert(0, ModelType.VISION)
|
||||
|
||||
if container.draft_config:
|
||||
model_type.insert(0, ModelType.DRAFT)
|
||||
|
||||
load_status = container.load_gen(load_progress, **kwargs)
|
||||
|
||||
progress = get_loading_progress_bar()
|
||||
progress.start()
|
||||
|
||||
try:
|
||||
index = 0
|
||||
async for module, modules in load_status:
|
||||
if module == 0:
|
||||
loading_task = progress.add_task(
|
||||
f"[cyan]Loading {model_type} modules", total=modules
|
||||
f"[cyan]Loading {model_type[index].value} modules", total=modules
|
||||
)
|
||||
else:
|
||||
progress.advance(loading_task)
|
||||
|
|
@ -89,10 +99,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
|
||||
if module == modules:
|
||||
# Switch to model progress if the draft model is loaded
|
||||
if model_type == "draft":
|
||||
model_type = "model"
|
||||
else:
|
||||
if index == len(model_type):
|
||||
progress.stop()
|
||||
else:
|
||||
index += 1
|
||||
finally:
|
||||
progress.stop()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue