ruff: formatting

This commit is contained in:
AlpinDale 2024-07-26 02:53:14 +00:00
parent 765d3593b3
commit 5adfab1cbd
3 changed files with 52 additions and 55 deletions

View file

@ -16,24 +16,22 @@ embeddings_params_initialized = False
def initialize_embedding_params():
'''
"""
using 'lazy loading' to avoid circular import
so this function will be executed only once
'''
"""
global embeddings_params_initialized
if not embeddings_params_initialized:
global st_model, embeddings_model, embeddings_device
st_model = os.environ.get("OPENAI_EMBEDDING_MODEL",
'all-mpnet-base-v2')
st_model = os.environ.get("OPENAI_EMBEDDING_MODEL", "all-mpnet-base-v2")
embeddings_model = None
# OPENAI_EMBEDDING_DEVICE: auto (best or cpu),
# cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep,
# hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta,
# hpu, mtia, privateuseone
embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", 'cpu')
if embeddings_device.lower() == 'auto':
embeddings_device = os.environ.get("OPENAI_EMBEDDING_DEVICE", "cpu")
if embeddings_device.lower() == "auto":
embeddings_device = None
embeddings_params_initialized = True
@ -43,19 +41,20 @@ def load_embedding_model(model: str):
try:
from sentence_transformers import SentenceTransformer
except ModuleNotFoundError:
print("The sentence_transformers module has not been found. " +
"Please install it manually with " +
"pip install -U sentence-transformers.")
print(
"The sentence_transformers module has not been found. "
+ "Please install it manually with "
+ "pip install -U sentence-transformers."
)
raise ModuleNotFoundError from None
initialize_embedding_params()
global embeddings_device, embeddings_model
try:
print(f"Try embedding model: {model} on {embeddings_device}")
if 'jina-embeddings' in model:
if "jina-embeddings" in model:
# trust_remote_code is needed to use the encode method
embeddings_model = AutoModel.from_pretrained(
model, trust_remote_code=True)
embeddings_model = AutoModel.from_pretrained(model, trust_remote_code=True)
embeddings_model = embeddings_model.to(embeddings_device)
else:
embeddings_model = SentenceTransformer(
@ -66,8 +65,9 @@ def load_embedding_model(model: str):
print(f"Loaded embedding model: {model}")
except Exception as e:
embeddings_model = None
raise Exception(f"Error: Failed to load embedding model: {model}",
internal_message=repr(e)) from None
raise Exception(
f"Error: Failed to load embedding model: {model}", internal_message=repr(e)
) from None
def get_embeddings_model():
@ -87,17 +87,17 @@ def get_embeddings_model_name() -> str:
def get_embeddings(input: list) -> np.ndarray:
model = get_embeddings_model()
embedding = model.encode(input,
convert_to_numpy=True,
normalize_embeddings=True,
convert_to_tensor=False,
show_progress_bar=False)
embedding = model.encode(
input,
convert_to_numpy=True,
normalize_embeddings=True,
convert_to_tensor=False,
show_progress_bar=False,
)
return embedding
async def embeddings(input: list,
encoding_format: str,
model: str = None) -> dict:
async def embeddings(input: list, encoding_format: str, model: str = None) -> dict:
if model is None:
model = st_model
else:
@ -105,17 +105,15 @@ async def embeddings(input: list,
embeddings = get_embeddings(input)
if encoding_format == "base64":
data = [{
"object": "embedding",
"embedding": float_list_to_base64(emb),
"index": n
} for n, emb in enumerate(embeddings)]
data = [
{"object": "embedding", "embedding": float_list_to_base64(emb), "index": n}
for n, emb in enumerate(embeddings)
]
else:
data = [{
"object": "embedding",
"embedding": emb.tolist(),
"index": n
} for n, emb in enumerate(embeddings)]
data = [
{"object": "embedding", "embedding": emb.tolist(), "index": n}
for n, emb in enumerate(embeddings)
]
response = {
"object": "list",
@ -124,7 +122,7 @@ async def embeddings(input: list,
"usage": {
"prompt_tokens": 0,
"total_tokens": 0,
}
},
}
return response
@ -140,6 +138,5 @@ def float_list_to_base64(float_array: np.ndarray) -> str:
encoded_bytes = base64.b64encode(bytes_array)
# Turn raw base64 encoded bytes into ASCII
ascii_string = encoded_bytes.decode('ascii')
ascii_string = encoded_bytes.decode("ascii")
return ascii_string

View file

@ -15,10 +15,7 @@ from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
)
from endpoints.OAI.types.embedding import (
EmbeddingsRequest,
EmbeddingsResponse
)
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
generate_chat_completion,
@ -132,19 +129,19 @@ async def chat_completion_request(
)
return response
# Embeddings endpoint
@router.post(
"/v1/embeddings",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
response_model=EmbeddingsResponse
response_model=EmbeddingsResponse,
)
async def handle_embeddings(request: EmbeddingsRequest):
input = request.input
if not input:
raise JSONResponse(status_code=400,
content={"error": "Missing required argument input"})
raise JSONResponse(
status_code=400, content={"error": "Missing required argument input"}
)
model = request.model if request.model else None
response = await OAIembeddings.embeddings(input, request.encoding_format,
model)
response = await OAIembeddings.embeddings(input, request.encoding_format, model)
return JSONResponse(response)

View file

@ -8,32 +8,35 @@ class UsageInfo(BaseModel):
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class EmbeddingsRequest(BaseModel):
input: List[str] = Field(
..., description="List of input texts to generate embeddings for.")
..., description="List of input texts to generate embeddings for."
)
encoding_format: str = Field(
"float",
description="Encoding format for the embeddings. "
"Can be 'float' or 'base64'.")
"Can be 'float' or 'base64'.",
)
model: Optional[str] = Field(
None,
description="Name of the embedding model to use. "
"If not provided, the default model will be used.")
"If not provided, the default model will be used.",
)
class EmbeddingObject(BaseModel):
object: str = Field("embedding", description="Type of the object.")
embedding: List[float] = Field(
..., description="Embedding values as a list of floats.")
..., description="Embedding values as a list of floats."
)
index: int = Field(
...,
description="Index of the input text corresponding to "
"the embedding.")
..., description="Index of the input text corresponding to " "the embedding."
)
class EmbeddingsResponse(BaseModel):
object: str = Field("list", description="Type of the response object.")
data: List[EmbeddingObject] = Field(
..., description="List of embedding objects.")
data: List[EmbeddingObject] = Field(..., description="List of embedding objects.")
model: str = Field(..., description="Name of the embedding model used.")
usage: UsageInfo = Field(..., description="Information about token usage.")