diff --git a/endpoints/OAI/utils/embeddings.py b/endpoints/OAI/utils/embeddings.py index cf5b799..1ce611c 100644 --- a/endpoints/OAI/utils/embeddings.py +++ b/endpoints/OAI/utils/embeddings.py @@ -7,37 +7,41 @@ typing/pydantic classes moved into this file, embeddings function declared async. """ -import asyncio -import os import base64 import pathlib -from loguru import logger import numpy as np -from transformers import AutoModel +from loguru import logger from common import config -from common.utils import unwrap from endpoints.OAI.types.embedding import ( EmbeddingObject, EmbeddingsRequest, EmbeddingsResponse, + UsageInfo, ) +# Conditionally import infinity embeddings engine +# Required so the logger doesn't take over tabby's logging handlers +try: + from infinity_emb import EngineArgs, AsyncEmbeddingEngine + + has_infinity_emb = True +except ImportError: + has_infinity_emb = False + embeddings_model = None def load_embedding_model(model_path: pathlib.Path, device: str): - try: - from infinity_emb import EngineArgs, AsyncEmbeddingEngine - except ModuleNotFoundError: + if not has_infinity_emb: logger.error( "Skipping embeddings because infinity-emb is not installed.\n" "Please run the following command in your environment " "to install extra packages:\n" "pip install -U .[extras]" ) - raise ModuleNotFoundError from None + raise ModuleNotFoundError global embeddings_model try: @@ -76,30 +80,22 @@ async def embeddings(data: EmbeddingsRequest) -> dict: embeddings, usage = await embeddings_model.embed(data.input) # OAI expects a return of base64 if the input is base64 - if data.encoding_format == "base64": - embedding_data = [ - { - "object": "embedding", - "embedding": float_list_to_base64(emb), - "index": n, - } - for n, emb in enumerate(embeddings) - ] - else: - embedding_data = [ - {"object": "embedding", "embedding": emb.tolist(), "index": n} - for n, emb in enumerate(embeddings) - ] + embedding_data = [ + EmbeddingObject( + embedding=float_list_to_base64(emb) + if data.encoding_format == "base64" + else emb.tolist(), + index=n, + ) + for n, emb in enumerate(embeddings) + ] + + response = EmbeddingsResponse( + data=embedding_data, + model=model_path.name, + usage=UsageInfo(prompt_tokens=usage, total_tokens=usage), + ) - response = { - "object": "list", - "data": embedding_data, - "model": model_path.name, - "usage": { - "prompt_tokens": usage, - "total_tokens": usage, - }, - } return response