Model: Move inline overrides to common

This is applied across containers. Doesn't make sense to put this method
in the backend.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-20 17:51:57 -04:00
parent 034682fcf1
commit b751e0a1d5
2 changed files with 42 additions and 33 deletions

View file

@ -4,10 +4,12 @@ Manages the storage and utility of model containers.
Containers exist as a common interface for backends.
"""
import aiofiles
import pathlib
from enum import Enum
from fastapi import HTTPException
from loguru import logger
from ruamel.yaml import YAML
from typing import Optional
from backends.base_model_container import BaseModelContainer
@ -15,6 +17,7 @@ from common.logger import get_loading_progress_bar
from common.networking import handle_request_error
from common.tabby_config import config
from common.optional_dependencies import dependencies
from common.utils import unwrap
# Global variables for model container
container: Optional[BaseModelContainer] = None
@ -43,6 +46,37 @@ def load_progress(module, modules):
yield module, modules
# TODO: Change this to be inline with config.yml
async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
"""Sets overrides from a model folder's config yaml."""
override_config_path = model_dir / "tabby_config.yml"
if not override_config_path.exists():
return kwargs
async with aiofiles.open(
override_config_path, "r", encoding="utf8"
) as override_config_file:
contents = await override_config_file.read()
# Create a temporary YAML parser
yaml = YAML(typ="safe")
override_args = unwrap(yaml.load(contents), {})
# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft_model"), {})
if draft_override_args:
kwargs["draft_model"] = {
**draft_override_args,
**unwrap(kwargs.get("draft_model"), {}),
}
# Merge the override and model kwargs
merged_kwargs = {**override_args, **kwargs}
return merged_kwargs
async def unload_model(skip_wait: bool = False, shutdown: bool = False):
"""Unloads a model"""
global container
@ -70,8 +104,15 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
# Reset to prepare for a new container
container = None
# Merge with config defaults
# Model_dir is already provided
# TODO: Isolate the root cause
kwargs.pop("model_dir")
# Merge with config and inline defaults
kwargs = {**config.model_defaults, **kwargs}
kwargs = await apply_inline_overrides(model_path, **kwargs)
print(kwargs)
# Create a new container
new_container = await ExllamaV2Container.create(