Config: Embeddings: Make embeddings_device a default when API loading
When loading from the API, the fallback for embeddings_device will be the same as the config. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
54aeebaec1
commit
3e42211c3e
3 changed files with 30 additions and 9 deletions
|
|
@ -5,6 +5,7 @@ Containers exist as a common interface for backends.
|
|||
"""
|
||||
|
||||
import pathlib
|
||||
from enum import Enum
|
||||
from fastapi import HTTPException
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
|
@ -31,6 +32,12 @@ if not do_export_openapi:
|
|||
embeddings_container: Optional[InfinityContainer] = None
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
MODEL = "model"
|
||||
DRAFT = "draft"
|
||||
EMBEDDING = "embedding"
|
||||
|
||||
|
||||
def load_progress(module, modules):
|
||||
"""Wrapper callback for load progress."""
|
||||
yield module, modules
|
||||
|
|
@ -142,16 +149,23 @@ async def unload_embedding_model():
|
|||
embeddings_container = None
|
||||
|
||||
|
||||
def get_config_default(key, fallback=None, is_draft=False):
|
||||
def get_config_default(key: str, fallback=None, model_type: str = "model"):
|
||||
"""Fetches a default value from model config if allowed by the user."""
|
||||
|
||||
model_config = config.model_config()
|
||||
default_keys = unwrap(model_config.get("use_as_default"), [])
|
||||
|
||||
# Add extra keys to defaults
|
||||
default_keys.append("embeddings_device")
|
||||
|
||||
if key in default_keys:
|
||||
# Is this a draft model load parameter?
|
||||
if is_draft:
|
||||
if model_type == "draft":
|
||||
draft_config = config.draft_model_config()
|
||||
return unwrap(draft_config.get(key), fallback)
|
||||
elif model_type == "embedding":
|
||||
embeddings_config = config.embeddings_config()
|
||||
return unwrap(embeddings_config.get(key), fallback)
|
||||
else:
|
||||
return unwrap(model_config.get(key), fallback)
|
||||
else:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue