Embeddings: Use response classes instead of dicts

Follows the existing code style.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-29 14:15:40 -04:00
parent 3f21d9ef96
commit ac1afcc588

View file

@ -7,37 +7,41 @@ typing/pydantic classes moved into this file,
embeddings function declared async.
"""
import asyncio
import os
import base64
import pathlib
from loguru import logger
import numpy as np
from transformers import AutoModel
from loguru import logger
from common import config
from common.utils import unwrap
from endpoints.OAI.types.embedding import (
EmbeddingObject,
EmbeddingsRequest,
EmbeddingsResponse,
UsageInfo,
)
# Conditionally import infinity embeddings engine
# Required so the logger doesn't take over tabby's logging handlers
try:
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
has_infinity_emb = True
except ImportError:
has_infinity_emb = False
embeddings_model = None
def load_embedding_model(model_path: pathlib.Path, device: str):
try:
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
except ModuleNotFoundError:
if not has_infinity_emb:
logger.error(
"Skipping embeddings because infinity-emb is not installed.\n"
"Please run the following command in your environment "
"to install extra packages:\n"
"pip install -U .[extras]"
)
raise ModuleNotFoundError from None
raise ModuleNotFoundError
global embeddings_model
try:
@ -76,30 +80,22 @@ async def embeddings(data: EmbeddingsRequest) -> dict:
embeddings, usage = await embeddings_model.embed(data.input)
# OAI expects a return of base64 if the input is base64
if data.encoding_format == "base64":
embedding_data = [
{
"object": "embedding",
"embedding": float_list_to_base64(emb),
"index": n,
}
for n, emb in enumerate(embeddings)
]
else:
embedding_data = [
{"object": "embedding", "embedding": emb.tolist(), "index": n}
for n, emb in enumerate(embeddings)
]
embedding_data = [
EmbeddingObject(
embedding=float_list_to_base64(emb)
if data.encoding_format == "base64"
else emb.tolist(),
index=n,
)
for n, emb in enumerate(embeddings)
]
response = EmbeddingsResponse(
data=embedding_data,
model=model_path.name,
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
)
response = {
"object": "list",
"data": embedding_data,
"model": model_path.name,
"usage": {
"prompt_tokens": usage,
"total_tokens": usage,
},
}
return response