Embeddings: Update config, args, and parameter names

Use embeddings_device as the parameter for device to remove ambiguity.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-30 15:32:26 -04:00
parent bfa011e0ce
commit dc3dcc9c0d
5 changed files with 43 additions and 9 deletions

View file

@ -35,7 +35,7 @@ class InfinityContainer:
self.model_is_loading = True
# Use cpu by default
device = unwrap(kwargs.get("device"), "cpu")
device = unwrap(kwargs.get("embeddings_device"), "cpu")
engine_args = EngineArgs(
model_name_or_path=str(self.model_dir),

View file

@ -23,6 +23,7 @@ def init_argparser():
)
add_network_args(parser)
add_model_args(parser)
add_embeddings_args(parser)
add_logging_args(parser)
add_developer_args(parser)
add_sampling_args(parser)
@ -209,3 +210,22 @@ def add_sampling_args(parser: argparse.ArgumentParser):
sampling_group.add_argument(
"--override-preset", type=str, help="Select a sampler override preset"
)
def add_embeddings_args(parser: argparse.ArgumentParser):
"""Adds arguments specific to embeddings"""
embeddings_group = parser.add_argument_group("embeddings")
embeddings_group.add_argument(
"--embedding-model-dir",
type=str,
help="Overrides the directory to look for models",
)
embeddings_group.add_argument(
"--embedding-model-name", type=str, help="An initial model to load"
)
embeddings_group.add_argument(
"--embeddings-device",
type=str,
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
)

View file

@ -59,6 +59,11 @@ def from_args(args: dict):
cur_developer_config = developer_config()
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
embeddings_override = args.get("embeddings")
if embeddings_override:
cur_embeddings_config = embeddings_config()
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}
def sampling_config():
"""Returns the sampling parameter config from the global config"""

View file

@ -72,13 +72,6 @@ developer:
# Otherwise, the priority will be set to high
#realtime_process_priority: False
embeddings:
embedding_model_dir: models
embedding_model_name:
embeddings_device: cpu
# Options for model overrides and loading
# Please read the comments to understand how arguments are handled between initial and API loads
model:
@ -208,3 +201,19 @@ model:
#loras:
#- name: lora1
# scaling: 1.0
# Options for embedding models and loading.
# NOTE: Embeddings requires the "extras" feature to be installed
# Install it via "pip install .[extras]"
embeddings:
# Overrides directory to look for embedding models (default: models)
embedding_model_dir: models
# An initial embedding model to load on the infinity backend (default: None)
embedding_model_name:
# Device to load embedding models on (default: cpu)
# Possible values: cpu, auto, cuda
# NOTE: It's recommended to load embedding models on the CPU.
# If you'd like to load on an AMD gpu, set this value to "cuda" as well.
embeddings_device: cpu

View file

@ -139,7 +139,7 @@ class ModelLoadRequest(BaseModel):
class EmbeddingModelLoadRequest(BaseModel):
name: str
device: Optional[str] = None
embeddings_device: Optional[str] = None
class ModelLoadResponse(BaseModel):