Embeddings: Add model management
Embedding models are managed on a separate backend, but are run in parallel with the model itself. Therefore, manage this in a separate container with separate routes. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
f13d0fb8b3
commit
bfa011e0ce
6 changed files with 135 additions and 19 deletions
|
|
@ -57,8 +57,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
f'Model "{loaded_model_name}" is already loaded! Aborting.'
|
||||
)
|
||||
|
||||
# Unload the existing model
|
||||
if container and container.model:
|
||||
logger.info("Unloading existing model.")
|
||||
await unload_model()
|
||||
|
||||
|
|
@ -109,24 +107,35 @@ async def unload_loras():
|
|||
await container.unload(loras_only=True)
|
||||
|
||||
|
||||
async def load_embeddings_model(model_path: pathlib.Path, **kwargs):
|
||||
async def load_embedding_model(model_path: pathlib.Path, **kwargs):
|
||||
global embeddings_container
|
||||
|
||||
# Break out if infinity isn't installed
|
||||
if not has_infinity_emb:
|
||||
logger.warning(
|
||||
raise ImportError(
|
||||
"Skipping embeddings because infinity-emb is not installed.\n"
|
||||
"Please run the following command in your environment "
|
||||
"to install extra packages:\n"
|
||||
"pip install -U .[extras]"
|
||||
)
|
||||
return
|
||||
|
||||
# Check if the model is already loaded
|
||||
if embeddings_container and embeddings_container.engine:
|
||||
loaded_model_name = embeddings_container.model_dir.name
|
||||
|
||||
if loaded_model_name == model_path.name and embeddings_container.model_loaded:
|
||||
raise ValueError(
|
||||
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
|
||||
)
|
||||
|
||||
logger.info("Unloading existing embeddings model.")
|
||||
await unload_embedding_model()
|
||||
|
||||
embeddings_container = InfinityContainer(model_path)
|
||||
await embeddings_container.load(**kwargs)
|
||||
|
||||
|
||||
async def unload_embeddings_model():
|
||||
async def unload_embedding_model():
|
||||
global embeddings_container
|
||||
|
||||
await embeddings_container.unload()
|
||||
|
|
@ -172,7 +181,7 @@ async def check_embeddings_container():
|
|||
embeddings_container.model_is_loading or embeddings_container.model_loaded
|
||||
):
|
||||
error_message = handle_request_error(
|
||||
"No embeddings models are currently loaded.",
|
||||
"No embedding models are currently loaded.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
|
|
|
|||
|
|
@ -73,9 +73,9 @@ developer:
|
|||
#realtime_process_priority: False
|
||||
|
||||
embeddings:
|
||||
embeddings_model_dir: models
|
||||
embedding_model_dir: models
|
||||
|
||||
embeddings_model_name:
|
||||
embedding_model_name:
|
||||
|
||||
embeddings_device: cpu
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from sse_starlette import EventSourceResponse
|
|||
from common import config, model, sampling
|
||||
from common.auth import check_admin_key, check_api_key, get_key_permission
|
||||
from common.downloader import hf_repo_download
|
||||
from common.model import check_model_container
|
||||
from common.model import check_embeddings_container, check_model_container
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import unwrap
|
||||
|
|
@ -15,6 +15,7 @@ from endpoints.core.types.auth import AuthPermissionResponse
|
|||
from endpoints.core.types.download import DownloadRequest, DownloadResponse
|
||||
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
|
||||
from endpoints.core.types.model import (
|
||||
EmbeddingModelLoadRequest,
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelLoadRequest,
|
||||
|
|
@ -253,6 +254,93 @@ async def unload_loras():
|
|||
await model.unload_loras()
|
||||
|
||||
|
||||
@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_embedding_models(request: Request) -> ModelList:
|
||||
"""
|
||||
Lists all embedding models in the model directory.
|
||||
|
||||
Requires an admin key to see all embedding models.
|
||||
"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
embedding_model_dir = unwrap(
|
||||
config.embeddings_config().get("embedding_model_dir"), "models"
|
||||
)
|
||||
embedding_model_path = pathlib.Path(embedding_model_dir)
|
||||
|
||||
models = get_model_list(embedding_model_path.resolve())
|
||||
else:
|
||||
models = await get_current_model_list(model_type="embedding")
|
||||
|
||||
return models
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/model/embedding",
|
||||
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
|
||||
)
|
||||
async def get_embedding_model() -> ModelList:
|
||||
"""Returns the currently loaded embedding model."""
|
||||
|
||||
return get_current_model_list(model_type="embedding")[0]
|
||||
|
||||
|
||||
@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)])
|
||||
async def load_embedding_model(
|
||||
request: Request, data: EmbeddingModelLoadRequest
|
||||
) -> ModelLoadResponse:
|
||||
# Verify request parameters
|
||||
if not data.name:
|
||||
error_message = handle_request_error(
|
||||
"A model name was not provided for load.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
embedding_model_dir = pathlib.Path(
|
||||
unwrap(config.model_config().get("embedding_model_dir"), "models")
|
||||
)
|
||||
embedding_model_path = embedding_model_dir / data.name
|
||||
|
||||
if not embedding_model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
"Could not find the embedding model path for load. "
|
||||
+ "Check model name or config.yml?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
try:
|
||||
load_task = asyncio.create_task(
|
||||
model.load_embedding_model(embedding_model_path, **data.model_dump())
|
||||
)
|
||||
await run_with_request_disconnect(
|
||||
request, load_task, "Embedding model load request cancelled by user."
|
||||
)
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(str(exc)).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
response = ModelLoadResponse(
|
||||
model_type="embedding_model", module=1, modules=1, status="finished"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/model/embedding/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)],
|
||||
)
|
||||
async def unload_embedding_model():
|
||||
"""Unloads the current embedding model."""
|
||||
|
||||
await model.unload_embedding_model()
|
||||
|
||||
|
||||
# Encode tokens endpoint
|
||||
@router.post(
|
||||
"/v1/token/encode",
|
||||
|
|
|
|||
|
|
@ -137,6 +137,11 @@ class ModelLoadRequest(BaseModel):
|
|||
skip_queue: Optional[bool] = False
|
||||
|
||||
|
||||
class EmbeddingModelLoadRequest(BaseModel):
|
||||
name: str
|
||||
device: Optional[str] = None
|
||||
|
||||
|
||||
class ModelLoadResponse(BaseModel):
|
||||
"""Represents a model load response."""
|
||||
|
||||
|
|
|
|||
|
|
@ -32,15 +32,26 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
|
|||
return model_card_list
|
||||
|
||||
|
||||
async def get_current_model_list(is_draft: bool = False):
|
||||
"""Gets the current model in list format and with path only."""
|
||||
async def get_current_model_list(model_type: str = "model"):
|
||||
"""
|
||||
Gets the current model in list format and with path only.
|
||||
|
||||
Unified for fetching both models and embedding models.
|
||||
"""
|
||||
|
||||
current_models = []
|
||||
model_path = None
|
||||
|
||||
# Make sure the model container exists
|
||||
if model.container:
|
||||
model_path = model.container.get_model_path(is_draft)
|
||||
if model_path:
|
||||
current_models.append(ModelCard(id=model_path.name))
|
||||
if model_type == "model" or model_type == "draft":
|
||||
if model.container:
|
||||
model_path = model.container.get_model_path(model_type == "draft")
|
||||
elif model_type == "embedding":
|
||||
if model.embeddings_container:
|
||||
model_path = model.embeddings_container.model_dir
|
||||
|
||||
if model_path:
|
||||
current_models.append(ModelCard(id=model_path.name))
|
||||
|
||||
return ModelList(data=current_models)
|
||||
|
||||
|
|
|
|||
9
main.py
9
main.py
|
|
@ -90,14 +90,17 @@ async def entrypoint_async():
|
|||
# If an initial embedding model name is specified, create a separate container
|
||||
# and load the model
|
||||
embedding_config = config.embeddings_config()
|
||||
embedding_model_name = embedding_config.get("embeddings_model_name")
|
||||
embedding_model_name = embedding_config.get("embedding_model_name")
|
||||
if embedding_model_name:
|
||||
embedding_model_path = pathlib.Path(
|
||||
unwrap(embedding_config.get("embeddings_model_dir"), "models")
|
||||
unwrap(embedding_config.get("embedding_model_dir"), "models")
|
||||
)
|
||||
embedding_model_path = embedding_model_path / embedding_model_name
|
||||
|
||||
await model.load_embeddings_model(embedding_model_path, **embedding_config)
|
||||
try:
|
||||
await model.load_embedding_model(embedding_model_path, **embedding_config)
|
||||
except ImportError as ex:
|
||||
logger.error(ex.msg)
|
||||
|
||||
await start_api(host, port)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue