API: Add config option to select servers
Always enable the core endpoints and allow servers to be selected as needed. Use the OAI server by default. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
9ad69e8ab6
commit
300f034233
2 changed files with 20 additions and 1 deletions
|
|
@ -23,6 +23,10 @@ network:
|
||||||
# NOTE: Only enable this for debug purposes
|
# NOTE: Only enable this for debug purposes
|
||||||
send_tracebacks: False
|
send_tracebacks: False
|
||||||
|
|
||||||
|
# Select API servers to enable (default: ["OAI"])
|
||||||
|
# Possible values: OAI
|
||||||
|
api_servers: ["OAI"]
|
||||||
|
|
||||||
# Options for logging
|
# Options for logging
|
||||||
logging:
|
logging:
|
||||||
# Enable prompt logging (default: False)
|
# Enable prompt logging (default: False)
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,13 @@
|
||||||
|
from typing import List
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
from common import config
|
||||||
from common.logger import UVICORN_LOG_CONFIG
|
from common.logger import UVICORN_LOG_CONFIG
|
||||||
from common.networking import get_global_depends
|
from common.networking import get_global_depends
|
||||||
|
from common.utils import unwrap
|
||||||
from endpoints.core.router import router as CoreRouter
|
from endpoints.core.router import router as CoreRouter
|
||||||
from endpoints.OAI.router import router as OAIRouter
|
from endpoints.OAI.router import router as OAIRouter
|
||||||
|
|
||||||
|
|
@ -31,7 +34,19 @@ def setup_app():
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
app.include_router(OAIRouter)
|
api_servers: List[str] = unwrap(config.network_config().get("api_servers"), [])
|
||||||
|
|
||||||
|
# Map for API id to server router
|
||||||
|
router_mapping = {"oai": OAIRouter}
|
||||||
|
|
||||||
|
# Include the OAI api by default
|
||||||
|
if api_servers:
|
||||||
|
for server in api_servers:
|
||||||
|
server_name = server.lower()
|
||||||
|
if server_name in router_mapping:
|
||||||
|
app.include_router(router_mapping[server_name])
|
||||||
|
else:
|
||||||
|
app.include_router(OAIRouter)
|
||||||
|
|
||||||
# Include core API request paths
|
# Include core API request paths
|
||||||
app.include_router(CoreRouter)
|
app.include_router(CoreRouter)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue