feat: add embeddings support via sentence-transformers

This commit is contained in:
AlpinDale 2024-07-26 02:45:07 +00:00
parent a1c3f6cc1c
commit f20cd330ef
5 changed files with 210 additions and 1 deletions

145
endpoints/OAI/embeddings.py Normal file
View file

@ -0,0 +1,145 @@
"""
This file is derived from
[text-generation-webui openai extension embeddings](https://github.com/oobabooga/text-generation-webui/blob/1a7c027386f43b84f3ca3b0ff04ca48d861c2d7a/extensions/openai/embeddings.py)
and modified.
The changes introduced are: Suppression of progress bar,
typing/pydantic classes moved into this file,
embeddings function declared async.
"""
import os
import base64
import numpy as np
from transformers import AutoModel
embeddings_params_initialized = False
def initialize_embedding_params():
'''
using 'lazy loading' to avoid circular import
so this function will be executed only once
'''
global embeddings_params_initialized
if not embeddings_params_initialized:
global st_model, embeddings_model, embeddings_device
st_model = os.environ.get("OPENAI_EMBEDDING_MODEL",
'all-mpnet-base-v2')
embeddings_model = None
# OPENAI_EMBEDDING_DEVICE: auto (best or cpu),
# cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep,
# hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta,
# hpu, mtia, privateuseone
embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", 'cpu')
if embeddings_device.lower() == 'auto':
embeddings_device = None
embeddings_params_initialized = True
def load_embedding_model(model: str):
try:
from sentence_transformers import SentenceTransformer
except ModuleNotFoundError:
print("The sentence_transformers module has not been found. " +
"Please install it manually with " +
"pip install -U sentence-transformers.")
raise ModuleNotFoundError from None
initialize_embedding_params()
global embeddings_device, embeddings_model
try:
print(f"Try embedding model: {model} on {embeddings_device}")
if 'jina-embeddings' in model:
# trust_remote_code is needed to use the encode method
embeddings_model = AutoModel.from_pretrained(
model, trust_remote_code=True)
embeddings_model = embeddings_model.to(embeddings_device)
else:
embeddings_model = SentenceTransformer(
model,
device=embeddings_device,
)
print(f"Loaded embedding model: {model}")
except Exception as e:
embeddings_model = None
raise Exception(f"Error: Failed to load embedding model: {model}",
internal_message=repr(e)) from None
def get_embeddings_model():
initialize_embedding_params()
global embeddings_model, st_model
if st_model and not embeddings_model:
load_embedding_model(st_model) # lazy load the model
return embeddings_model
def get_embeddings_model_name() -> str:
initialize_embedding_params()
global st_model
return st_model
def get_embeddings(input: list) -> np.ndarray:
model = get_embeddings_model()
embedding = model.encode(input,
convert_to_numpy=True,
normalize_embeddings=True,
convert_to_tensor=False,
show_progress_bar=False)
return embedding
async def embeddings(input: list,
encoding_format: str,
model: str = None) -> dict:
if model is None:
model = st_model
else:
load_embedding_model(model)
embeddings = get_embeddings(input)
if encoding_format == "base64":
data = [{
"object": "embedding",
"embedding": float_list_to_base64(emb),
"index": n
} for n, emb in enumerate(embeddings)]
else:
data = [{
"object": "embedding",
"embedding": emb.tolist(),
"index": n
} for n, emb in enumerate(embeddings)]
response = {
"object": "list",
"data": data,
"model": st_model if model is None else model,
"usage": {
"prompt_tokens": 0,
"total_tokens": 0,
}
}
return response
def float_list_to_base64(float_array: np.ndarray) -> str:
# Convert the list to a float32 array that the OpenAPI client expects
# float_array = np.array(float_list, dtype="float32")
# Get raw bytes
bytes_array = float_array.tobytes()
# Encode bytes into base64
encoded_bytes = base64.b64encode(bytes_array)
# Turn raw base64 encoded bytes into ASCII
ascii_string = encoded_bytes.decode('ascii')
return ascii_string

View file

@ -1,5 +1,6 @@
import asyncio
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import JSONResponse
from sse_starlette import EventSourceResponse
from sys import maxsize
@ -8,11 +9,16 @@ from common.auth import check_api_key
from common.model import check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.utils import unwrap
import endpoints.OAI.embeddings as OAIembeddings
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
)
from endpoints.OAI.types.embedding import (
EmbeddingsRequest,
EmbeddingsResponse
)
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
generate_chat_completion,
@ -125,3 +131,20 @@ async def chat_completion_request(
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
)
return response
# Embeddings endpoint
@router.post(
"/v1/embeddings",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
response_model=EmbeddingsResponse
)
async def handle_embeddings(request: EmbeddingsRequest):
input = request.input
if not input:
raise JSONResponse(status_code=400,
content={"error": "Missing required argument input"})
model = request.model if request.model else None
response = await OAIembeddings.embeddings(input, request.encoding_format,
model)
return JSONResponse(response)

View file

@ -0,0 +1,39 @@
from typing import List, Optional
from pydantic import BaseModel, Field
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class EmbeddingsRequest(BaseModel):
input: List[str] = Field(
..., description="List of input texts to generate embeddings for.")
encoding_format: str = Field(
"float",
description="Encoding format for the embeddings. "
"Can be 'float' or 'base64'.")
model: Optional[str] = Field(
None,
description="Name of the embedding model to use. "
"If not provided, the default model will be used.")
class EmbeddingObject(BaseModel):
object: str = Field("embedding", description="Type of the object.")
embedding: List[float] = Field(
..., description="Embedding values as a list of floats.")
index: int = Field(
...,
description="Index of the input text corresponding to "
"the embedding.")
class EmbeddingsResponse(BaseModel):
object: str = Field("list", description="Type of the response object.")
data: List[EmbeddingObject] = Field(
..., description="List of embedding objects.")
model: str = Field(..., description="Name of the embedding model used.")
usage: UsageInfo = Field(..., description="Information about token usage.")

View file

@ -47,7 +47,8 @@ dependencies = [
[project.optional-dependencies]
extras = [
# Heavy dependencies that aren't for everyday use
"outlines"
"outlines",
"sentence-transformers"
]
dev = [
"ruff == 0.3.2"

1
tabbyAPI Submodule

@ -0,0 +1 @@
Subproject commit 1650e6e6406edf797576c077aaceafcf28895c26