Embeddings: Migrate and organize Infinity
Use Infinity as a separate backend and handle the model within the common module. This separates out the embeddings model from the endpoint which allows for model loading/unloading in core. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
ac1afcc588
commit
fbf1455db1
6 changed files with 165 additions and 83 deletions
|
|
@ -23,7 +23,7 @@ from endpoints.OAI.utils.completion import (
|
|||
generate_completion,
|
||||
stream_generate_completion,
|
||||
)
|
||||
from endpoints.OAI.utils.embeddings import embeddings
|
||||
from endpoints.OAI.utils.embeddings import get_embeddings
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -134,7 +134,12 @@ async def chat_completion_request(
|
|||
"/v1/embeddings",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def handle_embeddings(data: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
response = await embeddings(data)
|
||||
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
embeddings_task = asyncio.create_task(get_embeddings(data, request))
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
embeddings_task,
|
||||
f"Embeddings request {request.state.id} cancelled by user.",
|
||||
)
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -8,11 +8,11 @@ embeddings function declared async.
|
|||
"""
|
||||
|
||||
import base64
|
||||
import pathlib
|
||||
from fastapi import Request
|
||||
import numpy as np
|
||||
from loguru import logger
|
||||
|
||||
from common import config
|
||||
from common import model
|
||||
from endpoints.OAI.types.embedding import (
|
||||
EmbeddingObject,
|
||||
EmbeddingsRequest,
|
||||
|
|
@ -20,84 +20,6 @@ from endpoints.OAI.types.embedding import (
|
|||
UsageInfo,
|
||||
)
|
||||
|
||||
# Conditionally import infinity embeddings engine
|
||||
# Required so the logger doesn't take over tabby's logging handlers
|
||||
try:
|
||||
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
|
||||
|
||||
has_infinity_emb = True
|
||||
except ImportError:
|
||||
has_infinity_emb = False
|
||||
|
||||
|
||||
embeddings_model = None
|
||||
|
||||
|
||||
def load_embedding_model(model_path: pathlib.Path, device: str):
|
||||
if not has_infinity_emb:
|
||||
logger.error(
|
||||
"Skipping embeddings because infinity-emb is not installed.\n"
|
||||
"Please run the following command in your environment "
|
||||
"to install extra packages:\n"
|
||||
"pip install -U .[extras]"
|
||||
)
|
||||
raise ModuleNotFoundError
|
||||
|
||||
global embeddings_model
|
||||
try:
|
||||
engine_args = EngineArgs(
|
||||
model_name_or_path=str(model_path.resolve()),
|
||||
engine="torch",
|
||||
device="cpu",
|
||||
bettertransformer=False,
|
||||
model_warmup=False,
|
||||
)
|
||||
embeddings_model = AsyncEmbeddingEngine.from_args(engine_args)
|
||||
logger.info(f"Trying to load embeddings model: {model_path.name} on {device}")
|
||||
except Exception as e:
|
||||
embeddings_model = None
|
||||
raise e
|
||||
|
||||
|
||||
async def embeddings(data: EmbeddingsRequest) -> dict:
|
||||
embeddings_config = config.embeddings_config()
|
||||
|
||||
# Use CPU by default
|
||||
device = embeddings_config.get("embeddings_device", "cpu")
|
||||
if device == "auto":
|
||||
device = None
|
||||
|
||||
model_path = pathlib.Path(embeddings_config.get("embeddings_model_dir"))
|
||||
model_path: pathlib.Path = model_path / embeddings_config.get(
|
||||
"embeddings_model_name"
|
||||
)
|
||||
if not model_path:
|
||||
logger.info("Embeddings model path not found")
|
||||
|
||||
load_embedding_model(model_path, device)
|
||||
|
||||
async with embeddings_model:
|
||||
embeddings, usage = await embeddings_model.embed(data.input)
|
||||
|
||||
# OAI expects a return of base64 if the input is base64
|
||||
embedding_data = [
|
||||
EmbeddingObject(
|
||||
embedding=float_list_to_base64(emb)
|
||||
if data.encoding_format == "base64"
|
||||
else emb.tolist(),
|
||||
index=n,
|
||||
)
|
||||
for n, emb in enumerate(embeddings)
|
||||
]
|
||||
|
||||
response = EmbeddingsResponse(
|
||||
data=embedding_data,
|
||||
model=model_path.name,
|
||||
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||
"""
|
||||
|
|
@ -111,3 +33,32 @@ def float_list_to_base64(float_array: np.ndarray) -> str:
|
|||
# Turn raw base64 encoded bytes into ASCII
|
||||
ascii_string = encoded_bytes.decode("ascii")
|
||||
return ascii_string
|
||||
|
||||
|
||||
async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict:
|
||||
model_path = model.embeddings_container.model_dir
|
||||
|
||||
logger.info(f"Recieved embeddings request {request.state.id}")
|
||||
embedding_data = await model.embeddings_container.generate(data.input)
|
||||
|
||||
# OAI expects a return of base64 if the input is base64
|
||||
embedding_object = [
|
||||
EmbeddingObject(
|
||||
embedding=float_list_to_base64(emb)
|
||||
if data.encoding_format == "base64"
|
||||
else emb.tolist(),
|
||||
index=n,
|
||||
)
|
||||
for n, emb in enumerate(embedding_data.get("embeddings"))
|
||||
]
|
||||
|
||||
usage = embedding_data.get("usage")
|
||||
response = EmbeddingsResponse(
|
||||
data=embedding_object,
|
||||
model=model_path.name,
|
||||
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
|
||||
)
|
||||
|
||||
logger.info(f"Finished embeddings request {request.state.id}")
|
||||
|
||||
return response
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue