From 279e900ea5ec095eff97d2996a966a1ea0aa663f Mon Sep 17 00:00:00 2001 From: Colin Kealty Date: Tue, 4 Jun 2024 13:35:48 -0400 Subject: [PATCH] Add on the fly model loading to requests --- endpoints/OAI/router.py | 49 +++++++++++++++++++++++++- endpoints/OAI/types/chat_completion.py | 1 + 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 0e4f27b..f4cc516 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -1,5 +1,6 @@ import asyncio import pathlib +from loguru import logger from fastapi import APIRouter, Depends, HTTPException, Header, Request from sse_starlette import EventSourceResponse from sys import maxsize @@ -118,7 +119,7 @@ async def completion_request( # Chat completions endpoint @router.post( "/v1/chat/completions", - dependencies=[Depends(check_api_key), Depends(check_model_container)], + dependencies=[Depends(check_api_key)], ) async def chat_completion_request( request: Request, data: ChatCompletionRequest @@ -129,6 +130,52 @@ async def chat_completion_request( If stream = true, this returns an SSE stream. """ + if data.model is not None and ( + model.container is None or model.container.get_model_path().name != data.model + ): + adminValid = False + if "x_admin_key" in request.headers.keys(): + try: + await check_admin_key( + x_admin_key=request.headers.get("x_admin_key"), authorization=None + ) + adminValid = True + except HTTPException: + pass + + if not adminValid and "authorization" in request.headers.keys(): + try: + await check_admin_key( + x_admin_key=None, authorization=request.headers.get("authorization") + ) + adminValid = True + except HTTPException: + pass + + if adminValid: + logger.info( + f"New request for {data.model} which is not loaded, proper admin key provided, loading new model" + ) + + model_path = pathlib.Path( + unwrap(config.model_config().get("model_dir"), "models") + ) + model_path = model_path / data.model + + if not model_path.exists(): + error_message = handle_request_error( + "Could not find the model path for load. Check model name.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) + + await model.load_model(model_path) + else: + logger.info(f"No valid admin key found to change loaded model, ignoring") + else: + await check_model_container() + if model.container.prompt_template is None: error_message = handle_request_error( "Chat completions are disabled because a prompt template is not set.", diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index be5cfea..b66277b 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -47,6 +47,7 @@ class ChatCompletionRequest(CommonCompletionRequest): add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = {} response_prefix: Optional[str] = None + model: Optional[str] = None class ChatCompletionResponse(BaseModel):