API: Move to ModelManager
This is a shared module which manages the model container and provides extra utility functions around it to help slim down the API. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
8b46282aef
commit
b373b25235
5 changed files with 178 additions and 143 deletions
75
common/model.py
Normal file
75
common/model.py
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
"""
|
||||
Manages the storage and utility of model containers.
|
||||
|
||||
Containers exist as a common interface for backends.
|
||||
"""
|
||||
|
||||
import pathlib
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from backends.exllamav2.model import ExllamaV2Container
|
||||
from common.logger import get_loading_progress_bar
|
||||
from common.utils import load_progress
|
||||
|
||||
|
||||
container: Optional[ExllamaV2Container] = None
|
||||
|
||||
|
||||
async def unload_model():
|
||||
"""Unloads a model"""
|
||||
global container
|
||||
|
||||
container.unload()
|
||||
container = None
|
||||
|
||||
|
||||
async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
||||
"""Generator to load a model"""
|
||||
global container
|
||||
|
||||
# Check if the model is already loaded
|
||||
if container and container.model:
|
||||
loaded_model_name = container.get_model_path().name
|
||||
|
||||
if loaded_model_name == model_path.name:
|
||||
raise ValueError(
|
||||
f'Model "{loaded_model_name}" is already loaded! Aborting.'
|
||||
)
|
||||
|
||||
# Unload the existing model
|
||||
if container and container.model:
|
||||
logger.info("Unloading existing model.")
|
||||
await unload_model()
|
||||
|
||||
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
|
||||
|
||||
model_type = "draft" if container.draft_config else "model"
|
||||
load_status = container.load_gen(load_progress)
|
||||
|
||||
progress = get_loading_progress_bar()
|
||||
progress.start()
|
||||
|
||||
try:
|
||||
for module, modules in load_status:
|
||||
if module == 0:
|
||||
loading_task = progress.add_task(
|
||||
f"[cyan]Loading {model_type} modules", total=modules
|
||||
)
|
||||
else:
|
||||
progress.advance(loading_task)
|
||||
if module == modules:
|
||||
# Switch to model progress if the draft model is loaded
|
||||
if model_type == "draft":
|
||||
model_type = "model"
|
||||
else:
|
||||
progress.stop()
|
||||
|
||||
yield module, modules, model_type
|
||||
finally:
|
||||
progress.stop()
|
||||
|
||||
|
||||
async def load_model(model_path: pathlib.Path, **kwargs):
|
||||
async for _, _, _ in load_model_gen(model_path, **kwargs):
|
||||
pass
|
||||
Loading…
Add table
Add a link
Reference in a new issue