From 64c2cc85c9caa718bb66fac9687caa4584790364 Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 23 Jul 2024 12:40:32 -0400 Subject: [PATCH] OAI: Migrate model depends into proper file Use amongst multiple routers. Signed-off-by: kingbri --- common/model.py | 14 ++++++++++++++ endpoints/OAI/router.py | 15 +-------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/common/model.py b/common/model.py index b925f15..a6477c2 100644 --- a/common/model.py +++ b/common/model.py @@ -5,11 +5,13 @@ Containers exist as a common interface for backends. """ import pathlib +from fastapi import HTTPException from loguru import logger from typing import Optional from common import config from common.logger import get_loading_progress_bar +from common.networking import handle_request_error from common.utils import unwrap from endpoints.utils import do_export_openapi @@ -112,3 +114,15 @@ def get_config_default(key, fallback=None, is_draft=False): return unwrap(model_config.get(key), fallback) else: return fallback + + +async def check_model_container(): + """FastAPI depends that checks if a model isn't loaded or currently loading.""" + + if container is None or not (container.model_is_loading or container.model_loaded): + error_message = handle_request_error( + "No models are currently loaded.", + exc_info=False, + ).error.message + + raise HTTPException(400, error_message) diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 1297d87..4269c89 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -7,6 +7,7 @@ from sys import maxsize from common import config, model, sampling from common.auth import check_admin_key, check_api_key, get_key_permission from common.downloader import hf_repo_download +from common.model import check_model_container from common.networking import handle_request_error, run_with_request_disconnect from common.templating import PromptTemplate, get_all_templates from common.utils import unwrap @@ -60,20 +61,6 @@ from endpoints.OAI.utils.lora import get_active_loras, get_lora_list router = APIRouter() -async def check_model_container(): - """FastAPI depends that checks if a model isn't loaded or currently loading.""" - - if model.container is None or not ( - model.container.model_is_loading or model.container.model_loaded - ): - error_message = handle_request_error( - "No models are currently loaded.", - exc_info=False, - ).error.message - - raise HTTPException(400, error_message) - - # Completions endpoint @router.post( "/v1/completions",