Embeddings: Use response classes instead of dicts
Follows the existing code style. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
3f21d9ef96
commit
ac1afcc588
1 changed files with 28 additions and 32 deletions
|
|
@ -7,37 +7,41 @@ typing/pydantic classes moved into this file,
|
|||
embeddings function declared async.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import base64
|
||||
import pathlib
|
||||
from loguru import logger
|
||||
import numpy as np
|
||||
from transformers import AutoModel
|
||||
from loguru import logger
|
||||
|
||||
from common import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.embedding import (
|
||||
EmbeddingObject,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
UsageInfo,
|
||||
)
|
||||
|
||||
# Conditionally import infinity embeddings engine
|
||||
# Required so the logger doesn't take over tabby's logging handlers
|
||||
try:
|
||||
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
|
||||
|
||||
has_infinity_emb = True
|
||||
except ImportError:
|
||||
has_infinity_emb = False
|
||||
|
||||
|
||||
embeddings_model = None
|
||||
|
||||
|
||||
def load_embedding_model(model_path: pathlib.Path, device: str):
|
||||
try:
|
||||
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
|
||||
except ModuleNotFoundError:
|
||||
if not has_infinity_emb:
|
||||
logger.error(
|
||||
"Skipping embeddings because infinity-emb is not installed.\n"
|
||||
"Please run the following command in your environment "
|
||||
"to install extra packages:\n"
|
||||
"pip install -U .[extras]"
|
||||
)
|
||||
raise ModuleNotFoundError from None
|
||||
raise ModuleNotFoundError
|
||||
|
||||
global embeddings_model
|
||||
try:
|
||||
|
|
@ -76,30 +80,22 @@ async def embeddings(data: EmbeddingsRequest) -> dict:
|
|||
embeddings, usage = await embeddings_model.embed(data.input)
|
||||
|
||||
# OAI expects a return of base64 if the input is base64
|
||||
if data.encoding_format == "base64":
|
||||
embedding_data = [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": float_list_to_base64(emb),
|
||||
"index": n,
|
||||
}
|
||||
for n, emb in enumerate(embeddings)
|
||||
]
|
||||
else:
|
||||
embedding_data = [
|
||||
{"object": "embedding", "embedding": emb.tolist(), "index": n}
|
||||
for n, emb in enumerate(embeddings)
|
||||
]
|
||||
embedding_data = [
|
||||
EmbeddingObject(
|
||||
embedding=float_list_to_base64(emb)
|
||||
if data.encoding_format == "base64"
|
||||
else emb.tolist(),
|
||||
index=n,
|
||||
)
|
||||
for n, emb in enumerate(embeddings)
|
||||
]
|
||||
|
||||
response = EmbeddingsResponse(
|
||||
data=embedding_data,
|
||||
model=model_path.name,
|
||||
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
|
||||
)
|
||||
|
||||
response = {
|
||||
"object": "list",
|
||||
"data": embedding_data,
|
||||
"model": model_path.name,
|
||||
"usage": {
|
||||
"prompt_tokens": usage,
|
||||
"total_tokens": usage,
|
||||
},
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue