Config: Embeddings: Make embeddings_device a default when API loading

When loading from the API, the fallback for embeddings_device will be
the same as the config.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-08-01 13:59:49 -04:00
parent 54aeebaec1
commit 3e42211c3e
3 changed files with 30 additions and 9 deletions

View file

@ -5,6 +5,7 @@ Containers exist as a common interface for backends.
""" """
import pathlib import pathlib
from enum import Enum
from fastapi import HTTPException from fastapi import HTTPException
from loguru import logger from loguru import logger
from typing import Optional from typing import Optional
@ -31,6 +32,12 @@ if not do_export_openapi:
embeddings_container: Optional[InfinityContainer] = None embeddings_container: Optional[InfinityContainer] = None
class ModelType(Enum):
MODEL = "model"
DRAFT = "draft"
EMBEDDING = "embedding"
def load_progress(module, modules): def load_progress(module, modules):
"""Wrapper callback for load progress.""" """Wrapper callback for load progress."""
yield module, modules yield module, modules
@ -142,16 +149,23 @@ async def unload_embedding_model():
embeddings_container = None embeddings_container = None
def get_config_default(key, fallback=None, is_draft=False): def get_config_default(key: str, fallback=None, model_type: str = "model"):
"""Fetches a default value from model config if allowed by the user.""" """Fetches a default value from model config if allowed by the user."""
model_config = config.model_config() model_config = config.model_config()
default_keys = unwrap(model_config.get("use_as_default"), []) default_keys = unwrap(model_config.get("use_as_default"), [])
# Add extra keys to defaults
default_keys.append("embeddings_device")
if key in default_keys: if key in default_keys:
# Is this a draft model load parameter? # Is this a draft model load parameter?
if is_draft: if model_type == "draft":
draft_config = config.draft_model_config() draft_config = config.draft_model_config()
return unwrap(draft_config.get(key), fallback) return unwrap(draft_config.get(key), fallback)
elif model_type == "embedding":
embeddings_config = config.embeddings_config()
return unwrap(embeddings_config.get(key), fallback)
else: else:
return unwrap(model_config.get(key), fallback) return unwrap(model_config.get(key), fallback)
else: else:

View file

@ -209,11 +209,14 @@ embeddings:
# Overrides directory to look for embedding models (default: models) # Overrides directory to look for embedding models (default: models)
embedding_model_dir: models embedding_model_dir: models
# An initial embedding model to load on the infinity backend (default: None)
embedding_model_name:
# Device to load embedding models on (default: cpu) # Device to load embedding models on (default: cpu)
# Possible values: cpu, auto, cuda # Possible values: cpu, auto, cuda
# NOTE: It's recommended to load embedding models on the CPU. # NOTE: It's recommended to load embedding models on the CPU.
# If you'd like to load on an AMD gpu, set this value to "cuda" as well. # If you'd like to load on an AMD gpu, set this value to "cuda" as well.
embeddings_device: cpu embeddings_device: cpu
# The below parameters only apply for initial loads
# All API based loads do NOT inherit these settings unless specified in use_as_default
# An initial embedding model to load on the infinity backend (default: None)
embedding_model_name:

View file

@ -53,19 +53,19 @@ class DraftModelLoadRequest(BaseModel):
# Config arguments # Config arguments
draft_rope_scale: Optional[float] = Field( draft_rope_scale: Optional[float] = Field(
default_factory=lambda: get_config_default( default_factory=lambda: get_config_default(
"draft_rope_scale", 1.0, is_draft=True "draft_rope_scale", 1.0, model_type="draft"
) )
) )
draft_rope_alpha: Optional[float] = Field( draft_rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present", description="Automatically calculated if not present",
default_factory=lambda: get_config_default( default_factory=lambda: get_config_default(
"draft_rope_alpha", None, is_draft=True "draft_rope_alpha", None, model_type="draft"
), ),
examples=[1.0], examples=[1.0],
) )
draft_cache_mode: Optional[str] = Field( draft_cache_mode: Optional[str] = Field(
default_factory=lambda: get_config_default( default_factory=lambda: get_config_default(
"draft_cache_mode", "FP16", is_draft=True "draft_cache_mode", "FP16", model_type="draft"
) )
) )
@ -139,7 +139,11 @@ class ModelLoadRequest(BaseModel):
class EmbeddingModelLoadRequest(BaseModel): class EmbeddingModelLoadRequest(BaseModel):
name: str name: str
embeddings_device: Optional[str] = None embeddings_device: Optional[str] = Field(
default_factory=lambda: get_config_default(
"embeddings_device", model_type="embedding"
)
)
class ModelLoadResponse(BaseModel): class ModelLoadResponse(BaseModel):