Embeddings: Migrate and organize Infinity
Use Infinity as a separate backend and handle the model within the common module. This separates out the embeddings model from the endpoint which allows for model loading/unloading in core. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
ac1afcc588
commit
fbf1455db1
6 changed files with 165 additions and 83 deletions
|
|
@ -20,6 +20,15 @@ if not do_export_openapi:
|
|||
|
||||
# Global model container
|
||||
container: Optional[ExllamaV2Container] = None
|
||||
embeddings_container = None
|
||||
|
||||
# Type hint the infinity emb container if it exists
|
||||
from backends.infinity.model import has_infinity_emb
|
||||
|
||||
if has_infinity_emb:
|
||||
from backends.infinity.model import InfinityContainer
|
||||
|
||||
embeddings_container: Optional[InfinityContainer] = None
|
||||
|
||||
|
||||
def load_progress(module, modules):
|
||||
|
|
@ -100,6 +109,30 @@ async def unload_loras():
|
|||
await container.unload(loras_only=True)
|
||||
|
||||
|
||||
async def load_embeddings_model(model_path: pathlib.Path, **kwargs):
|
||||
global embeddings_container
|
||||
|
||||
# Break out if infinity isn't installed
|
||||
if not has_infinity_emb:
|
||||
logger.warning(
|
||||
"Skipping embeddings because infinity-emb is not installed.\n"
|
||||
"Please run the following command in your environment "
|
||||
"to install extra packages:\n"
|
||||
"pip install -U .[extras]"
|
||||
)
|
||||
return
|
||||
|
||||
embeddings_container = InfinityContainer(model_path)
|
||||
await embeddings_container.load(**kwargs)
|
||||
|
||||
|
||||
async def unload_embeddings_model():
|
||||
global embeddings_container
|
||||
|
||||
await embeddings_container.unload()
|
||||
embeddings_container = None
|
||||
|
||||
|
||||
def get_config_default(key, fallback=None, is_draft=False):
|
||||
"""Fetches a default value from model config if allowed by the user."""
|
||||
|
||||
|
|
@ -126,3 +159,15 @@ async def check_model_container():
|
|||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
|
||||
async def check_embeddings_container():
|
||||
"""FastAPI depends that checks if an embeddings model is loaded."""
|
||||
|
||||
if embeddings_container is None:
|
||||
error_message = handle_request_error(
|
||||
"No embeddings models are currently loaded.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue