Embeddings: Migrate and organize Infinity

Use Infinity as a separate backend and handle the model within the
common module. This separates out the embeddings model from the endpoint
which allows for model loading/unloading in core.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-30 11:00:23 -04:00
parent ac1afcc588
commit fbf1455db1
6 changed files with 165 additions and 83 deletions

View file

@ -0,0 +1,56 @@
import gc
import pathlib
import torch
from typing import List, Optional
from common.utils import unwrap
# Conditionally import infinity to sidestep its logger
# TODO: Make this prettier
try:
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
has_infinity_emb = True
except ImportError:
has_infinity_emb = False
class InfinityContainer:
model_dir: pathlib.Path
# Conditionally set the type hint based on importablity
# TODO: Clean this up
if has_infinity_emb:
engine: Optional[AsyncEmbeddingEngine] = None
else:
engine = None
def __init__(self, model_directory: pathlib.Path):
self.model_dir = model_directory
async def load(self, **kwargs):
# Use cpu by default
device = unwrap(kwargs.get("device"), "cpu")
engine_args = EngineArgs(
model_name_or_path=str(self.model_dir),
engine="torch",
device=device,
bettertransformer=False,
model_warmup=False,
)
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
await self.engine.astart()
async def unload(self):
await self.engine.astop()
self.engine = None
gc.collect()
torch.cuda.empty_cache()
async def generate(self, sentence_input: List[str]):
result_embeddings, usage = await self.engine.embed(sentence_input)
return {"embeddings": result_embeddings, "usage": usage}

View file

@ -20,6 +20,15 @@ if not do_export_openapi:
# Global model container
container: Optional[ExllamaV2Container] = None
embeddings_container = None
# Type hint the infinity emb container if it exists
from backends.infinity.model import has_infinity_emb
if has_infinity_emb:
from backends.infinity.model import InfinityContainer
embeddings_container: Optional[InfinityContainer] = None
def load_progress(module, modules):
@ -100,6 +109,30 @@ async def unload_loras():
await container.unload(loras_only=True)
async def load_embeddings_model(model_path: pathlib.Path, **kwargs):
global embeddings_container
# Break out if infinity isn't installed
if not has_infinity_emb:
logger.warning(
"Skipping embeddings because infinity-emb is not installed.\n"
"Please run the following command in your environment "
"to install extra packages:\n"
"pip install -U .[extras]"
)
return
embeddings_container = InfinityContainer(model_path)
await embeddings_container.load(**kwargs)
async def unload_embeddings_model():
global embeddings_container
await embeddings_container.unload()
embeddings_container = None
def get_config_default(key, fallback=None, is_draft=False):
"""Fetches a default value from model config if allowed by the user."""
@ -126,3 +159,15 @@ async def check_model_container():
).error.message
raise HTTPException(400, error_message)
async def check_embeddings_container():
"""FastAPI depends that checks if an embeddings model is loaded."""
if embeddings_container is None:
error_message = handle_request_error(
"No embeddings models are currently loaded.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)

View file

@ -1,13 +1,26 @@
import asyncio
import signal
import sys
from loguru import logger
from types import FrameType
from common import model
def signal_handler(*_):
"""Signal handler for main function. Run before uvicorn starts."""
logger.warning("Shutdown signal called. Exiting gracefully.")
# Run async unloads for model
loop = asyncio.get_running_loop()
if model.container:
loop.create_task(model.container.unload())
if model.embeddings_container:
loop.create_task(model.embeddings_container.unload())
# Exit the program
sys.exit(0)

View file

@ -23,7 +23,7 @@ from endpoints.OAI.utils.completion import (
generate_completion,
stream_generate_completion,
)
from endpoints.OAI.utils.embeddings import embeddings
from endpoints.OAI.utils.embeddings import get_embeddings
router = APIRouter()
@ -134,7 +134,12 @@ async def chat_completion_request(
"/v1/embeddings",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def handle_embeddings(data: EmbeddingsRequest) -> EmbeddingsResponse:
response = await embeddings(data)
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
embeddings_task = asyncio.create_task(get_embeddings(data, request))
response = await run_with_request_disconnect(
request,
embeddings_task,
f"Embeddings request {request.state.id} cancelled by user.",
)
return response

View file

@ -8,11 +8,11 @@ embeddings function declared async.
"""
import base64
import pathlib
from fastapi import Request
import numpy as np
from loguru import logger
from common import config
from common import model
from endpoints.OAI.types.embedding import (
EmbeddingObject,
EmbeddingsRequest,
@ -20,84 +20,6 @@ from endpoints.OAI.types.embedding import (
UsageInfo,
)
# Conditionally import infinity embeddings engine
# Required so the logger doesn't take over tabby's logging handlers
try:
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
has_infinity_emb = True
except ImportError:
has_infinity_emb = False
embeddings_model = None
def load_embedding_model(model_path: pathlib.Path, device: str):
if not has_infinity_emb:
logger.error(
"Skipping embeddings because infinity-emb is not installed.\n"
"Please run the following command in your environment "
"to install extra packages:\n"
"pip install -U .[extras]"
)
raise ModuleNotFoundError
global embeddings_model
try:
engine_args = EngineArgs(
model_name_or_path=str(model_path.resolve()),
engine="torch",
device="cpu",
bettertransformer=False,
model_warmup=False,
)
embeddings_model = AsyncEmbeddingEngine.from_args(engine_args)
logger.info(f"Trying to load embeddings model: {model_path.name} on {device}")
except Exception as e:
embeddings_model = None
raise e
async def embeddings(data: EmbeddingsRequest) -> dict:
embeddings_config = config.embeddings_config()
# Use CPU by default
device = embeddings_config.get("embeddings_device", "cpu")
if device == "auto":
device = None
model_path = pathlib.Path(embeddings_config.get("embeddings_model_dir"))
model_path: pathlib.Path = model_path / embeddings_config.get(
"embeddings_model_name"
)
if not model_path:
logger.info("Embeddings model path not found")
load_embedding_model(model_path, device)
async with embeddings_model:
embeddings, usage = await embeddings_model.embed(data.input)
# OAI expects a return of base64 if the input is base64
embedding_data = [
EmbeddingObject(
embedding=float_list_to_base64(emb)
if data.encoding_format == "base64"
else emb.tolist(),
index=n,
)
for n, emb in enumerate(embeddings)
]
response = EmbeddingsResponse(
data=embedding_data,
model=model_path.name,
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
)
return response
def float_list_to_base64(float_array: np.ndarray) -> str:
"""
@ -111,3 +33,32 @@ def float_list_to_base64(float_array: np.ndarray) -> str:
# Turn raw base64 encoded bytes into ASCII
ascii_string = encoded_bytes.decode("ascii")
return ascii_string
async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict:
model_path = model.embeddings_container.model_dir
logger.info(f"Recieved embeddings request {request.state.id}")
embedding_data = await model.embeddings_container.generate(data.input)
# OAI expects a return of base64 if the input is base64
embedding_object = [
EmbeddingObject(
embedding=float_list_to_base64(emb)
if data.encoding_format == "base64"
else emb.tolist(),
index=n,
)
for n, emb in enumerate(embedding_data.get("embeddings"))
]
usage = embedding_data.get("usage")
response = EmbeddingsResponse(
data=embedding_object,
model=model_path.name,
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
)
logger.info(f"Finished embeddings request {request.state.id}")
return response

12
main.py
View file

@ -87,6 +87,18 @@ async def entrypoint_async():
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
await model.container.load_loras(lora_dir.resolve(), **lora_config)
# If an initial embedding model name is specified, create a separate container
# and load the model
embedding_config = config.embeddings_config()
embedding_model_name = embedding_config.get("embeddings_model_name")
if embedding_model_name:
embedding_model_path = pathlib.Path(
unwrap(embedding_config.get("embeddings_model_dir"), "models")
)
embedding_model_path = embedding_model_path / embedding_model_name
await model.load_embeddings_model(embedding_model_path, **embedding_config)
await start_api(host, port)