diff --git a/common/model.py b/common/model.py index 3776ff9..80858d4 100644 --- a/common/model.py +++ b/common/model.py @@ -57,8 +57,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): f'Model "{loaded_model_name}" is already loaded! Aborting.' ) - # Unload the existing model - if container and container.model: logger.info("Unloading existing model.") await unload_model() @@ -109,24 +107,35 @@ async def unload_loras(): await container.unload(loras_only=True) -async def load_embeddings_model(model_path: pathlib.Path, **kwargs): +async def load_embedding_model(model_path: pathlib.Path, **kwargs): global embeddings_container # Break out if infinity isn't installed if not has_infinity_emb: - logger.warning( + raise ImportError( "Skipping embeddings because infinity-emb is not installed.\n" "Please run the following command in your environment " "to install extra packages:\n" "pip install -U .[extras]" ) - return + + # Check if the model is already loaded + if embeddings_container and embeddings_container.engine: + loaded_model_name = embeddings_container.model_dir.name + + if loaded_model_name == model_path.name and embeddings_container.model_loaded: + raise ValueError( + f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.' + ) + + logger.info("Unloading existing embeddings model.") + await unload_embedding_model() embeddings_container = InfinityContainer(model_path) await embeddings_container.load(**kwargs) -async def unload_embeddings_model(): +async def unload_embedding_model(): global embeddings_container await embeddings_container.unload() @@ -172,7 +181,7 @@ async def check_embeddings_container(): embeddings_container.model_is_loading or embeddings_container.model_loaded ): error_message = handle_request_error( - "No embeddings models are currently loaded.", + "No embedding models are currently loaded.", exc_info=False, ).error.message diff --git a/config_sample.yml b/config_sample.yml index 053feb6..71a58d2 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -73,9 +73,9 @@ developer: #realtime_process_priority: False embeddings: - embeddings_model_dir: models + embedding_model_dir: models - embeddings_model_name: + embedding_model_name: embeddings_device: cpu diff --git a/endpoints/core/router.py b/endpoints/core/router.py index cd0ed37..5aabd48 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -7,7 +7,7 @@ from sse_starlette import EventSourceResponse from common import config, model, sampling from common.auth import check_admin_key, check_api_key, get_key_permission from common.downloader import hf_repo_download -from common.model import check_model_container +from common.model import check_embeddings_container, check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.templating import PromptTemplate, get_all_templates from common.utils import unwrap @@ -15,6 +15,7 @@ from endpoints.core.types.auth import AuthPermissionResponse from endpoints.core.types.download import DownloadRequest, DownloadResponse from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse from endpoints.core.types.model import ( + EmbeddingModelLoadRequest, ModelCard, ModelList, ModelLoadRequest, @@ -253,6 +254,93 @@ async def unload_loras(): await model.unload_loras() +@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)]) +async def list_embedding_models(request: Request) -> ModelList: + """ + Lists all embedding models in the model directory. + + Requires an admin key to see all embedding models. + """ + + if get_key_permission(request) == "admin": + embedding_model_dir = unwrap( + config.embeddings_config().get("embedding_model_dir"), "models" + ) + embedding_model_path = pathlib.Path(embedding_model_dir) + + models = get_model_list(embedding_model_path.resolve()) + else: + models = await get_current_model_list(model_type="embedding") + + return models + + +@router.get( + "/v1/model/embedding", + dependencies=[Depends(check_api_key), Depends(check_embeddings_container)], +) +async def get_embedding_model() -> ModelList: + """Returns the currently loaded embedding model.""" + + return get_current_model_list(model_type="embedding")[0] + + +@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)]) +async def load_embedding_model( + request: Request, data: EmbeddingModelLoadRequest +) -> ModelLoadResponse: + # Verify request parameters + if not data.name: + error_message = handle_request_error( + "A model name was not provided for load.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + embedding_model_dir = pathlib.Path( + unwrap(config.model_config().get("embedding_model_dir"), "models") + ) + embedding_model_path = embedding_model_dir / data.name + + if not embedding_model_path.exists(): + error_message = handle_request_error( + "Could not find the embedding model path for load. " + + "Check model name or config.yml?", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + try: + load_task = asyncio.create_task( + model.load_embedding_model(embedding_model_path, **data.model_dump()) + ) + await run_with_request_disconnect( + request, load_task, "Embedding model load request cancelled by user." + ) + except Exception as exc: + error_message = handle_request_error(str(exc)).error.message + + raise HTTPException(400, error_message) from exc + + response = ModelLoadResponse( + model_type="embedding_model", module=1, modules=1, status="finished" + ) + + return response + + +@router.post( + "/v1/model/embedding/unload", + dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)], +) +async def unload_embedding_model(): + """Unloads the current embedding model.""" + + await model.unload_embedding_model() + + # Encode tokens endpoint @router.post( "/v1/token/encode", diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index 30730b8..c107dde 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -137,6 +137,11 @@ class ModelLoadRequest(BaseModel): skip_queue: Optional[bool] = False +class EmbeddingModelLoadRequest(BaseModel): + name: str + device: Optional[str] = None + + class ModelLoadResponse(BaseModel): """Represents a model load response.""" diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index 0cfb26a..fc61337 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -32,15 +32,26 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N return model_card_list -async def get_current_model_list(is_draft: bool = False): - """Gets the current model in list format and with path only.""" +async def get_current_model_list(model_type: str = "model"): + """ + Gets the current model in list format and with path only. + + Unified for fetching both models and embedding models. + """ + current_models = [] + model_path = None # Make sure the model container exists - if model.container: - model_path = model.container.get_model_path(is_draft) - if model_path: - current_models.append(ModelCard(id=model_path.name)) + if model_type == "model" or model_type == "draft": + if model.container: + model_path = model.container.get_model_path(model_type == "draft") + elif model_type == "embedding": + if model.embeddings_container: + model_path = model.embeddings_container.model_dir + + if model_path: + current_models.append(ModelCard(id=model_path.name)) return ModelList(data=current_models) diff --git a/main.py b/main.py index 56873c4..bae2f98 100644 --- a/main.py +++ b/main.py @@ -90,14 +90,17 @@ async def entrypoint_async(): # If an initial embedding model name is specified, create a separate container # and load the model embedding_config = config.embeddings_config() - embedding_model_name = embedding_config.get("embeddings_model_name") + embedding_model_name = embedding_config.get("embedding_model_name") if embedding_model_name: embedding_model_path = pathlib.Path( - unwrap(embedding_config.get("embeddings_model_dir"), "models") + unwrap(embedding_config.get("embedding_model_dir"), "models") ) embedding_model_path = embedding_model_path / embedding_model_name - await model.load_embeddings_model(embedding_model_path, **embedding_config) + try: + await model.load_embedding_model(embedding_model_path, **embedding_config) + except ImportError as ex: + logger.error(ex.msg) await start_api(host, port)