Embeddings: Add model load checks

Same as the normal model container.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-30 11:17:36 -04:00
parent 01c7702859
commit f13d0fb8b3
3 changed files with 17 additions and 4 deletions

View file

@ -18,6 +18,8 @@ except ImportError:
class InfinityContainer:
model_dir: pathlib.Path
model_is_loading: bool = False
model_loaded: bool = False
# Conditionally set the type hint based on importablity
# TODO: Clean this up
@ -30,6 +32,8 @@ class InfinityContainer:
self.model_dir = model_directory
async def load(self, **kwargs):
self.model_is_loading = True
# Use cpu by default
device = unwrap(kwargs.get("device"), "cpu")
@ -44,6 +48,9 @@ class InfinityContainer:
self.engine = AsyncEmbeddingEngine.from_args(engine_args)
await self.engine.astart()
self.model_loaded = True
logger.info("Embedding model successfully loaded.")
async def unload(self):
await self.engine.astop()
self.engine = None

View file

@ -162,9 +162,15 @@ async def check_model_container():
async def check_embeddings_container():
"""FastAPI depends that checks if an embeddings model is loaded."""
"""
FastAPI depends that checks if an embeddings model is loaded.
if embeddings_container is None:
This is the same as the model container check, but with embeddings instead.
"""
if embeddings_container is None or not (
embeddings_container.model_is_loading or embeddings_container.model_loaded
):
error_message = handle_request_error(
"No embeddings models are currently loaded.",
exc_info=False,

View file

@ -5,7 +5,7 @@ from sys import maxsize
from common import config, model
from common.auth import check_api_key
from common.model import check_model_container
from common.model import check_embeddings_container, check_model_container
from common.networking import handle_request_error, run_with_request_disconnect
from common.utils import unwrap
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
@ -132,7 +132,7 @@ async def chat_completion_request(
# Embeddings endpoint
@router.post(
"/v1/embeddings",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
dependencies=[Depends(check_api_key), Depends(check_embeddings_container)],
)
async def embeddings(request: Request, data: EmbeddingsRequest) -> EmbeddingsResponse:
embeddings_task = asyncio.create_task(get_embeddings(data, request))