Model: Move inline overrides to common
This is applied across containers. Doesn't make sense to put this method in the backend. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
034682fcf1
commit
b751e0a1d5
2 changed files with 42 additions and 33 deletions
|
|
@ -4,10 +4,12 @@ Manages the storage and utility of model containers.
|
|||
Containers exist as a common interface for backends.
|
||||
"""
|
||||
|
||||
import aiofiles
|
||||
import pathlib
|
||||
from enum import Enum
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from ruamel.yaml import YAML
|
||||
from typing import Optional
|
||||
|
||||
from backends.base_model_container import BaseModelContainer
|
||||
|
|
@ -15,6 +17,7 @@ from common.logger import get_loading_progress_bar
|
|||
from common.networking import handle_request_error
|
||||
from common.tabby_config import config
|
||||
from common.optional_dependencies import dependencies
|
||||
from common.utils import unwrap
|
||||
|
||||
# Global variables for model container
|
||||
container: Optional[BaseModelContainer] = None
|
||||
|
|
@ -43,6 +46,37 @@ def load_progress(module, modules):
|
|||
yield module, modules
|
||||
|
||||
|
||||
# TODO: Change this to be inline with config.yml
|
||||
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
||||
"""Sets overrides from a model folder's config yaml."""
|
||||
|
||||
override_config_path = model_dir / "tabby_config.yml"
|
||||
|
||||
if not override_config_path.exists():
|
||||
return kwargs
|
||||
|
||||
async with aiofiles.open(
|
||||
override_config_path, "r", encoding="utf8"
|
||||
) as override_config_file:
|
||||
contents = await override_config_file.read()
|
||||
|
||||
# Create a temporary YAML parser
|
||||
yaml = YAML(typ="safe")
|
||||
override_args = unwrap(yaml.load(contents), {})
|
||||
|
||||
# Merge draft overrides beforehand
|
||||
draft_override_args = unwrap(override_args.get("draft_model"), {})
|
||||
if draft_override_args:
|
||||
kwargs["draft_model"] = {
|
||||
**draft_override_args,
|
||||
**unwrap(kwargs.get("draft_model"), {}),
|
||||
}
|
||||
|
||||
# Merge the override and model kwargs
|
||||
merged_kwargs = {**override_args, **kwargs}
|
||||
return merged_kwargs
|
||||
|
||||
|
||||
async def unload_model(skip_wait: bool = False, shutdown: bool = False):
|
||||
"""Unloads a model"""
|
||||
global container
|
||||
|
|
@ -70,8 +104,15 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
# Reset to prepare for a new container
|
||||
container = None
|
||||
|
||||
# Merge with config defaults
|
||||
# Model_dir is already provided
|
||||
# TODO: Isolate the root cause
|
||||
kwargs.pop("model_dir")
|
||||
|
||||
# Merge with config and inline defaults
|
||||
kwargs = {**config.model_defaults, **kwargs}
|
||||
kwargs = await apply_inline_overrides(model_path, **kwargs)
|
||||
|
||||
print(kwargs)
|
||||
|
||||
# Create a new container
|
||||
new_container = await ExllamaV2Container.create(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue