Embeddings: Migrate and organize Infinity
Use Infinity as a separate backend and handle the model within the common module. This separates out the embeddings model from the endpoint which allows for model loading/unloading in core. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
ac1afcc588
commit
fbf1455db1
6 changed files with 165 additions and 83 deletions
56
backends/infinity/model.py
Normal file
56
backends/infinity/model.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
import gc
|
||||
import pathlib
|
||||
import torch
|
||||
from typing import List, Optional
|
||||
|
||||
from common.utils import unwrap
|
||||
|
||||
# Conditionally import infinity to sidestep its logger
|
||||
# TODO: Make this prettier
|
||||
try:
|
||||
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
|
||||
|
||||
has_infinity_emb = True
|
||||
except ImportError:
|
||||
has_infinity_emb = False
|
||||
|
||||
|
||||
class InfinityContainer:
|
||||
model_dir: pathlib.Path
|
||||
|
||||
# Conditionally set the type hint based on importablity
|
||||
# TODO: Clean this up
|
||||
if has_infinity_emb:
|
||||
engine: Optional[AsyncEmbeddingEngine] = None
|
||||
else:
|
||||
engine = None
|
||||
|
||||
def __init__(self, model_directory: pathlib.Path):
|
||||
self.model_dir = model_directory
|
||||
|
||||
async def load(self, **kwargs):
|
||||
# Use cpu by default
|
||||
device = unwrap(kwargs.get("device"), "cpu")
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model_name_or_path=str(self.model_dir),
|
||||
engine="torch",
|
||||
device=device,
|
||||
bettertransformer=False,
|
||||
model_warmup=False,
|
||||
)
|
||||
|
||||
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
|
||||
await self.engine.astart()
|
||||
|
||||
async def unload(self):
|
||||
await self.engine.astop()
|
||||
self.engine = None
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
async def generate(self, sentence_input: List[str]):
|
||||
result_embeddings, usage = await self.engine.embed(sentence_input)
|
||||
|
||||
return {"embeddings": result_embeddings, "usage": usage}
|
||||
|
|
@ -20,6 +20,15 @@ if not do_export_openapi:
|
|||
|
||||
# Global model container
|
||||
container: Optional[ExllamaV2Container] = None
|
||||
embeddings_container = None
|
||||
|
||||
# Type hint the infinity emb container if it exists
|
||||
from backends.infinity.model import has_infinity_emb
|
||||
|
||||
if has_infinity_emb:
|
||||
from backends.infinity.model import InfinityContainer
|
||||
|
||||
embeddings_container: Optional[InfinityContainer] = None
|
||||
|
||||
|
||||
def load_progress(module, modules):
|
||||
|
|
@ -100,6 +109,30 @@ async def unload_loras():
|
|||
await container.unload(loras_only=True)
|
||||
|
||||
|
||||
async def load_embeddings_model(model_path: pathlib.Path, **kwargs):
|
||||
global embeddings_container
|
||||
|
||||
# Break out if infinity isn't installed
|
||||
if not has_infinity_emb:
|
||||
logger.warning(
|
||||
"Skipping embeddings because infinity-emb is not installed.\n"
|
||||
"Please run the following command in your environment "
|
||||
"to install extra packages:\n"
|
||||
"pip install -U .[extras]"
|
||||
)
|
||||
return
|
||||
|
||||
embeddings_container = InfinityContainer(model_path)
|
||||
await embeddings_container.load(**kwargs)
|
||||
|
||||
|
||||
async def unload_embeddings_model():
|
||||
global embeddings_container
|
||||
|
||||
await embeddings_container.unload()
|
||||
embeddings_container = None
|
||||
|
||||
|
||||
def get_config_default(key, fallback=None, is_draft=False):
|
||||
"""Fetches a default value from model config if allowed by the user."""
|
||||
|
||||
|
|
@ -126,3 +159,15 @@ async def check_model_container():
|
|||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
|
||||
async def check_embeddings_container():
|
||||
"""FastAPI depends that checks if an embeddings model is loaded."""
|
||||
|
||||
if embeddings_container is None:
|
||||
error_message = handle_request_error(
|
||||
"No embeddings models are currently loaded.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,26 @@
|
|||
import asyncio
|
||||
import signal
|
||||
import sys
|
||||
from loguru import logger
|
||||
from types import FrameType
|
||||
|
||||
from common import model
|
||||
|
||||
|
||||
def signal_handler(*_):
|
||||
"""Signal handler for main function. Run before uvicorn starts."""
|
||||
|
||||
logger.warning("Shutdown signal called. Exiting gracefully.")
|
||||
|
||||
# Run async unloads for model
|
||||
loop = asyncio.get_running_loop()
|
||||
if model.container:
|
||||
loop.create_task(model.container.unload())
|
||||
|
||||
if model.embeddings_container:
|
||||
loop.create_task(model.embeddings_container.unload())
|
||||
|
||||
# Exit the program
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ from endpoints.OAI.utils.completion import (
|
|||
generate_completion,
|
||||
stream_generate_completion,
|
||||
)
|
||||
from endpoints.OAI.utils.embeddings import embeddings
|
||||
from endpoints.OAI.utils.embeddings import get_embeddings
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -134,7 +134,12 @@ async def chat_completion_request(
|
|||
"/v1/embeddings",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def handle_embeddings(data: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
response = await embeddings(data)
|
||||
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
embeddings_task = asyncio.create_task(get_embeddings(data, request))
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
embeddings_task,
|
||||
f"Embeddings request {request.state.id} cancelled by user.",
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -8,11 +8,11 @@ embeddings function declared async.
|
|||
"""
|
||||
|
||||
import base64
|
||||
import pathlib
|
||||
from fastapi import Request
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from common import config
|
||||
from common import model
|
||||
from endpoints.OAI.types.embedding import (
|
||||
EmbeddingObject,
|
||||
EmbeddingsRequest,
|
||||
|
|
@ -20,84 +20,6 @@ from endpoints.OAI.types.embedding import (
|
|||
UsageInfo,
|
||||
)
|
||||
|
||||
# Conditionally import infinity embeddings engine
|
||||
# Required so the logger doesn't take over tabby's logging handlers
|
||||
try:
|
||||
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
|
||||
|
||||
has_infinity_emb = True
|
||||
except ImportError:
|
||||
has_infinity_emb = False
|
||||
|
||||
|
||||
embeddings_model = None
|
||||
|
||||
|
||||
def load_embedding_model(model_path: pathlib.Path, device: str):
|
||||
if not has_infinity_emb:
|
||||
logger.error(
|
||||
"Skipping embeddings because infinity-emb is not installed.\n"
|
||||
"Please run the following command in your environment "
|
||||
"to install extra packages:\n"
|
||||
"pip install -U .[extras]"
|
||||
)
|
||||
raise ModuleNotFoundError
|
||||
|
||||
global embeddings_model
|
||||
try:
|
||||
engine_args = EngineArgs(
|
||||
model_name_or_path=str(model_path.resolve()),
|
||||
engine="torch",
|
||||
device="cpu",
|
||||
bettertransformer=False,
|
||||
model_warmup=False,
|
||||
)
|
||||
embeddings_model = AsyncEmbeddingEngine.from_args(engine_args)
|
||||
logger.info(f"Trying to load embeddings model: {model_path.name} on {device}")
|
||||
except Exception as e:
|
||||
embeddings_model = None
|
||||
raise e
|
||||
|
||||
|
||||
async def embeddings(data: EmbeddingsRequest) -> dict:
|
||||
embeddings_config = config.embeddings_config()
|
||||
|
||||
# Use CPU by default
|
||||
device = embeddings_config.get("embeddings_device", "cpu")
|
||||
if device == "auto":
|
||||
device = None
|
||||
|
||||
model_path = pathlib.Path(embeddings_config.get("embeddings_model_dir"))
|
||||
model_path: pathlib.Path = model_path / embeddings_config.get(
|
||||
"embeddings_model_name"
|
||||
)
|
||||
if not model_path:
|
||||
logger.info("Embeddings model path not found")
|
||||
|
||||
load_embedding_model(model_path, device)
|
||||
|
||||
async with embeddings_model:
|
||||
embeddings, usage = await embeddings_model.embed(data.input)
|
||||
|
||||
# OAI expects a return of base64 if the input is base64
|
||||
embedding_data = [
|
||||
EmbeddingObject(
|
||||
embedding=float_list_to_base64(emb)
|
||||
if data.encoding_format == "base64"
|
||||
else emb.tolist(),
|
||||
index=n,
|
||||
)
|
||||
for n, emb in enumerate(embeddings)
|
||||
]
|
||||
|
||||
response = EmbeddingsResponse(
|
||||
data=embedding_data,
|
||||
model=model_path.name,
|
||||
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||
"""
|
||||
|
|
@ -111,3 +33,32 @@ def float_list_to_base64(float_array: np.ndarray) -> str:
|
|||
# Turn raw base64 encoded bytes into ASCII
|
||||
ascii_string = encoded_bytes.decode("ascii")
|
||||
return ascii_string
|
||||
|
||||
|
||||
async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict:
|
||||
model_path = model.embeddings_container.model_dir
|
||||
|
||||
logger.info(f"Recieved embeddings request {request.state.id}")
|
||||
embedding_data = await model.embeddings_container.generate(data.input)
|
||||
|
||||
# OAI expects a return of base64 if the input is base64
|
||||
embedding_object = [
|
||||
EmbeddingObject(
|
||||
embedding=float_list_to_base64(emb)
|
||||
if data.encoding_format == "base64"
|
||||
else emb.tolist(),
|
||||
index=n,
|
||||
)
|
||||
for n, emb in enumerate(embedding_data.get("embeddings"))
|
||||
]
|
||||
|
||||
usage = embedding_data.get("usage")
|
||||
response = EmbeddingsResponse(
|
||||
data=embedding_object,
|
||||
model=model_path.name,
|
||||
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
|
||||
)
|
||||
|
||||
logger.info(f"Finished embeddings request {request.state.id}")
|
||||
|
||||
return response
|
||||
|
|
|
|||
12
main.py
12
main.py
|
|
@ -87,6 +87,18 @@ async def entrypoint_async():
|
|||
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
||||
await model.container.load_loras(lora_dir.resolve(), **lora_config)
|
||||
|
||||
# If an initial embedding model name is specified, create a separate container
|
||||
# and load the model
|
||||
embedding_config = config.embeddings_config()
|
||||
embedding_model_name = embedding_config.get("embeddings_model_name")
|
||||
if embedding_model_name:
|
||||
embedding_model_path = pathlib.Path(
|
||||
unwrap(embedding_config.get("embeddings_model_dir"), "models")
|
||||
)
|
||||
embedding_model_path = embedding_model_path / embedding_model_name
|
||||
|
||||
await model.load_embeddings_model(embedding_model_path, **embedding_config)
|
||||
|
||||
await start_api(host, port)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue