tabbyAPI-ollama/common/model.py
kingbri 933268f7e2 API: Integrate OpenAPI export script
Move OpenAPI export as an env var within the main function. This
allows for easy export by running main.

In addition, an env variable provides global and explicit state to
disable conditional wheel imports (ex. Exl2 and torch) which caused
errors at first.

Signed-off-by: kingbri <bdashore3@proton.me>
2024-07-08 12:34:32 -04:00

114 lines
3.2 KiB
Python

"""
Manages the storage and utility of model containers.
Containers exist as a common interface for backends.
"""
import os
import pathlib
from loguru import logger
from typing import Optional
from common import config
from common.logger import get_loading_progress_bar
from common.utils import unwrap
if not os.getenv("EXPORT_OPENAPI"):
from backends.exllamav2.model import ExllamaV2Container
# Global model container
container: Optional[ExllamaV2Container] = None
def load_progress(module, modules):
"""Wrapper callback for load progress."""
yield module, modules
async def unload_model(skip_wait: bool = False):
"""Unloads a model"""
global container
await container.unload(skip_wait=skip_wait)
container = None
async def load_model_gen(model_path: pathlib.Path, **kwargs):
"""Generator to load a model"""
global container
# Check if the model is already loaded
if container and container.model:
loaded_model_name = container.get_model_path().name
if loaded_model_name == model_path.name and container.model_loaded:
raise ValueError(
f'Model "{loaded_model_name}" is already loaded! Aborting.'
)
# Unload the existing model
if container and container.model:
logger.info("Unloading existing model.")
await unload_model()
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
model_type = "draft" if container.draft_config else "model"
load_status = container.load_gen(load_progress, **kwargs)
progress = get_loading_progress_bar()
progress.start()
try:
async for module, modules in load_status:
if module == 0:
loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules
)
else:
progress.advance(loading_task)
yield module, modules, model_type
if module == modules:
# Switch to model progress if the draft model is loaded
if model_type == "draft":
model_type = "model"
else:
progress.stop()
finally:
progress.stop()
async def load_model(model_path: pathlib.Path, **kwargs):
async for _ in load_model_gen(model_path, **kwargs):
pass
async def load_loras(lora_dir, **kwargs):
"""Wrapper to load loras."""
if len(container.get_loras()) > 0:
await unload_loras()
return await container.load_loras(lora_dir, **kwargs)
async def unload_loras():
"""Wrapper to unload loras"""
await container.unload(loras_only=True)
def get_config_default(key, fallback=None, is_draft=False):
"""Fetches a default value from model config if allowed by the user."""
model_config = config.model_config()
default_keys = unwrap(model_config.get("use_as_default"), [])
if key in default_keys:
# Is this a draft model load parameter?
if is_draft:
draft_config = config.draft_model_config()
return unwrap(draft_config.get(key), fallback)
else:
return unwrap(model_config.get(key), fallback)
else:
return fallback