API: Add setup function to routers

This helps prepare the router before exposing it to the parent app.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-26 22:24:33 -04:00
parent 6365427d38
commit 2773517a16
3 changed files with 33 additions and 18 deletions

View file

@ -29,9 +29,20 @@ urls = {
"Streaming": "http://{host}:{port}/api/extra/generate/stream",
}
kai_router = APIRouter()
extra_kai_router = APIRouter()
@router.post(
"/v1/generate",
def setup():
router.include_router(kai_router, prefix="/v1")
router.include_router(kai_router, prefix="/latest", include_in_schema=False)
router.include_router(extra_kai_router, prefix="/extra")
return router
@kai_router.post(
"/generate",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def generate(request: Request, data: GenerateRequest) -> GenerateResponse:
@ -40,8 +51,8 @@ async def generate(request: Request, data: GenerateRequest) -> GenerateResponse:
return response
@router.post(
"/extra/generate/stream",
@extra_kai_router.post(
"/generate/stream",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse:
@ -50,8 +61,8 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe
return response
@router.post(
"/extra/abort",
@extra_kai_router.post(
"/abort",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def abort_generate(data: AbortRequest):
@ -60,12 +71,12 @@ async def abort_generate(data: AbortRequest):
return response
@router.get(
"/extra/generate/check",
@extra_kai_router.get(
"/generate/check",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@router.post(
"/extra/generate/check",
@extra_kai_router.post(
"/generate/check",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
@ -74,8 +85,8 @@ async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
return response
@router.get(
"/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
@kai_router.get(
"/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
)
async def current_model():
"""Fetches the current model and who owns it."""
@ -84,8 +95,8 @@ async def current_model():
return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"}
@router.post(
"/extra/tokencount",
@extra_kai_router.post(
"/tokencount",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_tokencount(data: TokenCountRequest):
@ -94,14 +105,14 @@ async def get_tokencount(data: TokenCountRequest):
return TokenCountResponse(value=len(tokens), ids=tokens)
@router.get("/v1/info/version")
@kai_router.get("/info/version")
async def get_version():
"""Impersonate KAI United."""
return {"result": "1.2.5"}
@router.get("/extra/version")
@extra_kai_router.get("/version")
async def get_extra_version():
"""Impersonate Koboldcpp."""

View file

@ -32,6 +32,10 @@ urls = {
}
def setup():
return router
# Completions endpoint
@router.post(
"/v1/completions",

View file

@ -47,14 +47,14 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None):
selected_server = router_mapping.get(server.lower())
if selected_server:
app.include_router(selected_server.router)
app.include_router(selected_server.setup())
logger.info(f"Starting {selected_server.api_name} API")
for path, url in selected_server.urls.items():
formatted_url = url.format(host=host, port=port)
logger.info(f"{path}: {formatted_url}")
else:
app.include_router(OAIRouter.router)
app.include_router(OAIRouter.setup())
for path, url in OAIRouter.urls.items():
formatted_url = url.format(host=host, port=port)
logger.info(f"{path}: {formatted_url}")