From 5bb4995a7c6d0d8fa486b4d4c3a3e062155afde5 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 6 Apr 2024 01:24:54 -0400 Subject: [PATCH] API: Move OAI to APIRouter This makes the API more modular for other API implementations in the future. Signed-off-by: kingbri --- endpoints/OAI/{app.py => router.py} | 89 +++++++++-------------------- endpoints/server.py | 47 +++++++++++++++ main.py | 2 +- 3 files changed, 74 insertions(+), 64 deletions(-) rename endpoints/OAI/{app.py => router.py} (88%) create mode 100644 endpoints/server.py diff --git a/endpoints/OAI/app.py b/endpoints/OAI/router.py similarity index 88% rename from endpoints/OAI/app.py rename to endpoints/OAI/router.py index 6c523d2..6d50e34 100644 --- a/endpoints/OAI/app.py +++ b/endpoints/OAI/router.py @@ -1,8 +1,6 @@ import asyncio import pathlib -import uvicorn -from fastapi import FastAPI, Depends, HTTPException, Header, Request -from fastapi.middleware.cors import CORSMiddleware +from fastapi import APIRouter, Depends, HTTPException, Header, Request from functools import partial from loguru import logger from sse_starlette import EventSourceResponse @@ -15,7 +13,6 @@ from common.concurrency import ( call_with_semaphore, generate_with_semaphore, ) -from common.logger import UVICORN_LOG_CONFIG from common.networking import handle_request_error, run_with_request_disconnect from common.templating import ( get_all_templates, @@ -56,23 +53,8 @@ from endpoints.OAI.utils.completion import ( from endpoints.OAI.utils.model import get_model_list, stream_model_load from endpoints.OAI.utils.lora import get_lora_list -app = FastAPI( - title="TabbyAPI", - summary="An OAI compatible exllamav2 API that's both lightweight and fast", - description=( - "This docs page is not meant to send requests! Please use a service " - "like Postman or a frontend UI." - ), -) -# ALlow CORS requests -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) +router = APIRouter() async def check_model_container(): @@ -90,8 +72,8 @@ async def check_model_container(): # Model list endpoint -@app.get("/v1/models", dependencies=[Depends(check_api_key)]) -@app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) +@router.get("/v1/models", dependencies=[Depends(check_api_key)]) +@router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) async def list_models(): """Lists all models in the model directory.""" model_config = config.model_config() @@ -108,7 +90,7 @@ async def list_models(): # Currently loaded model endpoint -@app.get( +@router.get( "/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) @@ -142,7 +124,7 @@ async def get_current_model(): return model_card -@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) +@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) async def list_draft_models(): """Lists all draft models in the model directory.""" draft_model_dir = unwrap( @@ -156,7 +138,7 @@ async def list_draft_models(): # Load model endpoint -@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) +@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) async def load_model(request: Request, data: ModelLoadRequest): """Loads a model into the model container.""" @@ -209,7 +191,7 @@ async def load_model(request: Request, data: ModelLoadRequest): # Unload model endpoint -@app.post( +@router.post( "/v1/model/unload", dependencies=[Depends(check_admin_key), Depends(check_model_container)], ) @@ -218,15 +200,15 @@ async def unload_model(): await model.unload_model() -@app.get("/v1/templates", dependencies=[Depends(check_api_key)]) -@app.get("/v1/template/list", dependencies=[Depends(check_api_key)]) +@router.get("/v1/templates", dependencies=[Depends(check_api_key)]) +@router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) async def get_templates(): templates = get_all_templates() template_strings = list(map(lambda template: template.stem, templates)) return TemplateList(data=template_strings) -@app.post( +@router.post( "/v1/template/switch", dependencies=[Depends(check_admin_key), Depends(check_model_container)], ) @@ -252,7 +234,7 @@ async def switch_template(data: TemplateSwitchRequest): raise HTTPException(400, error_message) from e -@app.post( +@router.post( "/v1/template/unload", dependencies=[Depends(check_admin_key), Depends(check_model_container)], ) @@ -263,15 +245,15 @@ async def unload_template(): # Sampler override endpoints -@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) -@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) +@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) +@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) async def list_sampler_overrides(): """API wrapper to list all currently applied sampler overrides""" return sampling.overrides -@app.post( +@router.post( "/v1/sampling/override/switch", dependencies=[Depends(check_admin_key)], ) @@ -300,7 +282,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest): raise HTTPException(400, error_message) -@app.post( +@router.post( "/v1/sampling/override/unload", dependencies=[Depends(check_admin_key)], ) @@ -311,8 +293,8 @@ async def unload_sampler_override(): # Lora list endpoint -@app.get("/v1/loras", dependencies=[Depends(check_api_key)]) -@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) +@router.get("/v1/loras", dependencies=[Depends(check_api_key)]) +@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) async def get_all_loras(): """Lists all LoRAs in the lora directory.""" lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) @@ -322,7 +304,7 @@ async def get_all_loras(): # Currently loaded loras endpoint -@app.get( +@router.get( "/v1/lora", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) @@ -344,7 +326,7 @@ async def get_active_loras(): # Load lora endpoint -@app.post( +@router.post( "/v1/lora/load", dependencies=[Depends(check_admin_key), Depends(check_model_container)], ) @@ -388,7 +370,7 @@ async def load_lora(data: LoraLoadRequest): # Unload lora endpoint -@app.post( +@router.post( "/v1/lora/unload", dependencies=[Depends(check_admin_key), Depends(check_model_container)], ) @@ -399,7 +381,7 @@ async def unload_loras(): # Encode tokens endpoint -@app.post( +@router.post( "/v1/token/encode", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) @@ -413,7 +395,7 @@ async def encode_tokens(data: TokenEncodeRequest): # Decode tokens endpoint -@app.post( +@router.post( "/v1/token/decode", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) @@ -425,7 +407,7 @@ async def decode_tokens(data: TokenDecodeRequest): return response -@app.get("/v1/auth/permission", dependencies=[Depends(check_api_key)]) +@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)]) async def get_key_permission( x_admin_key: Optional[str] = Header(None), x_api_key: Optional[str] = Header(None), @@ -452,7 +434,7 @@ async def get_key_permission( # Completions endpoint -@app.post( +@router.post( "/v1/completions", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) @@ -488,7 +470,7 @@ async def completion_request(request: Request, data: CompletionRequest): # Chat completions endpoint -@app.post( +@router.post( "/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(check_model_container)], ) @@ -536,22 +518,3 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest) disconnect_message="Chat completion generation cancelled by user.", ) return response - - -async def start_api(host: str, port: int): - """Isolated function to start the API server""" - - # TODO: Move OAI API to a separate folder - logger.info(f"Developer documentation: http://{host}:{port}/redoc") - logger.info(f"Completions: http://{host}:{port}/v1/completions") - logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions") - - config = uvicorn.Config( - app, - host=host, - port=port, - log_config=UVICORN_LOG_CONFIG, - ) - server = uvicorn.Server(config) - - await server.serve() diff --git a/endpoints/server.py b/endpoints/server.py new file mode 100644 index 0000000..f6515c2 --- /dev/null +++ b/endpoints/server.py @@ -0,0 +1,47 @@ +import uvicorn +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from loguru import logger + +from common.logger import UVICORN_LOG_CONFIG +from endpoints.OAI.router import router as OAIRouter + +app = FastAPI( + title="TabbyAPI", + summary="An OAI compatible exllamav2 API that's both lightweight and fast", + description=( + "This docs page is not meant to send requests! Please use a service " + "like Postman or a frontend UI." + ), +) + +# ALlow CORS requests +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) + + +async def start_api(host: str, port: int): + """Isolated function to start the API server""" + + # TODO: Move OAI API to a separate folder + logger.info(f"Developer documentation: http://{host}:{port}/redoc") + logger.info(f"Completions: http://{host}:{port}/v1/completions") + logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions") + + # Add OAI router + app.include_router(OAIRouter) + + config = uvicorn.Config( + app, + host=host, + port=port, + log_config=UVICORN_LOG_CONFIG, + ) + server = uvicorn.Server(config) + + await server.serve() diff --git a/main.py b/main.py index 5b95d87..330c736 100644 --- a/main.py +++ b/main.py @@ -15,7 +15,7 @@ from common.logger import setup_logger from common.networking import is_port_in_use from common.signals import signal_handler from common.utils import unwrap -from endpoints.OAI.app import start_api +from endpoints.server import start_api async def entrypoint(args: Optional[dict] = None):