API + Model: Apply config.yml defaults for all load paths
There are two ways to load a model: 1. Via the load endpoint 2. Inline with a completion The defaults were not applying on the inline load, so rewrite to fix that. However, while doing this, set up a defaults dictionary rather than comparing it at runtime and remove the pydantic default lambda on all the model load fields. This makes the code cleaner and establishes a clear config tree for loading models. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
7baef05b49
commit
b9e5693c1b
3 changed files with 41 additions and 67 deletions
|
|
@ -13,7 +13,6 @@ from typing import Optional
|
|||
from common.logger import get_loading_progress_bar
|
||||
from common.networking import handle_request_error
|
||||
from common.tabby_config import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.utils import do_export_openapi
|
||||
|
||||
if not do_export_openapi:
|
||||
|
|
@ -67,6 +66,10 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
logger.info("Unloading existing model.")
|
||||
await unload_model()
|
||||
|
||||
# Merge with config defaults
|
||||
kwargs = {**config.model_defaults, **kwargs}
|
||||
|
||||
# Create a new container
|
||||
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
|
||||
|
||||
model_type = "draft" if container.draft_config else "model"
|
||||
|
|
@ -149,25 +152,6 @@ async def unload_embedding_model():
|
|||
embeddings_container = None
|
||||
|
||||
|
||||
# FIXME: Maybe make this a one-time function instead of a dynamic default
|
||||
def get_config_default(key: str, model_type: str = "model"):
|
||||
"""Fetches a default value from model config if allowed by the user."""
|
||||
|
||||
default_keys = unwrap(config.model.get("use_as_default"), [])
|
||||
|
||||
# Add extra keys to defaults
|
||||
default_keys.append("embeddings_device")
|
||||
|
||||
if key in default_keys:
|
||||
# Is this a draft model load parameter?
|
||||
if model_type == "draft":
|
||||
return config.draft_model.get(key)
|
||||
elif model_type == "embedding":
|
||||
return config.embeddings.get(key)
|
||||
else:
|
||||
return config.model.get(key)
|
||||
|
||||
|
||||
async def check_model_container():
|
||||
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,9 @@ from common.utils import unwrap, merge_dicts
|
|||
|
||||
|
||||
class TabbyConfig:
|
||||
"""Common config class for TabbyAPI. Loaded into sub-dictionaries from YAML file."""
|
||||
|
||||
# Sub-blocks of yaml
|
||||
network: dict = {}
|
||||
logging: dict = {}
|
||||
model: dict = {}
|
||||
|
|
@ -16,6 +19,9 @@ class TabbyConfig:
|
|||
developer: dict = {}
|
||||
embeddings: dict = {}
|
||||
|
||||
# Persistent defaults
|
||||
model_defaults: dict = {}
|
||||
|
||||
def load(self, arguments: Optional[dict] = None):
|
||||
"""Synchronously loads the global application config"""
|
||||
|
||||
|
|
@ -36,6 +42,14 @@ class TabbyConfig:
|
|||
self.developer = unwrap(merged_config.get("developer"), {})
|
||||
self.embeddings = unwrap(merged_config.get("embeddings"), {})
|
||||
|
||||
# Set model defaults dict once to prevent on-demand reconstruction
|
||||
default_keys = unwrap(self.model.get("use_as_default"), [])
|
||||
for key in default_keys:
|
||||
if key in self.model:
|
||||
self.model_defaults[key] = config.model[key]
|
||||
elif key in self.draft_model:
|
||||
self.model_defaults[key] = config.draft_model[key]
|
||||
|
||||
def _from_file(self, config_path: pathlib.Path):
|
||||
"""loads config from a given file path"""
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue