API: Move OAI to APIRouter

This makes the API more modular for other API implementations in the
future.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-06 01:24:54 -04:00
parent 8bdc19124f
commit 5bb4995a7c
3 changed files with 74 additions and 64 deletions

View file

@ -1,8 +1,6 @@
import asyncio
import pathlib
import uvicorn
from fastapi import FastAPI, Depends, HTTPException, Header, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi import APIRouter, Depends, HTTPException, Header, Request
from functools import partial
from loguru import logger
from sse_starlette import EventSourceResponse
@ -15,7 +13,6 @@ from common.concurrency import (
call_with_semaphore,
generate_with_semaphore,
)
from common.logger import UVICORN_LOG_CONFIG
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import (
get_all_templates,
@ -56,23 +53,8 @@ from endpoints.OAI.utils.completion import (
from endpoints.OAI.utils.model import get_model_list, stream_model_load
from endpoints.OAI.utils.lora import get_lora_list
app = FastAPI(
title="TabbyAPI",
summary="An OAI compatible exllamav2 API that's both lightweight and fast",
description=(
"This docs page is not meant to send requests! Please use a service "
"like Postman or a frontend UI."
),
)
# ALlow CORS requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
router = APIRouter()
async def check_model_container():
@ -90,8 +72,8 @@ async def check_model_container():
# Model list endpoint
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
"""Lists all models in the model directory."""
model_config = config.model_config()
@ -108,7 +90,7 @@ async def list_models():
# Currently loaded model endpoint
@app.get(
@router.get(
"/v1/model",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@ -142,7 +124,7 @@ async def get_current_model():
return model_card
@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models():
"""Lists all draft models in the model directory."""
draft_model_dir = unwrap(
@ -156,7 +138,7 @@ async def list_draft_models():
# Load model endpoint
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(request: Request, data: ModelLoadRequest):
"""Loads a model into the model container."""
@ -209,7 +191,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
# Unload model endpoint
@app.post(
@router.post(
"/v1/model/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
@ -218,15 +200,15 @@ async def unload_model():
await model.unload_model()
@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
@app.get("/v1/template/list", dependencies=[Depends(check_api_key)])
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def get_templates():
templates = get_all_templates()
template_strings = list(map(lambda template: template.stem, templates))
return TemplateList(data=template_strings)
@app.post(
@router.post(
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
@ -252,7 +234,7 @@ async def switch_template(data: TemplateSwitchRequest):
raise HTTPException(400, error_message) from e
@app.post(
@router.post(
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
@ -263,15 +245,15 @@ async def unload_template():
# Sampler override endpoints
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides():
"""API wrapper to list all currently applied sampler overrides"""
return sampling.overrides
@app.post(
@router.post(
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
@ -300,7 +282,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
raise HTTPException(400, error_message)
@app.post(
@router.post(
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
@ -311,8 +293,8 @@ async def unload_sampler_override():
# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def get_all_loras():
"""Lists all LoRAs in the lora directory."""
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
@ -322,7 +304,7 @@ async def get_all_loras():
# Currently loaded loras endpoint
@app.get(
@router.get(
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@ -344,7 +326,7 @@ async def get_active_loras():
# Load lora endpoint
@app.post(
@router.post(
"/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
@ -388,7 +370,7 @@ async def load_lora(data: LoraLoadRequest):
# Unload lora endpoint
@app.post(
@router.post(
"/v1/lora/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
@ -399,7 +381,7 @@ async def unload_loras():
# Encode tokens endpoint
@app.post(
@router.post(
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@ -413,7 +395,7 @@ async def encode_tokens(data: TokenEncodeRequest):
# Decode tokens endpoint
@app.post(
@router.post(
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@ -425,7 +407,7 @@ async def decode_tokens(data: TokenDecodeRequest):
return response
@app.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
async def get_key_permission(
x_admin_key: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None),
@ -452,7 +434,7 @@ async def get_key_permission(
# Completions endpoint
@app.post(
@router.post(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@ -488,7 +470,7 @@ async def completion_request(request: Request, data: CompletionRequest):
# Chat completions endpoint
@app.post(
@router.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@ -536,22 +518,3 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
disconnect_message="Chat completion generation cancelled by user.",
)
return response
async def start_api(host: str, port: int):
"""Isolated function to start the API server"""
# TODO: Move OAI API to a separate folder
logger.info(f"Developer documentation: http://{host}:{port}/redoc")
logger.info(f"Completions: http://{host}:{port}/v1/completions")
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
config = uvicorn.Config(
app,
host=host,
port=port,
log_config=UVICORN_LOG_CONFIG,
)
server = uvicorn.Server(config)
await server.serve()

47
endpoints/server.py Normal file
View file

@ -0,0 +1,47 @@
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from common.logger import UVICORN_LOG_CONFIG
from endpoints.OAI.router import router as OAIRouter
app = FastAPI(
title="TabbyAPI",
summary="An OAI compatible exllamav2 API that's both lightweight and fast",
description=(
"This docs page is not meant to send requests! Please use a service "
"like Postman or a frontend UI."
),
)
# ALlow CORS requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
async def start_api(host: str, port: int):
"""Isolated function to start the API server"""
# TODO: Move OAI API to a separate folder
logger.info(f"Developer documentation: http://{host}:{port}/redoc")
logger.info(f"Completions: http://{host}:{port}/v1/completions")
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")
# Add OAI router
app.include_router(OAIRouter)
config = uvicorn.Config(
app,
host=host,
port=port,
log_config=UVICORN_LOG_CONFIG,
)
server = uvicorn.Server(config)
await server.serve()

View file

@ -15,7 +15,7 @@ from common.logger import setup_logger
from common.networking import is_port_in_use
from common.signals import signal_handler
from common.utils import unwrap
from endpoints.OAI.app import start_api
from endpoints.server import start_api
async def entrypoint(args: Optional[dict] = None):