API + Model: Apply config.yml defaults for all load paths

There are two ways to load a model:
1. Via the load endpoint
2. Inline with a completion

The defaults were not applying on the inline load, so rewrite to fix
that. However, while doing this, set up a defaults dictionary rather
than comparing it at runtime and remove the pydantic default lambda
on all the model load fields.

This makes the code cleaner and establishes a clear config tree for
loading models.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-09-10 23:35:35 -04:00
parent 7baef05b49
commit b9e5693c1b
3 changed files with 41 additions and 67 deletions

View file

@ -13,7 +13,6 @@ from typing import Optional
from common.logger import get_loading_progress_bar
from common.networking import handle_request_error
from common.tabby_config import config
from common.utils import unwrap
from endpoints.utils import do_export_openapi
if not do_export_openapi:
@ -67,6 +66,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
logger.info("Unloading existing model.")
await unload_model()
# Merge with config defaults
kwargs = {**config.model_defaults, **kwargs}
# Create a new container
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
model_type = "draft" if container.draft_config else "model"
@ -149,25 +152,6 @@ async def unload_embedding_model():
embeddings_container = None
# FIXME: Maybe make this a one-time function instead of a dynamic default
def get_config_default(key: str, model_type: str = "model"):
"""Fetches a default value from model config if allowed by the user."""
default_keys = unwrap(config.model.get("use_as_default"), [])
# Add extra keys to defaults
default_keys.append("embeddings_device")
if key in default_keys:
# Is this a draft model load parameter?
if model_type == "draft":
return config.draft_model.get(key)
elif model_type == "embedding":
return config.embeddings.get(key)
else:
return config.model.get(key)
async def check_model_container():
"""FastAPI depends that checks if a model isn't loaded or currently loading."""