From 300f0342337e0522e76c2cbd7ea7afdee7d8ccce Mon Sep 17 00:00:00 2001 From: kingbri Date: Tue, 23 Jul 2024 14:26:15 -0400 Subject: [PATCH] API: Add config option to select servers Always enable the core endpoints and allow servers to be selected as needed. Use the OAI server by default. Signed-off-by: kingbri --- config_sample.yml | 4 ++++ endpoints/server.py | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/config_sample.yml b/config_sample.yml index 12458de..c5d6c9c 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -23,6 +23,10 @@ network: # NOTE: Only enable this for debug purposes send_tracebacks: False + # Select API servers to enable (default: ["OAI"]) + # Possible values: OAI + api_servers: ["OAI"] + # Options for logging logging: # Enable prompt logging (default: False) diff --git a/endpoints/server.py b/endpoints/server.py index dc501f5..dbe35de 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -1,10 +1,13 @@ +from typing import List import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from loguru import logger +from common import config from common.logger import UVICORN_LOG_CONFIG from common.networking import get_global_depends +from common.utils import unwrap from endpoints.core.router import router as CoreRouter from endpoints.OAI.router import router as OAIRouter @@ -31,7 +34,19 @@ def setup_app(): allow_headers=["*"], ) - app.include_router(OAIRouter) + api_servers: List[str] = unwrap(config.network_config().get("api_servers"), []) + + # Map for API id to server router + router_mapping = {"oai": OAIRouter} + + # Include the OAI api by default + if api_servers: + for server in api_servers: + server_name = server.lower() + if server_name in router_mapping: + app.include_router(router_mapping[server_name]) + else: + app.include_router(OAIRouter) # Include core API request paths app.include_router(CoreRouter)