API: Add config option to select servers

Always enable the core endpoints and allow servers to be selected
as needed. Use the OAI server by default.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-23 14:26:15 -04:00
parent 9ad69e8ab6
commit 300f034233
2 changed files with 20 additions and 1 deletions

View file

@ -23,6 +23,10 @@ network:
# NOTE: Only enable this for debug purposes # NOTE: Only enable this for debug purposes
send_tracebacks: False send_tracebacks: False
# Select API servers to enable (default: ["OAI"])
# Possible values: OAI
api_servers: ["OAI"]
# Options for logging # Options for logging
logging: logging:
# Enable prompt logging (default: False) # Enable prompt logging (default: False)

View file

@ -1,10 +1,13 @@
from typing import List
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from loguru import logger from loguru import logger
from common import config
from common.logger import UVICORN_LOG_CONFIG from common.logger import UVICORN_LOG_CONFIG
from common.networking import get_global_depends from common.networking import get_global_depends
from common.utils import unwrap
from endpoints.core.router import router as CoreRouter from endpoints.core.router import router as CoreRouter
from endpoints.OAI.router import router as OAIRouter from endpoints.OAI.router import router as OAIRouter
@ -31,7 +34,19 @@ def setup_app():
allow_headers=["*"], allow_headers=["*"],
) )
app.include_router(OAIRouter) api_servers: List[str] = unwrap(config.network_config().get("api_servers"), [])
# Map for API id to server router
router_mapping = {"oai": OAIRouter}
# Include the OAI api by default
if api_servers:
for server in api_servers:
server_name = server.lower()
if server_name in router_mapping:
app.include_router(router_mapping[server_name])
else:
app.include_router(OAIRouter)
# Include core API request paths # Include core API request paths
app.include_router(CoreRouter) app.include_router(CoreRouter)