From 3e42211c3e646063fa76f73abb0a128ef36af980 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 1 Aug 2024 13:59:49 -0400 Subject: [PATCH] Config: Embeddings: Make embeddings_device a default when API loading When loading from the API, the fallback for embeddings_device will be the same as the config. Signed-off-by: kingbri --- common/model.py | 18 ++++++++++++++++-- config_sample.yml | 9 ++++++--- endpoints/core/types/model.py | 12 ++++++++---- 3 files changed, 30 insertions(+), 9 deletions(-) diff --git a/common/model.py b/common/model.py index 80858d4..feedc9f 100644 --- a/common/model.py +++ b/common/model.py @@ -5,6 +5,7 @@ Containers exist as a common interface for backends. """ import pathlib +from enum import Enum from fastapi import HTTPException from loguru import logger from typing import Optional @@ -31,6 +32,12 @@ if not do_export_openapi: embeddings_container: Optional[InfinityContainer] = None +class ModelType(Enum): + MODEL = "model" + DRAFT = "draft" + EMBEDDING = "embedding" + + def load_progress(module, modules): """Wrapper callback for load progress.""" yield module, modules @@ -142,16 +149,23 @@ async def unload_embedding_model(): embeddings_container = None -def get_config_default(key, fallback=None, is_draft=False): +def get_config_default(key: str, fallback=None, model_type: str = "model"): """Fetches a default value from model config if allowed by the user.""" model_config = config.model_config() default_keys = unwrap(model_config.get("use_as_default"), []) + + # Add extra keys to defaults + default_keys.append("embeddings_device") + if key in default_keys: # Is this a draft model load parameter? - if is_draft: + if model_type == "draft": draft_config = config.draft_model_config() return unwrap(draft_config.get(key), fallback) + elif model_type == "embedding": + embeddings_config = config.embeddings_config() + return unwrap(embeddings_config.get(key), fallback) else: return unwrap(model_config.get(key), fallback) else: diff --git a/config_sample.yml b/config_sample.yml index f3a1c51..018ff61 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -209,11 +209,14 @@ embeddings: # Overrides directory to look for embedding models (default: models) embedding_model_dir: models - # An initial embedding model to load on the infinity backend (default: None) - embedding_model_name: - # Device to load embedding models on (default: cpu) # Possible values: cpu, auto, cuda # NOTE: It's recommended to load embedding models on the CPU. # If you'd like to load on an AMD gpu, set this value to "cuda" as well. embeddings_device: cpu + + # The below parameters only apply for initial loads + # All API based loads do NOT inherit these settings unless specified in use_as_default + + # An initial embedding model to load on the infinity backend (default: None) + embedding_model_name: diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 8b3d83e..1e2eb46 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -53,19 +53,19 @@ class DraftModelLoadRequest(BaseModel): # Config arguments draft_rope_scale: Optional[float] = Field( default_factory=lambda: get_config_default( - "draft_rope_scale", 1.0, is_draft=True + "draft_rope_scale", 1.0, model_type="draft" ) ) draft_rope_alpha: Optional[float] = Field( description="Automatically calculated if not present", default_factory=lambda: get_config_default( - "draft_rope_alpha", None, is_draft=True + "draft_rope_alpha", None, model_type="draft" ), examples=[1.0], ) draft_cache_mode: Optional[str] = Field( default_factory=lambda: get_config_default( - "draft_cache_mode", "FP16", is_draft=True + "draft_cache_mode", "FP16", model_type="draft" ) ) @@ -139,7 +139,11 @@ class ModelLoadRequest(BaseModel): class EmbeddingModelLoadRequest(BaseModel): name: str - embeddings_device: Optional[str] = None + embeddings_device: Optional[str] = Field( + default_factory=lambda: get_config_default( + "embeddings_device", model_type="embedding" + ) + ) class ModelLoadResponse(BaseModel):