Embeddings: Update config, args, and parameter names
Use embeddings_device as the parameter for device to remove ambiguity. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
bfa011e0ce
commit
dc3dcc9c0d
5 changed files with 43 additions and 9 deletions
|
|
@ -35,7 +35,7 @@ class InfinityContainer:
|
|||
self.model_is_loading = True
|
||||
|
||||
# Use cpu by default
|
||||
device = unwrap(kwargs.get("device"), "cpu")
|
||||
device = unwrap(kwargs.get("embeddings_device"), "cpu")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model_name_or_path=str(self.model_dir),
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ def init_argparser():
|
|||
)
|
||||
add_network_args(parser)
|
||||
add_model_args(parser)
|
||||
add_embeddings_args(parser)
|
||||
add_logging_args(parser)
|
||||
add_developer_args(parser)
|
||||
add_sampling_args(parser)
|
||||
|
|
@ -209,3 +210,22 @@ def add_sampling_args(parser: argparse.ArgumentParser):
|
|||
sampling_group.add_argument(
|
||||
"--override-preset", type=str, help="Select a sampler override preset"
|
||||
)
|
||||
|
||||
|
||||
def add_embeddings_args(parser: argparse.ArgumentParser):
|
||||
"""Adds arguments specific to embeddings"""
|
||||
|
||||
embeddings_group = parser.add_argument_group("embeddings")
|
||||
embeddings_group.add_argument(
|
||||
"--embedding-model-dir",
|
||||
type=str,
|
||||
help="Overrides the directory to look for models",
|
||||
)
|
||||
embeddings_group.add_argument(
|
||||
"--embedding-model-name", type=str, help="An initial model to load"
|
||||
)
|
||||
embeddings_group.add_argument(
|
||||
"--embeddings-device",
|
||||
type=str,
|
||||
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -59,6 +59,11 @@ def from_args(args: dict):
|
|||
cur_developer_config = developer_config()
|
||||
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
|
||||
|
||||
embeddings_override = args.get("embeddings")
|
||||
if embeddings_override:
|
||||
cur_embeddings_config = embeddings_config()
|
||||
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}
|
||||
|
||||
|
||||
def sampling_config():
|
||||
"""Returns the sampling parameter config from the global config"""
|
||||
|
|
|
|||
|
|
@ -72,13 +72,6 @@ developer:
|
|||
# Otherwise, the priority will be set to high
|
||||
#realtime_process_priority: False
|
||||
|
||||
embeddings:
|
||||
embedding_model_dir: models
|
||||
|
||||
embedding_model_name:
|
||||
|
||||
embeddings_device: cpu
|
||||
|
||||
# Options for model overrides and loading
|
||||
# Please read the comments to understand how arguments are handled between initial and API loads
|
||||
model:
|
||||
|
|
@ -208,3 +201,19 @@ model:
|
|||
#loras:
|
||||
#- name: lora1
|
||||
# scaling: 1.0
|
||||
|
||||
# Options for embedding models and loading.
|
||||
# NOTE: Embeddings requires the "extras" feature to be installed
|
||||
# Install it via "pip install .[extras]"
|
||||
embeddings:
|
||||
# Overrides directory to look for embedding models (default: models)
|
||||
embedding_model_dir: models
|
||||
|
||||
# An initial embedding model to load on the infinity backend (default: None)
|
||||
embedding_model_name:
|
||||
|
||||
# Device to load embedding models on (default: cpu)
|
||||
# Possible values: cpu, auto, cuda
|
||||
# NOTE: It's recommended to load embedding models on the CPU.
|
||||
# If you'd like to load on an AMD gpu, set this value to "cuda" as well.
|
||||
embeddings_device: cpu
|
||||
|
|
|
|||
|
|
@ -139,7 +139,7 @@ class ModelLoadRequest(BaseModel):
|
|||
|
||||
class EmbeddingModelLoadRequest(BaseModel):
|
||||
name: str
|
||||
device: Optional[str] = None
|
||||
embeddings_device: Optional[str] = None
|
||||
|
||||
|
||||
class ModelLoadResponse(BaseModel):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue