Previously, the parameters under the "model" block in config.yml only handled the loading of a model on startup. This meant that any subsequent API request required each parameter to be filled out or use a sane default (usually defaults to the model's config.json). However, there are cases where admins may want an argument from the config to apply if the parameter isn't provided in the request body. To help alleviate this, add a mechanism that works like sampler overrides where users can specify a flag that acts as a fallback. Therefore, this change both preserves the source of truth of what parameters the admin is loading and adds some convenience for users that want customizable defaults for their requests. This behavior may change in the future, but I think it solves the issue for now. Signed-off-by: kingbri <bdashore3@proton.me>
111 lines
3.1 KiB
Python
111 lines
3.1 KiB
Python
"""
|
|
Manages the storage and utility of model containers.
|
|
|
|
Containers exist as a common interface for backends.
|
|
"""
|
|
|
|
import pathlib
|
|
from loguru import logger
|
|
from typing import Optional
|
|
|
|
from backends.exllamav2.model import ExllamaV2Container
|
|
from common import config
|
|
from common.logger import get_loading_progress_bar
|
|
from common.utils import unwrap
|
|
|
|
# Global model container
|
|
container: Optional[ExllamaV2Container] = None
|
|
|
|
|
|
def load_progress(module, modules):
|
|
"""Wrapper callback for load progress."""
|
|
yield module, modules
|
|
|
|
|
|
async def unload_model(skip_wait: bool = False):
|
|
"""Unloads a model"""
|
|
global container
|
|
|
|
await container.unload(skip_wait=skip_wait)
|
|
container = None
|
|
|
|
|
|
async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|
"""Generator to load a model"""
|
|
global container
|
|
|
|
# Check if the model is already loaded
|
|
if container and container.model:
|
|
loaded_model_name = container.get_model_path().name
|
|
|
|
if loaded_model_name == model_path.name and container.model_loaded:
|
|
raise ValueError(
|
|
f'Model "{loaded_model_name}" is already loaded! Aborting.'
|
|
)
|
|
|
|
# Unload the existing model
|
|
if container and container.model:
|
|
logger.info("Unloading existing model.")
|
|
await unload_model()
|
|
|
|
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
|
|
|
|
model_type = "draft" if container.draft_config else "model"
|
|
load_status = container.load_gen(load_progress, **kwargs)
|
|
|
|
progress = get_loading_progress_bar()
|
|
progress.start()
|
|
|
|
try:
|
|
async for module, modules in load_status:
|
|
if module == 0:
|
|
loading_task = progress.add_task(
|
|
f"[cyan]Loading {model_type} modules", total=modules
|
|
)
|
|
else:
|
|
progress.advance(loading_task)
|
|
|
|
yield module, modules, model_type
|
|
|
|
if module == modules:
|
|
# Switch to model progress if the draft model is loaded
|
|
if model_type == "draft":
|
|
model_type = "model"
|
|
else:
|
|
progress.stop()
|
|
finally:
|
|
progress.stop()
|
|
|
|
|
|
async def load_model(model_path: pathlib.Path, **kwargs):
|
|
async for _ in load_model_gen(model_path, **kwargs):
|
|
pass
|
|
|
|
|
|
async def load_loras(lora_dir, **kwargs):
|
|
"""Wrapper to load loras."""
|
|
if len(container.get_loras()) > 0:
|
|
await unload_loras()
|
|
|
|
return await container.load_loras(lora_dir, **kwargs)
|
|
|
|
|
|
async def unload_loras():
|
|
"""Wrapper to unload loras"""
|
|
await container.unload(loras_only=True)
|
|
|
|
|
|
def get_config_default(key, fallback=None, is_draft=False):
|
|
"""Fetches a default value from model config if allowed by the user."""
|
|
|
|
model_config = config.model_config()
|
|
default_keys = unwrap(model_config.get("use_as_default"), [])
|
|
if key in default_keys:
|
|
# Is this a draft model load parameter?
|
|
if is_draft:
|
|
draft_config = config.draft_model_config()
|
|
return unwrap(draft_config.get(key), fallback)
|
|
else:
|
|
return unwrap(model_config.get(key), fallback)
|
|
else:
|
|
return fallback
|