From dc3dcc9c0ddf721ee67a54b2395df271f0393d2a Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 15:32:26 -0400 Subject: [PATCH] Embeddings: Update config, args, and parameter names Use embeddings_device as the parameter for device to remove ambiguity. Signed-off-by: kingbri --- backends/infinity/model.py | 2 +- common/args.py | 20 ++++++++++++++++++++ common/config.py | 5 +++++ config_sample.yml | 23 ++++++++++++++++------- endpoints/core/types/model.py | 2 +- 5 files changed, 43 insertions(+), 9 deletions(-) diff --git a/backends/infinity/model.py b/backends/infinity/model.py index 4c9bb69..35a4df4 100644 --- a/backends/infinity/model.py +++ b/backends/infinity/model.py @@ -35,7 +35,7 @@ class InfinityContainer: self.model_is_loading = True # Use cpu by default - device = unwrap(kwargs.get("device"), "cpu") + device = unwrap(kwargs.get("embeddings_device"), "cpu") engine_args = EngineArgs( model_name_or_path=str(self.model_dir), diff --git a/common/args.py b/common/args.py index e57de78..0548eaf 100644 --- a/common/args.py +++ b/common/args.py @@ -23,6 +23,7 @@ def init_argparser(): ) add_network_args(parser) add_model_args(parser) + add_embeddings_args(parser) add_logging_args(parser) add_developer_args(parser) add_sampling_args(parser) @@ -209,3 +210,22 @@ def add_sampling_args(parser: argparse.ArgumentParser): sampling_group.add_argument( "--override-preset", type=str, help="Select a sampler override preset" ) + + +def add_embeddings_args(parser: argparse.ArgumentParser): + """Adds arguments specific to embeddings""" + + embeddings_group = parser.add_argument_group("embeddings") + embeddings_group.add_argument( + "--embedding-model-dir", + type=str, + help="Overrides the directory to look for models", + ) + embeddings_group.add_argument( + "--embedding-model-name", type=str, help="An initial model to load" + ) + embeddings_group.add_argument( + "--embeddings-device", + type=str, + help="Device to use for embeddings. Options: (cpu, auto, cuda)", + ) diff --git a/common/config.py b/common/config.py index 5546240..9b2f654 100644 --- a/common/config.py +++ b/common/config.py @@ -59,6 +59,11 @@ def from_args(args: dict): cur_developer_config = developer_config() GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override} + embeddings_override = args.get("embeddings") + if embeddings_override: + cur_embeddings_config = embeddings_config() + GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override} + def sampling_config(): """Returns the sampling parameter config from the global config""" diff --git a/config_sample.yml b/config_sample.yml index 71a58d2..09ae000 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -72,13 +72,6 @@ developer: # Otherwise, the priority will be set to high #realtime_process_priority: False -embeddings: - embedding_model_dir: models - - embedding_model_name: - - embeddings_device: cpu - # Options for model overrides and loading # Please read the comments to understand how arguments are handled between initial and API loads model: @@ -208,3 +201,19 @@ model: #loras: #- name: lora1 # scaling: 1.0 + +# Options for embedding models and loading. +# NOTE: Embeddings requires the "extras" feature to be installed +# Install it via "pip install .[extras]" +embeddings: + # Overrides directory to look for embedding models (default: models) + embedding_model_dir: models + + # An initial embedding model to load on the infinity backend (default: None) + embedding_model_name: + + # Device to load embedding models on (default: cpu) + # Possible values: cpu, auto, cuda + # NOTE: It's recommended to load embedding models on the CPU. + # If you'd like to load on an AMD gpu, set this value to "cuda" as well. + embeddings_device: cpu diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index c107dde..8b3d83e 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -139,7 +139,7 @@ class ModelLoadRequest(BaseModel): class EmbeddingModelLoadRequest(BaseModel): name: str - device: Optional[str] = None + embeddings_device: Optional[str] = None class ModelLoadResponse(BaseModel):