Config: Embeddings: Make embeddings_device a default when API loading
When loading from the API, the fallback for embeddings_device will be the same as the config. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
54aeebaec1
commit
3e42211c3e
3 changed files with 30 additions and 9 deletions
|
|
@ -5,6 +5,7 @@ Containers exist as a common interface for backends.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
|
from enum import Enum
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
@ -31,6 +32,12 @@ if not do_export_openapi:
|
||||||
embeddings_container: Optional[InfinityContainer] = None
|
embeddings_container: Optional[InfinityContainer] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelType(Enum):
|
||||||
|
MODEL = "model"
|
||||||
|
DRAFT = "draft"
|
||||||
|
EMBEDDING = "embedding"
|
||||||
|
|
||||||
|
|
||||||
def load_progress(module, modules):
|
def load_progress(module, modules):
|
||||||
"""Wrapper callback for load progress."""
|
"""Wrapper callback for load progress."""
|
||||||
yield module, modules
|
yield module, modules
|
||||||
|
|
@ -142,16 +149,23 @@ async def unload_embedding_model():
|
||||||
embeddings_container = None
|
embeddings_container = None
|
||||||
|
|
||||||
|
|
||||||
def get_config_default(key, fallback=None, is_draft=False):
|
def get_config_default(key: str, fallback=None, model_type: str = "model"):
|
||||||
"""Fetches a default value from model config if allowed by the user."""
|
"""Fetches a default value from model config if allowed by the user."""
|
||||||
|
|
||||||
model_config = config.model_config()
|
model_config = config.model_config()
|
||||||
default_keys = unwrap(model_config.get("use_as_default"), [])
|
default_keys = unwrap(model_config.get("use_as_default"), [])
|
||||||
|
|
||||||
|
# Add extra keys to defaults
|
||||||
|
default_keys.append("embeddings_device")
|
||||||
|
|
||||||
if key in default_keys:
|
if key in default_keys:
|
||||||
# Is this a draft model load parameter?
|
# Is this a draft model load parameter?
|
||||||
if is_draft:
|
if model_type == "draft":
|
||||||
draft_config = config.draft_model_config()
|
draft_config = config.draft_model_config()
|
||||||
return unwrap(draft_config.get(key), fallback)
|
return unwrap(draft_config.get(key), fallback)
|
||||||
|
elif model_type == "embedding":
|
||||||
|
embeddings_config = config.embeddings_config()
|
||||||
|
return unwrap(embeddings_config.get(key), fallback)
|
||||||
else:
|
else:
|
||||||
return unwrap(model_config.get(key), fallback)
|
return unwrap(model_config.get(key), fallback)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -209,11 +209,14 @@ embeddings:
|
||||||
# Overrides directory to look for embedding models (default: models)
|
# Overrides directory to look for embedding models (default: models)
|
||||||
embedding_model_dir: models
|
embedding_model_dir: models
|
||||||
|
|
||||||
# An initial embedding model to load on the infinity backend (default: None)
|
|
||||||
embedding_model_name:
|
|
||||||
|
|
||||||
# Device to load embedding models on (default: cpu)
|
# Device to load embedding models on (default: cpu)
|
||||||
# Possible values: cpu, auto, cuda
|
# Possible values: cpu, auto, cuda
|
||||||
# NOTE: It's recommended to load embedding models on the CPU.
|
# NOTE: It's recommended to load embedding models on the CPU.
|
||||||
# If you'd like to load on an AMD gpu, set this value to "cuda" as well.
|
# If you'd like to load on an AMD gpu, set this value to "cuda" as well.
|
||||||
embeddings_device: cpu
|
embeddings_device: cpu
|
||||||
|
|
||||||
|
# The below parameters only apply for initial loads
|
||||||
|
# All API based loads do NOT inherit these settings unless specified in use_as_default
|
||||||
|
|
||||||
|
# An initial embedding model to load on the infinity backend (default: None)
|
||||||
|
embedding_model_name:
|
||||||
|
|
|
||||||
|
|
@ -53,19 +53,19 @@ class DraftModelLoadRequest(BaseModel):
|
||||||
# Config arguments
|
# Config arguments
|
||||||
draft_rope_scale: Optional[float] = Field(
|
draft_rope_scale: Optional[float] = Field(
|
||||||
default_factory=lambda: get_config_default(
|
default_factory=lambda: get_config_default(
|
||||||
"draft_rope_scale", 1.0, is_draft=True
|
"draft_rope_scale", 1.0, model_type="draft"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
draft_rope_alpha: Optional[float] = Field(
|
draft_rope_alpha: Optional[float] = Field(
|
||||||
description="Automatically calculated if not present",
|
description="Automatically calculated if not present",
|
||||||
default_factory=lambda: get_config_default(
|
default_factory=lambda: get_config_default(
|
||||||
"draft_rope_alpha", None, is_draft=True
|
"draft_rope_alpha", None, model_type="draft"
|
||||||
),
|
),
|
||||||
examples=[1.0],
|
examples=[1.0],
|
||||||
)
|
)
|
||||||
draft_cache_mode: Optional[str] = Field(
|
draft_cache_mode: Optional[str] = Field(
|
||||||
default_factory=lambda: get_config_default(
|
default_factory=lambda: get_config_default(
|
||||||
"draft_cache_mode", "FP16", is_draft=True
|
"draft_cache_mode", "FP16", model_type="draft"
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -139,7 +139,11 @@ class ModelLoadRequest(BaseModel):
|
||||||
|
|
||||||
class EmbeddingModelLoadRequest(BaseModel):
|
class EmbeddingModelLoadRequest(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
embeddings_device: Optional[str] = None
|
embeddings_device: Optional[str] = Field(
|
||||||
|
default_factory=lambda: get_config_default(
|
||||||
|
"embeddings_device", model_type="embedding"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ModelLoadResponse(BaseModel):
|
class ModelLoadResponse(BaseModel):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue