Merge pull request #158 from AlpinDale/embeddings

feat: add embeddings support via Infinity-emb
This commit is contained in:
Brian Dashore 2024-07-31 20:33:12 -04:00 committed by GitHub
commit 1bf062559d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 443 additions and 11 deletions

View file

@ -0,0 +1,66 @@
import gc
import pathlib
import torch
from loguru import logger
from typing import List, Optional
from common.utils import unwrap
# Conditionally import infinity to sidestep its logger
# TODO: Make this prettier
try:
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
has_infinity_emb = True
except ImportError:
has_infinity_emb = False
class InfinityContainer:
model_dir: pathlib.Path
model_is_loading: bool = False
model_loaded: bool = False
# Conditionally set the type hint based on importablity
# TODO: Clean this up
if has_infinity_emb:
engine: Optional[AsyncEmbeddingEngine] = None
else:
engine = None
def __init__(self, model_directory: pathlib.Path):
self.model_dir = model_directory
async def load(self, **kwargs):
self.model_is_loading = True
# Use cpu by default
device = unwrap(kwargs.get("embeddings_device"), "cpu")
engine_args = EngineArgs(
model_name_or_path=str(self.model_dir),
engine="torch",
device=device,
bettertransformer=False,
model_warmup=False,
)
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
await self.engine.astart()
self.model_loaded = True
logger.info("Embedding model successfully loaded.")
async def unload(self):
await self.engine.astop()
self.engine = None
gc.collect()
torch.cuda.empty_cache()
logger.info("Embedding model unloaded.")
async def generate(self, sentence_input: List[str]):
result_embeddings, usage = await self.engine.embed(sentence_input)
return {"embeddings": result_embeddings, "usage": usage}

View file

@ -23,6 +23,7 @@ def init_argparser():
)
add_network_args(parser)
add_model_args(parser)
add_embeddings_args(parser)
add_logging_args(parser)
add_developer_args(parser)
add_sampling_args(parser)
@ -209,3 +210,22 @@ def add_sampling_args(parser: argparse.ArgumentParser):
sampling_group.add_argument(
"--override-preset", type=str, help="Select a sampler override preset"
)
def add_embeddings_args(parser: argparse.ArgumentParser):
"""Adds arguments specific to embeddings"""
embeddings_group = parser.add_argument_group("embeddings")
embeddings_group.add_argument(
"--embedding-model-dir",
type=str,
help="Overrides the directory to look for models",
)
embeddings_group.add_argument(
"--embedding-model-name", type=str, help="An initial model to load"
)
embeddings_group.add_argument(
"--embeddings-device",
type=str,
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
)

View file

@ -59,6 +59,11 @@ def from_args(args: dict):
cur_developer_config = developer_config()
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
embeddings_override = args.get("embeddings")
if embeddings_override:
cur_embeddings_config = embeddings_config()
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}
def sampling_config():
"""Returns the sampling parameter config from the global config"""
@ -95,3 +100,8 @@ def logging_config():
def developer_config():
"""Returns the developer specific config from the global config"""
return unwrap(GLOBAL_CONFIG.get("developer"), {})
def embeddings_config():
"""Returns the embeddings config from the global config"""
return unwrap(GLOBAL_CONFIG.get("embeddings"), {})

View file

@ -20,6 +20,15 @@ if not do_export_openapi:
# Global model container
container: Optional[ExllamaV2Container] = None
embeddings_container = None
# Type hint the infinity emb container if it exists
from backends.infinity.model import has_infinity_emb
if has_infinity_emb:
from backends.infinity.model import InfinityContainer
embeddings_container: Optional[InfinityContainer] = None
def load_progress(module, modules):
@ -48,8 +57,6 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
f'Model "{loaded_model_name}" is already loaded! Aborting.'
)
# Unload the existing model
if container and container.model:
logger.info("Unloading existing model.")
await unload_model()
@ -100,6 +107,41 @@ async def unload_loras():
await container.unload(loras_only=True)
async def load_embedding_model(model_path: pathlib.Path, **kwargs):
global embeddings_container
# Break out if infinity isn't installed
if not has_infinity_emb:
raise ImportError(
"Skipping embeddings because infinity-emb is not installed.\n"
"Please run the following command in your environment "
"to install extra packages:\n"
"pip install -U .[extras]"
)
# Check if the model is already loaded
if embeddings_container and embeddings_container.engine:
loaded_model_name = embeddings_container.model_dir.name
if loaded_model_name == model_path.name and embeddings_container.model_loaded:
raise ValueError(
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
)
logger.info("Unloading existing embeddings model.")
await unload_embedding_model()
embeddings_container = InfinityContainer(model_path)
await embeddings_container.load(**kwargs)
async def unload_embedding_model():
global embeddings_container
await embeddings_container.unload()
embeddings_container = None
def get_config_default(key, fallback=None, is_draft=False):
"""Fetches a default value from model config if allowed by the user."""
@ -126,3 +168,21 @@ async def check_model_container():
).error.message
raise HTTPException(400, error_message)
async def check_embeddings_container():
"""
FastAPI depends that checks if an embeddings model is loaded.
This is the same as the model container check, but with embeddings instead.
"""
if embeddings_container is None or not (
embeddings_container.model_is_loading or embeddings_container.model_loaded
):
error_message = handle_request_error(
"No embedding models are currently loaded.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)

View file

@ -1,16 +1,32 @@
import asyncio
import signal
import sys
from loguru import logger
from types import FrameType
from common import model
def signal_handler(*_):
"""Signal handler for main function. Run before uvicorn starts."""
logger.warning("Shutdown signal called. Exiting gracefully.")
# Run async unloads for model
asyncio.ensure_future(signal_handler_async())
# Exit the program
sys.exit(0)
async def signal_handler_async(*_):
if model.container:
await model.container.unload()
if model.embeddings_container:
await model.embeddings_container.unload()
def uvicorn_signal_handler(signal_event: signal.Signals):
"""Overrides uvicorn's signal handler."""

View file

@ -201,3 +201,19 @@ model:
#loras:
#- name: lora1
# scaling: 1.0
# Options for embedding models and loading.
# NOTE: Embeddings requires the "extras" feature to be installed
# Install it via "pip install .[extras]"
embeddings:
# Overrides directory to look for embedding models (default: models)
embedding_model_dir: models
# An initial embedding model to load on the infinity backend (default: None)
embedding_model_name:
# Device to load embedding models on (default: cpu)
# Possible values: cpu, auto, cuda
# NOTE: It's recommended to load embedding models on the CPU.
# If you'd like to load on an AMD gpu, set this value to "cuda" as well.
embeddings_device: cpu

View file

@ -5,7 +5,7 @@ from sys import maxsize
from common import config, model
from common.auth import check_api_key
from common.model import check_model_container
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.utils import unwrap
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
@ -13,6 +13,7 @@ from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
)
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
generate_chat_completion,
@ -22,6 +23,7 @@ from endpoints.OAI.utils.completion import (
generate_completion,
stream_generate_completion,
)
from endpoints.OAI.utils.embeddings import get_embeddings
api_name = "OAI"
@ -134,3 +136,19 @@ async def chat_completion_request(
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
)
return response
# Embeddings endpoint
@router.post(
"/v1/embeddings",
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
)
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
embeddings_task = asyncio.create_task(get_embeddings(data, request))
response = await run_with_request_disconnect(
request,
embeddings_task,
f"Embeddings request {request.state.id} cancelled by user.",
)
return response

View file

@ -0,0 +1,42 @@
from typing import List, Optional
from pydantic import BaseModel, Field
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class EmbeddingsRequest(BaseModel):
input: List[str] = Field(
..., description="List of input texts to generate embeddings for."
)
encoding_format: str = Field(
"float",
description="Encoding format for the embeddings. "
"Can be 'float' or 'base64'.",
)
model: Optional[str] = Field(
None,
description="Name of the embedding model to use. "
"If not provided, the default model will be used.",
)
class EmbeddingObject(BaseModel):
object: str = Field("embedding", description="Type of the object.")
embedding: List[float] = Field(
..., description="Embedding values as a list of floats."
)
index: int = Field(
..., description="Index of the input text corresponding to " "the embedding."
)
class EmbeddingsResponse(BaseModel):
object: str = Field("list", description="Type of the response object.")
data: List[EmbeddingObject] = Field(..., description="List of embedding objects.")
model: str = Field(..., description="Name of the embedding model used.")
usage: UsageInfo = Field(..., description="Information about token usage.")

View file

@ -0,0 +1,64 @@
"""
This file is derived from
[text-generation-webui openai extension embeddings](https://github.com/oobabooga/text-generation-webui/blob/1a7c027386f43b84f3ca3b0ff04ca48d861c2d7a/extensions/openai/embeddings.py)
and modified.
The changes introduced are: Suppression of progress bar,
typing/pydantic classes moved into this file,
embeddings function declared async.
"""
import base64
from fastapi import Request
import numpy as np
from loguru import logger
from common import model
from endpoints.OAI.types.embedding import (
EmbeddingObject,
EmbeddingsRequest,
EmbeddingsResponse,
UsageInfo,
)
def float_list_to_base64(float_array: np.ndarray) -> str:
"""
Converts the provided list to a float32 array for OpenAI
Ex. float_array = np.array(float_list, dtype="float32")
"""
# Encode raw bytes into base64
encoded_bytes = base64.b64encode(float_array.tobytes())
# Turn raw base64 encoded bytes into ASCII
ascii_string = encoded_bytes.decode("ascii")
return ascii_string
async def get_embeddings(data: EmbeddingsRequest, request: Request) -> dict:
model_path = model.embeddings_container.model_dir
logger.info(f"Recieved embeddings request {request.state.id}")
embedding_data = await model.embeddings_container.generate(data.input)
# OAI expects a return of base64 if the input is base64
embedding_object = [
EmbeddingObject(
embedding=float_list_to_base64(emb)
if data.encoding_format == "base64"
else emb.tolist(),
index=n,
)
for n, emb in enumerate(embedding_data.get("embeddings"))
]
usage = embedding_data.get("usage")
response = EmbeddingsResponse(
data=embedding_object,
model=model_path.name,
usage=UsageInfo(prompt_tokens=usage, total_tokens=usage),
)
logger.info(f"Finished embeddings request {request.state.id}")
return response

View file

@ -7,7 +7,7 @@ from sse_starlette import EventSourceResponse
from common import config, model, sampling
from common.auth import check_admin_key, check_api_key, get_key_permission
from common.downloader import hf_repo_download
from common.model import check_model_container
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import PromptTemplate, get_all_templates
from common.utils import unwrap
@ -15,6 +15,7 @@ from endpoints.core.types.auth import AuthPermissionResponse
from endpoints.core.types.download import DownloadRequest, DownloadResponse
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
from endpoints.core.types.model import (
EmbeddingModelLoadRequest,
ModelCard,
ModelList,
ModelLoadRequest,
@ -253,6 +254,93 @@ async def unload_loras():
await model.unload_loras()
@router.get("/v1/model/embedding/list", dependencies=[Depends(check_api_key)])
async def list_embedding_models(request: Request) -> ModelList:
"""
Lists all embedding models in the model directory.
Requires an admin key to see all embedding models.
"""
if get_key_permission(request) == "admin":
embedding_model_dir = unwrap(
config.embeddings_config().get("embedding_model_dir"), "models"
)
embedding_model_path = pathlib.Path(embedding_model_dir)
models = get_model_list(embedding_model_path.resolve())
else:
models = await get_current_model_list(model_type="embedding")
return models
@router.get(
"/v1/model/embedding",
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
)
async def get_embedding_model() -> ModelList:
"""Returns the currently loaded embedding model."""
return get_current_model_list(model_type="embedding")[0]
@router.post("/v1/model/embedding/load", dependencies=[Depends(check_admin_key)])
async def load_embedding_model(
request: Request, data: EmbeddingModelLoadRequest
) -> ModelLoadResponse:
# Verify request parameters
if not data.name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
embedding_model_dir = pathlib.Path(
unwrap(config.model_config().get("embedding_model_dir"), "models")
)
embedding_model_path = embedding_model_dir / data.name
if not embedding_model_path.exists():
error_message = handle_request_error(
"Could not find the embedding model path for load. "
+ "Check model name or config.yml?",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
try:
load_task = asyncio.create_task(
model.load_embedding_model(embedding_model_path, **data.model_dump())
)
await run_with_request_disconnect(
request, load_task, "Embedding model load request cancelled by user."
)
except Exception as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
response = ModelLoadResponse(
model_type="embedding_model", module=1, modules=1, status="finished"
)
return response
@router.post(
"/v1/model/embedding/unload",
dependencies=[Depends(check_admin_key), Depends(check_embeddings_container)],
)
async def unload_embedding_model():
"""Unloads the current embedding model."""
await model.unload_embedding_model()
# Encode tokens endpoint
@router.post(
"/v1/token/encode",

View file

@ -137,6 +137,11 @@ class ModelLoadRequest(BaseModel):
skip_queue: Optional[bool] = False
class EmbeddingModelLoadRequest(BaseModel):
name: str
embeddings_device: Optional[str] = None
class ModelLoadResponse(BaseModel):
"""Represents a model load response."""

View file

@ -32,15 +32,26 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
return model_card_list
async def get_current_model_list(is_draft: bool = False):
"""Gets the current model in list format and with path only."""
async def get_current_model_list(model_type: str = "model"):
"""
Gets the current model in list format and with path only.
Unified for fetching both models and embedding models.
"""
current_models = []
model_path = None
# Make sure the model container exists
if model.container:
model_path = model.container.get_model_path(is_draft)
if model_path:
current_models.append(ModelCard(id=model_path.name))
if model_type == "model" or model_type == "draft":
if model.container:
model_path = model.container.get_model_path(model_type == "draft")
elif model_type == "embedding":
if model.embeddings_container:
model_path = model.embeddings_container.model_dir
if model_path:
current_models.append(ModelCard(id=model_path.name))
return ModelList(data=current_models)

15
main.py
View file

@ -87,6 +87,21 @@ async def entrypoint_async():
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
await model.container.load_loras(lora_dir.resolve(), **lora_config)
# If an initial embedding model name is specified, create a separate container
# and load the model
embedding_config = config.embeddings_config()
embedding_model_name = embedding_config.get("embedding_model_name")
if embedding_model_name:
embedding_model_path = pathlib.Path(
unwrap(embedding_config.get("embedding_model_dir"), "models")
)
embedding_model_path = embedding_model_path / embedding_model_name
try:
await model.load_embedding_model(embedding_model_path, **embedding_config)
except ImportError as ex:
logger.error(ex.msg)
await start_api(host, port)

View file

@ -47,7 +47,8 @@ dependencies = [
[project.optional-dependencies]
extras = [
# Heavy dependencies that aren't for everyday use
"outlines"
"outlines",
"sentence-transformers"
]
dev = [
"ruff == 0.3.2"