From fbf1455db18a3ac2f4f312796150e99536d2361c Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 30 Jul 2024 11:00:23 -0400 Subject: [PATCH] Embeddings: Migrate and organize Infinity Use Infinity as a separate backend and handle the model within the common module. This separates out the embeddings model from the endpoint which allows for model loading/unloading in core. Signed-off-by: kingbri --- backends/infinity/model.py | 56 +++++++++++++++ common/model.py | 45 ++++++++++++ common/signals.py | 13 ++++ endpoints/OAI/router.py | 11 ++- endpoints/OAI/utils/embeddings.py | 111 +++++++++--------------------- main.py | 12 ++++ 6 files changed, 165 insertions(+), 83 deletions(-) create mode 100644 backends/infinity/model.py diff --git a/backends/infinity/model.py b/backends/infinity/model.py new file mode 100644 index 0000000..2d4ae83 --- /dev/null +++ b/backends/infinity/model.py @@ -0,0 +1,56 @@ +import gc +import pathlib +import torch +from typing import List, Optional + +from common.utils import unwrap + +# Conditionally import infinity to sidestep its logger +# TODO: Make this prettier +try: + from infinity_emb import EngineArgs, AsyncEmbeddingEngine + + has_infinity_emb = True +except ImportError: + has_infinity_emb = False + + +class InfinityContainer: + model_dir: pathlib.Path + + # Conditionally set the type hint based on importablity + # TODO: Clean this up + if has_infinity_emb: + engine: Optional[AsyncEmbeddingEngine] = None + else: + engine = None + + def __init__(self, model_directory: pathlib.Path): + self.model_dir = model_directory + + async def load(self, **kwargs): + # Use cpu by default + device = unwrap(kwargs.get("device"), "cpu") + + engine_args = EngineArgs( + model_name_or_path=str(self.model_dir), + engine="torch", + device=device, + bettertransformer=False, + model_warmup=False, + ) + + self.engine = AsyncEmbeddingEngine.from_args(engine_args) + await self.engine.astart() + + async def unload(self): + await self.engine.astop() + self.engine = None + + gc.collect() + torch.cuda.empty_cache() + + async def generate(self, sentence_input: List[str]): + result_embeddings, usage = await self.engine.embed(sentence_input) + + return {"embeddings": result_embeddings, "usage": usage} diff --git a/common/model.py b/common/model.py index a6477c2..b4b259e 100644 --- a/common/model.py +++ b/common/model.py @@ -20,6 +20,15 @@ if not do_export_openapi: # Global model container container: Optional[ExllamaV2Container] = None + embeddings_container = None + + # Type hint the infinity emb container if it exists + from backends.infinity.model import has_infinity_emb + + if has_infinity_emb: + from backends.infinity.model import InfinityContainer + + embeddings_container: Optional[InfinityContainer] = None def load_progress(module, modules): @@ -100,6 +109,30 @@ async def unload_loras(): await container.unload(loras_only=True) +async def load_embeddings_model(model_path: pathlib.Path, **kwargs): + global embeddings_container + + # Break out if infinity isn't installed + if not has_infinity_emb: + logger.warning( + "Skipping embeddings because infinity-emb is not installed.\n" + "Please run the following command in your environment " + "to install extra packages:\n" + "pip install -U .[extras]" + ) + return + + embeddings_container = InfinityContainer(model_path) + await embeddings_container.load(**kwargs) + + +async def unload_embeddings_model(): + global embeddings_container + + await embeddings_container.unload() + embeddings_container = None + + def get_config_default(key, fallback=None, is_draft=False): """Fetches a default value from model config if allowed by the user.""" @@ -126,3 +159,15 @@ async def check_model_container(): ).error.message raise HTTPException(400, error_message) + + +async def check_embeddings_container(): + """FastAPI depends that checks if an embeddings model is loaded.""" + + if embeddings_container is None: + error_message = handle_request_error( + "No embeddings models are currently loaded.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) diff --git a/common/signals.py b/common/signals.py index 07d7564..d4b67bc 100644 --- a/common/signals.py +++ b/common/signals.py @@ -1,13 +1,26 @@ +import asyncio import signal import sys from loguru import logger from types import FrameType +from common import model + def signal_handler(*_): """Signal handler for main function. Run before uvicorn starts.""" logger.warning("Shutdown signal called. Exiting gracefully.") + + # Run async unloads for model + loop = asyncio.get_running_loop() + if model.container: + loop.create_task(model.container.unload()) + + if model.embeddings_container: + loop.create_task(model.embeddings_container.unload()) + + # Exit the program sys.exit(0) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 2cad876..b702e52 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -23,7 +23,7 @@ from endpoints.OAI.utils.completion import ( generate_completion, stream_generate_completion, ) -from endpoints.OAI.utils.embeddings import embeddings +from endpoints.OAI.utils.embeddings import get_embeddings router = APIRouter() @@ -134,7 +134,12 @@ async def chat_completion_request( "/v1/embeddings", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) -async def handle_embeddings(data: EmbeddingsRequest) -> EmbeddingsResponse: - response = await embeddings(data) +async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse: + embeddings_task = asyncio.create_task(get_embeddings(data, request)) + response = await run_with_request_disconnect( + request, + embeddings_task, + f"Embeddings request {request.state.id} cancelled by user.", + ) return response diff --git a/endpoints/OAI/utils/embeddings.py b/endpoints/OAI/utils/embeddings.py index 1ce611c..5b43953 100644 --- a/endpoints/OAI/utils/embeddings.py +++ b/endpoints/OAI/utils/embeddings.py @@ -8,11 +8,11 @@ embeddings function declared async. """ import base64 -import pathlib +from fastapi import Request import numpy as np from loguru import logger -from common import config +from common import model from endpoints.OAI.types.embedding import ( EmbeddingObject, EmbeddingsRequest, @@ -20,84 +20,6 @@ from endpoints.OAI.types.embedding import ( UsageInfo, ) -# Conditionally import infinity embeddings engine -# Required so the logger doesn't take over tabby's logging handlers -try: - from infinity_emb import EngineArgs, AsyncEmbeddingEngine - - has_infinity_emb = True -except ImportError: - has_infinity_emb = False - - -embeddings_model = None - - -def load_embedding_model(model_path: pathlib.Path, device: str): - if not has_infinity_emb: - logger.error( - "Skipping embeddings because infinity-emb is not installed.\n" - "Please run the following command in your environment " - "to install extra packages:\n" - "pip install -U .[extras]" - ) - raise ModuleNotFoundError - - global embeddings_model - try: - engine_args = EngineArgs( - model_name_or_path=str(model_path.resolve()), - engine="torch", - device="cpu", - bettertransformer=False, - model_warmup=False, - ) - embeddings_model = AsyncEmbeddingEngine.from_args(engine_args) - logger.info(f"Trying to load embeddings model: {model_path.name} on {device}") - except Exception as e: - embeddings_model = None - raise e - - -async def embeddings(data: EmbeddingsRequest) -> dict: - embeddings_config = config.embeddings_config() - - # Use CPU by default - device = embeddings_config.get("embeddings_device", "cpu") - if device == "auto": - device = None - - model_path = pathlib.Path(embeddings_config.get("embeddings_model_dir")) - model_path: pathlib.Path = model_path / embeddings_config.get( - "embeddings_model_name" - ) - if not model_path: - logger.info("Embeddings model path not found") - - load_embedding_model(model_path, device) - - async with embeddings_model: - embeddings, usage = await embeddings_model.embed(data.input) - - # OAI expects a return of base64 if the input is base64 - embedding_data = [ - EmbeddingObject( - embedding=float_list_to_base64(emb) - if data.encoding_format == "base64" - else emb.tolist(), - index=n, - ) - for n, emb in enumerate(embeddings) - ] - - response = EmbeddingsResponse( - data=embedding_data, - model=model_path.name, - usage=UsageInfo(prompt_tokens=usage, total_tokens=usage), - ) - - return response - def float_list_to_base64(float_array: np.ndarray) -> str: """ @@ -111,3 +33,32 @@ def float_list_to_base64(float_array: np.ndarray) -> str: # Turn raw base64 encoded bytes into ASCII ascii_string = encoded_bytes.decode("ascii") return ascii_string + + +async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict: + model_path = model.embeddings_container.model_dir + + logger.info(f"Recieved embeddings request {request.state.id}") + embedding_data = await model.embeddings_container.generate(data.input) + + # OAI expects a return of base64 if the input is base64 + embedding_object = [ + EmbeddingObject( + embedding=float_list_to_base64(emb) + if data.encoding_format == "base64" + else emb.tolist(), + index=n, + ) + for n, emb in enumerate(embedding_data.get("embeddings")) + ] + + usage = embedding_data.get("usage") + response = EmbeddingsResponse( + data=embedding_object, + model=model_path.name, + usage=UsageInfo(prompt_tokens=usage, total_tokens=usage), + ) + + logger.info(f"Finished embeddings request {request.state.id}") + + return response diff --git a/main.py b/main.py index c62a381..56873c4 100644 --- a/main.py +++ b/main.py @@ -87,6 +87,18 @@ async def entrypoint_async(): lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras")) await model.container.load_loras(lora_dir.resolve(), **lora_config) + # If an initial embedding model name is specified, create a separate container + # and load the model + embedding_config = config.embeddings_config() + embedding_model_name = embedding_config.get("embeddings_model_name") + if embedding_model_name: + embedding_model_path = pathlib.Path( + unwrap(embedding_config.get("embeddings_model_dir"), "models") + ) + embedding_model_path = embedding_model_path / embedding_model_name + + await model.load_embeddings_model(embedding_model_path, **embedding_config) + await start_api(host, port)