API: Move OAI to APIRouter
This makes the API more modular for other API implementations in the future. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
8bdc19124f
commit
5bb4995a7c
3 changed files with 74 additions and 64 deletions
|
|
@ -1,8 +1,6 @@
|
|||
import asyncio
|
||||
import pathlib
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Depends, HTTPException, Header, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||
from functools import partial
|
||||
from loguru import logger
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
|
@ -15,7 +13,6 @@ from common.concurrency import (
|
|||
call_with_semaphore,
|
||||
generate_with_semaphore,
|
||||
)
|
||||
from common.logger import UVICORN_LOG_CONFIG
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.templating import (
|
||||
get_all_templates,
|
||||
|
|
@ -56,23 +53,8 @@ from endpoints.OAI.utils.completion import (
|
|||
from endpoints.OAI.utils.model import get_model_list, stream_model_load
|
||||
from endpoints.OAI.utils.lora import get_lora_list
|
||||
|
||||
app = FastAPI(
|
||||
title="TabbyAPI",
|
||||
summary="An OAI compatible exllamav2 API that's both lightweight and fast",
|
||||
description=(
|
||||
"This docs page is not meant to send requests! Please use a service "
|
||||
"like Postman or a frontend UI."
|
||||
),
|
||||
)
|
||||
|
||||
# ALlow CORS requests
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
async def check_model_container():
|
||||
|
|
@ -90,8 +72,8 @@ async def check_model_container():
|
|||
|
||||
|
||||
# Model list endpoint
|
||||
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models():
|
||||
"""Lists all models in the model directory."""
|
||||
model_config = config.model_config()
|
||||
|
|
@ -108,7 +90,7 @@ async def list_models():
|
|||
|
||||
|
||||
# Currently loaded model endpoint
|
||||
@app.get(
|
||||
@router.get(
|
||||
"/v1/model",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
|
|
@ -142,7 +124,7 @@ async def get_current_model():
|
|||
return model_card
|
||||
|
||||
|
||||
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_draft_models():
|
||||
"""Lists all draft models in the model directory."""
|
||||
draft_model_dir = unwrap(
|
||||
|
|
@ -156,7 +138,7 @@ async def list_draft_models():
|
|||
|
||||
|
||||
# Load model endpoint
|
||||
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||
async def load_model(request: Request, data: ModelLoadRequest):
|
||||
"""Loads a model into the model container."""
|
||||
|
||||
|
|
@ -209,7 +191,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
|
||||
|
||||
# Unload model endpoint
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/model/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
|
|
@ -218,15 +200,15 @@ async def unload_model():
|
|||
await model.unload_model()
|
||||
|
||||
|
||||
@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
||||
async def get_templates():
|
||||
templates = get_all_templates()
|
||||
template_strings = list(map(lambda template: template.stem, templates))
|
||||
return TemplateList(data=template_strings)
|
||||
|
||||
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/template/switch",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
|
|
@ -252,7 +234,7 @@ async def switch_template(data: TemplateSwitchRequest):
|
|||
raise HTTPException(400, error_message) from e
|
||||
|
||||
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/template/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
|
|
@ -263,15 +245,15 @@ async def unload_template():
|
|||
|
||||
|
||||
# Sampler override endpoints
|
||||
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_sampler_overrides():
|
||||
"""API wrapper to list all currently applied sampler overrides"""
|
||||
|
||||
return sampling.overrides
|
||||
|
||||
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/sampling/override/switch",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
|
|
@ -300,7 +282,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
|
|||
raise HTTPException(400, error_message)
|
||||
|
||||
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/sampling/override/unload",
|
||||
dependencies=[Depends(check_admin_key)],
|
||||
)
|
||||
|
|
@ -311,8 +293,8 @@ async def unload_sampler_override():
|
|||
|
||||
|
||||
# Lora list endpoint
|
||||
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
async def get_all_loras():
|
||||
"""Lists all LoRAs in the lora directory."""
|
||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
|
|
@ -322,7 +304,7 @@ async def get_all_loras():
|
|||
|
||||
|
||||
# Currently loaded loras endpoint
|
||||
@app.get(
|
||||
@router.get(
|
||||
"/v1/lora",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
|
|
@ -344,7 +326,7 @@ async def get_active_loras():
|
|||
|
||||
|
||||
# Load lora endpoint
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/lora/load",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
|
|
@ -388,7 +370,7 @@ async def load_lora(data: LoraLoadRequest):
|
|||
|
||||
|
||||
# Unload lora endpoint
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/lora/unload",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
|
|
@ -399,7 +381,7 @@ async def unload_loras():
|
|||
|
||||
|
||||
# Encode tokens endpoint
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/token/encode",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
|
|
@ -413,7 +395,7 @@ async def encode_tokens(data: TokenEncodeRequest):
|
|||
|
||||
|
||||
# Decode tokens endpoint
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/token/decode",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
|
|
@ -425,7 +407,7 @@ async def decode_tokens(data: TokenDecodeRequest):
|
|||
return response
|
||||
|
||||
|
||||
@app.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
|
||||
async def get_key_permission(
|
||||
x_admin_key: Optional[str] = Header(None),
|
||||
x_api_key: Optional[str] = Header(None),
|
||||
|
|
@ -452,7 +434,7 @@ async def get_key_permission(
|
|||
|
||||
|
||||
# Completions endpoint
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
|
|
@ -488,7 +470,7 @@ async def completion_request(request: Request, data: CompletionRequest):
|
|||
|
||||
|
||||
# Chat completions endpoint
|
||||
@app.post(
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
|
|
@ -536,22 +518,3 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
|
|||
disconnect_message="Chat completion generation cancelled by user.",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
async def start_api(host: str, port: int):
|
||||
"""Isolated function to start the API server"""
|
||||
|
||||
# TODO: Move OAI API to a separate folder
|
||||
logger.info(f"Developer documentation: http://{host}:{port}/redoc")
|
||||
logger.info(f"Completions: http://{host}:{port}/v1/completions")
|
||||
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
|
||||
|
||||
config = uvicorn.Config(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_config=UVICORN_LOG_CONFIG,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
await server.serve()
|
||||
47
endpoints/server.py
Normal file
47
endpoints/server.py
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
|
||||
from common.logger import UVICORN_LOG_CONFIG
|
||||
from endpoints.OAI.router import router as OAIRouter
|
||||
|
||||
app = FastAPI(
|
||||
title="TabbyAPI",
|
||||
summary="An OAI compatible exllamav2 API that's both lightweight and fast",
|
||||
description=(
|
||||
"This docs page is not meant to send requests! Please use a service "
|
||||
"like Postman or a frontend UI."
|
||||
),
|
||||
)
|
||||
|
||||
# ALlow CORS requests
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
async def start_api(host: str, port: int):
|
||||
"""Isolated function to start the API server"""
|
||||
|
||||
# TODO: Move OAI API to a separate folder
|
||||
logger.info(f"Developer documentation: http://{host}:{port}/redoc")
|
||||
logger.info(f"Completions: http://{host}:{port}/v1/completions")
|
||||
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
|
||||
|
||||
# Add OAI router
|
||||
app.include_router(OAIRouter)
|
||||
|
||||
config = uvicorn.Config(
|
||||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_config=UVICORN_LOG_CONFIG,
|
||||
)
|
||||
server = uvicorn.Server(config)
|
||||
|
||||
await server.serve()
|
||||
2
main.py
2
main.py
|
|
@ -15,7 +15,7 @@ from common.logger import setup_logger
|
|||
from common.networking import is_port_in_use
|
||||
from common.signals import signal_handler
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.app import start_api
|
||||
from endpoints.server import start_api
|
||||
|
||||
|
||||
async def entrypoint(args: Optional[dict] = None):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue