Embeddings: Migrate and organize Infinity

Use Infinity as a separate backend and handle the model within the
common module. This separates out the embeddings model from the endpoint
which allows for model loading/unloading in core.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-30 11:00:23 -04:00
parent ac1afcc588
commit fbf1455db1
6 changed files with 165 additions and 83 deletions

View file

@ -23,7 +23,7 @@ from endpoints.OAI.utils.completion import (
generate_completion,
stream_generate_completion,
)
from endpoints.OAI.utils.embeddings import embeddings
from endpoints.OAI.utils.embeddings import get_embeddings
router = APIRouter()
@ -134,7 +134,12 @@ async def chat_completion_request(
"/v1/embeddings",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def handle_embeddings(data: EmbeddingsRequest) -> EmbeddingsResponse:
response = await embeddings(data)
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
embeddings_task = asyncio.create_task(get_embeddings(data, request))
response = await run_with_request_disconnect(
request,
embeddings_task,
f"Embeddings request {request.state.id} cancelled by user.",
)
return response

View file

@ -8,11 +8,11 @@ embeddings function declared async.
"""
import base64
import pathlib
from fastapi import Request
import numpy as np
from loguru import logger
from common import config
from common import model
from endpoints.OAI.types.embedding import (
EmbeddingObject,
EmbeddingsRequest,
@ -20,84 +20,6 @@ from endpoints.OAI.types.embedding import (
UsageInfo,
)
# Conditionally import infinity embeddings engine
# Required so the logger doesn't take over tabby's logging handlers
try:
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
has_infinity_emb = True
except ImportError:
has_infinity_emb = False
embeddings_model = None
def load_embedding_model(model_path: pathlib.Path, device: str):
if not has_infinity_emb:
logger.error(
"Skipping embeddings because infinity-emb is not installed.\n"
"Please run the following command in your environment "
"to install extra packages:\n"
"pip install -U .[extras]"
)
raise ModuleNotFoundError
global embeddings_model
try:
engine_args = EngineArgs(
model_name_or_path=str(model_path.resolve()),
engine="torch",
device="cpu",
bettertransformer=False,
model_warmup=False,
)
embeddings_model = AsyncEmbeddingEngine.from_args(engine_args)
logger.info(f"Trying to load embeddings model: {model_path.name} on {device}")
except Exception as e:
embeddings_model = None
raise e
async def embeddings(data: EmbeddingsRequest) -> dict:
embeddings_config = config.embeddings_config()
# Use CPU by default
device = embeddings_config.get("embeddings_device", "cpu")
if device == "auto":
device = None
model_path = pathlib.Path(embeddings_config.get("embeddings_model_dir"))
model_path: pathlib.Path = model_path / embeddings_config.get(
"embeddings_model_name"
)
if not model_path:
logger.info("Embeddings model path not found")
load_embedding_model(model_path, device)
async with embeddings_model:
embeddings, usage = await embeddings_model.embed(data.input)
# OAI expects a return of base64 if the input is base64
embedding_data = [
EmbeddingObject(
embedding=float_list_to_base64(emb)
if data.encoding_format == "base64"
else emb.tolist(),
index=n,
)
for n, emb in enumerate(embeddings)
]
response = EmbeddingsResponse(
data=embedding_data,
model=model_path.name,
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
)
return response
def float_list_to_base64(float_array: np.ndarray) -> str:
"""
@ -111,3 +33,32 @@ def float_list_to_base64(float_array: np.ndarray) -> str:
# Turn raw base64 encoded bytes into ASCII
ascii_string = encoded_bytes.decode("ascii")
return ascii_string
async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict:
model_path = model.embeddings_container.model_dir
logger.info(f"Recieved embeddings request {request.state.id}")
embedding_data = await model.embeddings_container.generate(data.input)
# OAI expects a return of base64 if the input is base64
embedding_object = [
EmbeddingObject(
embedding=float_list_to_base64(emb)
if data.encoding_format == "base64"
else emb.tolist(),
index=n,
)
for n, emb in enumerate(embedding_data.get("embeddings"))
]
usage = embedding_data.get("usage")
response = EmbeddingsResponse(
data=embedding_object,
model=model_path.name,
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
)
logger.info(f"Finished embeddings request {request.state.id}")
return response