OAI: Add return types for docs

Adding return types allows for responses to get included in the
autogenerated docs.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-08 15:23:41 -04:00
parent 62e495fc13
commit 521d21b9f2

View file

@ -12,8 +12,11 @@ from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import PromptTemplate, get_all_templates
from common.utils import coalesce, unwrap
from endpoints.OAI.types.auth import AuthPermissionResponse
from endpoints.OAI.types.completion import CompletionRequest
from endpoints.OAI.types.chat_completion import ChatCompletionRequest
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
)
from endpoints.OAI.types.download import DownloadRequest, DownloadResponse
from endpoints.OAI.types.lora import (
LoraCard,
@ -23,8 +26,10 @@ from endpoints.OAI.types.lora import (
)
from endpoints.OAI.types.model import (
ModelCard,
ModelList,
ModelLoadRequest,
ModelCardParameters,
ModelLoadResponse,
)
from endpoints.OAI.types.sampler_overrides import (
SamplerOverrideListResponse,
@ -70,7 +75,7 @@ async def check_model_container():
# Model list endpoint
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
async def list_models() -> ModelList:
"""Lists all models in the model directory."""
model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
@ -90,7 +95,7 @@ async def list_models():
"/v1/model",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_current_model():
async def get_current_model() -> ModelCard:
"""Returns the currently loaded model."""
model_params = model.container.get_model_parameters()
draft_model_params = model_params.pop("draft", {})
@ -121,7 +126,7 @@ async def get_current_model():
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models():
async def list_draft_models() -> ModelList:
"""Lists all draft models in the model directory."""
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
@ -135,8 +140,8 @@ async def list_draft_models():
# Load model endpoint
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest):
"""Loads a model into the model container."""
async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
"""Loads a model into the model container. This returns an SSE stream."""
# Verify request parameters
if not data.name:
@ -189,7 +194,7 @@ async def unload_model():
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def get_templates():
async def get_templates() -> TemplateList:
templates = get_all_templates()
template_strings = [template.stem for template in templates]
return TemplateList(data=template_strings)
@ -233,7 +238,7 @@ async def unload_template():
# Sampler override endpoints
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides():
async def list_sampler_overrides() -> SamplerOverrideListResponse:
"""API wrapper to list all currently applied sampler overrides"""
return SamplerOverrideListResponse(
@ -281,7 +286,7 @@ async def unload_sampler_override():
@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
async def download_model(request: Request, data: DownloadRequest):
async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse:
"""Downloads a model from HuggingFace."""
try:
@ -304,7 +309,7 @@ async def download_model(request: Request, data: DownloadRequest):
# Lora list endpoint
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def get_all_loras():
async def get_all_loras() -> LoraList:
"""Lists all LoRAs in the lora directory."""
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
@ -317,7 +322,7 @@ async def get_all_loras():
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_active_loras():
async def get_active_loras() -> LoraList:
"""Returns the currently loaded loras."""
active_loras = LoraList(
data=[
@ -337,7 +342,7 @@ async def get_active_loras():
"/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def load_lora(data: LoraLoadRequest):
async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
"""Loads a LoRA into the model container."""
if not data.loras:
@ -383,7 +388,7 @@ async def unload_loras():
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def encode_tokens(data: TokenEncodeRequest):
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
"""Encodes a string or chat completion messages into tokens."""
if isinstance(data.text, str):
@ -413,7 +418,7 @@ async def encode_tokens(data: TokenEncodeRequest):
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def decode_tokens(data: TokenDecodeRequest):
async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse:
"""Decodes tokens into a string."""
message = model.container.decode_tokens(data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))
@ -426,7 +431,7 @@ async def get_key_permission(
x_admin_key: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None),
authorization: Optional[str] = Header(None),
):
) -> AuthPermissionResponse:
"""
Gets the access level/permission of a provided key in headers.
@ -452,8 +457,15 @@ async def get_key_permission(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def completion_request(request: Request, data: CompletionRequest):
"""Generates a completion from a prompt."""
async def completion_request(
request: Request, data: CompletionRequest
) -> CompletionResponse:
"""
Generates a completion from a prompt.
If stream = true, this returns an SSE stream.
"""
model_path = model.container.get_model_path()
if isinstance(data.prompt, list):
@ -488,8 +500,14 @@ async def completion_request(request: Request, data: CompletionRequest):
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def chat_completion_request(request: Request, data: ChatCompletionRequest):
"""Generates a chat completion from a prompt."""
async def chat_completion_request(
request: Request, data: ChatCompletionRequest
) -> ChatCompletionResponse:
"""
Generates a chat completion from a prompt.
If stream = true, this returns an SSE stream.
"""
if model.container.prompt_template is None:
error_message = handle_request_error(