API: Add setup function to routers
This helps prepare the router before exposing it to the parent app. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
6365427d38
commit
2773517a16
3 changed files with 33 additions and 18 deletions
|
|
@ -29,9 +29,20 @@ urls = {
|
|||
"Streaming": "http://{host}:{port}/api/extra/generate/stream",
|
||||
}
|
||||
|
||||
kai_router = APIRouter()
|
||||
extra_kai_router = APIRouter()
|
||||
|
||||
@router.post(
|
||||
"/v1/generate",
|
||||
|
||||
def setup():
|
||||
router.include_router(kai_router, prefix="/v1")
|
||||
router.include_router(kai_router, prefix="/latest", include_in_schema=False)
|
||||
router.include_router(extra_kai_router, prefix="/extra")
|
||||
|
||||
return router
|
||||
|
||||
|
||||
@kai_router.post(
|
||||
"/generate",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def generate(request: Request, data: GenerateRequest) -> GenerateResponse:
|
||||
|
|
@ -40,8 +51,8 @@ async def generate(request: Request, data: GenerateRequest) -> GenerateResponse:
|
|||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extra/generate/stream",
|
||||
@extra_kai_router.post(
|
||||
"/generate/stream",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse:
|
||||
|
|
@ -50,8 +61,8 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe
|
|||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extra/abort",
|
||||
@extra_kai_router.post(
|
||||
"/abort",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def abort_generate(data: AbortRequest):
|
||||
|
|
@ -60,12 +71,12 @@ async def abort_generate(data: AbortRequest):
|
|||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/extra/generate/check",
|
||||
@extra_kai_router.get(
|
||||
"/generate/check",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
@router.post(
|
||||
"/extra/generate/check",
|
||||
@extra_kai_router.post(
|
||||
"/generate/check",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
|
||||
|
|
@ -74,8 +85,8 @@ async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
|
|||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
|
||||
@kai_router.get(
|
||||
"/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
|
||||
)
|
||||
async def current_model():
|
||||
"""Fetches the current model and who owns it."""
|
||||
|
|
@ -84,8 +95,8 @@ async def current_model():
|
|||
return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extra/tokencount",
|
||||
@extra_kai_router.post(
|
||||
"/tokencount",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def get_tokencount(data: TokenCountRequest):
|
||||
|
|
@ -94,14 +105,14 @@ async def get_tokencount(data: TokenCountRequest):
|
|||
return TokenCountResponse(value=len(tokens), ids=tokens)
|
||||
|
||||
|
||||
@router.get("/v1/info/version")
|
||||
@kai_router.get("/info/version")
|
||||
async def get_version():
|
||||
"""Impersonate KAI United."""
|
||||
|
||||
return {"result": "1.2.5"}
|
||||
|
||||
|
||||
@router.get("/extra/version")
|
||||
@extra_kai_router.get("/version")
|
||||
async def get_extra_version():
|
||||
"""Impersonate Koboldcpp."""
|
||||
|
||||
|
|
|
|||
|
|
@ -32,6 +32,10 @@ urls = {
|
|||
}
|
||||
|
||||
|
||||
def setup():
|
||||
return router
|
||||
|
||||
|
||||
# Completions endpoint
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
|
|
|
|||
|
|
@ -47,14 +47,14 @@ def setup_app(host: Optional[str] = None, port: Optional[int] = None):
|
|||
selected_server = router_mapping.get(server.lower())
|
||||
|
||||
if selected_server:
|
||||
app.include_router(selected_server.router)
|
||||
app.include_router(selected_server.setup())
|
||||
|
||||
logger.info(f"Starting {selected_server.api_name} API")
|
||||
for path, url in selected_server.urls.items():
|
||||
formatted_url = url.format(host=host, port=port)
|
||||
logger.info(f"{path}: {formatted_url}")
|
||||
else:
|
||||
app.include_router(OAIRouter.router)
|
||||
app.include_router(OAIRouter.setup())
|
||||
for path, url in OAIRouter.urls.items():
|
||||
formatted_url = url.format(host=host, port=port)
|
||||
logger.info(f"{path}: {formatted_url}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue