Embeddings: Migrate and organize Infinity

Use Infinity as a separate backend and handle the model within the
common module. This separates out the embeddings model from the endpoint
which allows for model loading/unloading in core.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-30 11:00:23 -04:00
parent ac1afcc588
commit fbf1455db1
6 changed files with 165 additions and 83 deletions

View file

@ -20,6 +20,15 @@ if not do_export_openapi:
# Global model container
container: Optional[ExllamaV2Container] = None
embeddings_container = None
# Type hint the infinity emb container if it exists
from backends.infinity.model import has_infinity_emb
if has_infinity_emb:
from backends.infinity.model import InfinityContainer
embeddings_container: Optional[InfinityContainer] = None
def load_progress(module, modules):
@ -100,6 +109,30 @@ async def unload_loras():
await container.unload(loras_only=True)
async def load_embeddings_model(model_path: pathlib.Path, **kwargs):
global embeddings_container
# Break out if infinity isn't installed
if not has_infinity_emb:
logger.warning(
"Skipping embeddings because infinity-emb is not installed.\n"
"Please run the following command in your environment "
"to install extra packages:\n"
"pip install -U .[extras]"
)
return
embeddings_container = InfinityContainer(model_path)
await embeddings_container.load(**kwargs)
async def unload_embeddings_model():
global embeddings_container
await embeddings_container.unload()
embeddings_container = None
def get_config_default(key, fallback=None, is_draft=False):
"""Fetches a default value from model config if allowed by the user."""
@ -126,3 +159,15 @@ async def check_model_container():
).error.message
raise HTTPException(400, error_message)
async def check_embeddings_container():
"""FastAPI depends that checks if an embeddings model is loaded."""
if embeddings_container is None:
error_message = handle_request_error(
"No embeddings models are currently loaded.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)