Embeddings: Switch to Infinity
Infinity-emb is an async batching engine for embeddings. This is preferable to sentence-transformers since it handles scalable usecases without the need for external thread intervention. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
c9a5d2c363
commit
3f21d9ef96
4 changed files with 87 additions and 100 deletions
|
|
@ -95,3 +95,8 @@ def logging_config():
|
|||
def developer_config():
|
||||
"""Returns the developer specific config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("developer"), {})
|
||||
|
||||
|
||||
def embeddings_config():
|
||||
"""Returns the embeddings config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("embeddings"), {})
|
||||
|
|
|
|||
|
|
@ -72,6 +72,13 @@ developer:
|
|||
# Otherwise, the priority will be set to high
|
||||
#realtime_process_priority: False
|
||||
|
||||
embeddings:
|
||||
embeddings_model_dir: models
|
||||
|
||||
embeddings_model_name:
|
||||
|
||||
embeddings_device: cpu
|
||||
|
||||
# Options for model overrides and loading
|
||||
# Please read the comments to understand how arguments are handled between initial and API loads
|
||||
model:
|
||||
|
|
|
|||
|
|
@ -135,6 +135,6 @@ async def chat_completion_request(
|
|||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def handle_embeddings(data: EmbeddingsRequest) -> EmbeddingsResponse:
|
||||
response = await embeddings(data.input, data.encoding_format, data.model)
|
||||
response = await embeddings(data)
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -7,135 +7,110 @@ typing/pydantic classes moved into this file,
|
|||
embeddings function declared async.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import base64
|
||||
import pathlib
|
||||
from loguru import logger
|
||||
import numpy as np
|
||||
from transformers import AutoModel
|
||||
|
||||
embeddings_params_initialized = False
|
||||
from common import config
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.embedding import (
|
||||
EmbeddingObject,
|
||||
EmbeddingsRequest,
|
||||
EmbeddingsResponse,
|
||||
)
|
||||
|
||||
|
||||
def initialize_embedding_params():
|
||||
"""
|
||||
using 'lazy loading' to avoid circular import
|
||||
so this function will be executed only once
|
||||
"""
|
||||
global embeddings_params_initialized
|
||||
if not embeddings_params_initialized:
|
||||
global st_model, embeddings_model, embeddings_device
|
||||
|
||||
st_model = os.environ.get("OPENAI_EMBEDDING_MODEL", "all-mpnet-base-v2")
|
||||
embeddings_model = None
|
||||
# OPENAI_EMBEDDING_DEVICE: auto (best or cpu),
|
||||
# cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep,
|
||||
# hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta,
|
||||
# hpu, mtia, privateuseone
|
||||
embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", "cpu")
|
||||
if embeddings_device.lower() == "auto":
|
||||
embeddings_device = None
|
||||
|
||||
embeddings_params_initialized = True
|
||||
embeddings_model = None
|
||||
|
||||
|
||||
def load_embedding_model(model: str):
|
||||
def load_embedding_model(model_path: pathlib.Path, device: str):
|
||||
try:
|
||||
from sentence_transformers import SentenceTransformer
|
||||
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
|
||||
except ModuleNotFoundError:
|
||||
print(
|
||||
"The sentence_transformers module has not been found. "
|
||||
+ "Please install it manually with "
|
||||
+ "pip install -U sentence-transformers."
|
||||
logger.error(
|
||||
"Skipping embeddings because infinity-emb is not installed.\n"
|
||||
"Please run the following command in your environment "
|
||||
"to install extra packages:\n"
|
||||
"pip install -U .[extras]"
|
||||
)
|
||||
raise ModuleNotFoundError from None
|
||||
|
||||
initialize_embedding_params()
|
||||
global embeddings_device, embeddings_model
|
||||
global embeddings_model
|
||||
try:
|
||||
print(f"Try embedding model: {model} on {embeddings_device}")
|
||||
if "jina-embeddings" in model:
|
||||
# trust_remote_code is needed to use the encode method
|
||||
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True)
|
||||
embeddings_model = embeddings_model.to(embeddings_device)
|
||||
else:
|
||||
embeddings_model = SentenceTransformer(
|
||||
model,
|
||||
device=embeddings_device,
|
||||
)
|
||||
|
||||
print(f"Loaded embedding model: {model}")
|
||||
engine_args = EngineArgs(
|
||||
model_name_or_path=str(model_path.resolve()),
|
||||
engine="torch",
|
||||
device="cpu",
|
||||
bettertransformer=False,
|
||||
model_warmup=False,
|
||||
)
|
||||
embeddings_model = AsyncEmbeddingEngine.from_args(engine_args)
|
||||
logger.info(f"Trying to load embeddings model: {model_path.name} on {device}")
|
||||
except Exception as e:
|
||||
embeddings_model = None
|
||||
raise Exception(
|
||||
f"Error: Failed to load embedding model: {model}", internal_message=repr(e)
|
||||
) from None
|
||||
raise e
|
||||
|
||||
|
||||
def get_embeddings_model():
|
||||
initialize_embedding_params()
|
||||
global embeddings_model, st_model
|
||||
if st_model and not embeddings_model:
|
||||
load_embedding_model(st_model) # lazy load the model
|
||||
async def embeddings(data: EmbeddingsRequest) -> dict:
|
||||
embeddings_config = config.embeddings_config()
|
||||
|
||||
return embeddings_model
|
||||
# Use CPU by default
|
||||
device = embeddings_config.get("embeddings_device", "cpu")
|
||||
if device == "auto":
|
||||
device = None
|
||||
|
||||
|
||||
def get_embeddings_model_name() -> str:
|
||||
initialize_embedding_params()
|
||||
global st_model
|
||||
return st_model
|
||||
|
||||
|
||||
def get_embeddings(input: list) -> np.ndarray:
|
||||
model = get_embeddings_model()
|
||||
embedding = model.encode(
|
||||
input,
|
||||
convert_to_numpy=True,
|
||||
normalize_embeddings=True,
|
||||
convert_to_tensor=False,
|
||||
show_progress_bar=False,
|
||||
model_path = pathlib.Path(embeddings_config.get("embeddings_model_dir"))
|
||||
model_path: pathlib.Path = model_path / embeddings_config.get(
|
||||
"embeddings_model_name"
|
||||
)
|
||||
return embedding
|
||||
if not model_path:
|
||||
logger.info("Embeddings model path not found")
|
||||
|
||||
load_embedding_model(model_path, device)
|
||||
|
||||
async def embeddings(input: list, encoding_format: str, model: str = None) -> dict:
|
||||
if model is None:
|
||||
model = st_model
|
||||
else:
|
||||
load_embedding_model(model)
|
||||
async with embeddings_model:
|
||||
embeddings, usage = await embeddings_model.embed(data.input)
|
||||
|
||||
embeddings = get_embeddings(input)
|
||||
if encoding_format == "base64":
|
||||
data = [
|
||||
{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n}
|
||||
for n, emb in enumerate(embeddings)
|
||||
]
|
||||
else:
|
||||
data = [
|
||||
{"object": "embedding", "embedding": emb.tolist(), "index": n}
|
||||
for n, emb in enumerate(embeddings)
|
||||
]
|
||||
# OAI expects a return of base64 if the input is base64
|
||||
if data.encoding_format == "base64":
|
||||
embedding_data = [
|
||||
{
|
||||
"object": "embedding",
|
||||
"embedding": float_list_to_base64(emb),
|
||||
"index": n,
|
||||
}
|
||||
for n, emb in enumerate(embeddings)
|
||||
]
|
||||
else:
|
||||
embedding_data = [
|
||||
{"object": "embedding", "embedding": emb.tolist(), "index": n}
|
||||
for n, emb in enumerate(embeddings)
|
||||
]
|
||||
|
||||
response = {
|
||||
"object": "list",
|
||||
"data": data,
|
||||
"model": st_model if model is None else model,
|
||||
"usage": {
|
||||
"prompt_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
},
|
||||
}
|
||||
return response
|
||||
response = {
|
||||
"object": "list",
|
||||
"data": embedding_data,
|
||||
"model": model_path.name,
|
||||
"usage": {
|
||||
"prompt_tokens": usage,
|
||||
"total_tokens": usage,
|
||||
},
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
def float_list_to_base64(float_array: np.ndarray) -> str:
|
||||
# Convert the list to a float32 array that the OpenAPI client expects
|
||||
# float_array = np.array(float_list, dtype="float32")
|
||||
"""
|
||||
Converts the provided list to a float32 array for OpenAI
|
||||
Ex. float_array = np.array(float_list, dtype="float32")
|
||||
"""
|
||||
|
||||
# Get raw bytes
|
||||
bytes_array = float_array.tobytes()
|
||||
|
||||
# Encode bytes into base64
|
||||
encoded_bytes = base64.b64encode(bytes_array)
|
||||
# Encode raw bytes into base64
|
||||
encoded_bytes = base64.b64encode(float_array.tobytes())
|
||||
|
||||
# Turn raw base64 encoded bytes into ASCII
|
||||
ascii_string = encoded_bytes.decode("ascii")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue