OAI: Add return types for docs
Adding return types allows for responses to get included in the autogenerated docs. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
62e495fc13
commit
521d21b9f2
1 changed files with 38 additions and 20 deletions
|
|
@ -12,8 +12,11 @@ from common.networking import handle_request_error, run_with_request_disconnect
|
|||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import coalesce, unwrap
|
||||
from endpoints.OAI.types.auth import AuthPermissionResponse
|
||||
from endpoints.OAI.types.completion import CompletionRequest
|
||||
from endpoints.OAI.types.chat_completion import ChatCompletionRequest
|
||||
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
)
|
||||
from endpoints.OAI.types.download import DownloadRequest, DownloadResponse
|
||||
from endpoints.OAI.types.lora import (
|
||||
LoraCard,
|
||||
|
|
@ -23,8 +26,10 @@ from endpoints.OAI.types.lora import (
|
|||
)
|
||||
from endpoints.OAI.types.model import (
|
||||
ModelCard,
|
||||
ModelList,
|
||||
ModelLoadRequest,
|
||||
ModelCardParameters,
|
||||
ModelLoadResponse,
|
||||
)
|
||||
from endpoints.OAI.types.sampler_overrides import (
|
||||
SamplerOverrideListResponse,
|
||||
|
|
@ -70,7 +75,7 @@ async def check_model_container():
|
|||
# Model list endpoint
|
||||
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models():
|
||||
async def list_models() -> ModelList:
|
||||
"""Lists all models in the model directory."""
|
||||
model_config = config.model_config()
|
||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||
|
|
@ -90,7 +95,7 @@ async def list_models():
|
|||
"/v1/model",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def get_current_model():
|
||||
async def get_current_model() -> ModelCard:
|
||||
"""Returns the currently loaded model."""
|
||||
model_params = model.container.get_model_parameters()
|
||||
draft_model_params = model_params.pop("draft", {})
|
||||
|
|
@ -121,7 +126,7 @@ async def get_current_model():
|
|||
|
||||
|
||||
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_draft_models():
|
||||
async def list_draft_models() -> ModelList:
|
||||
"""Lists all draft models in the model directory."""
|
||||
draft_model_dir = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
|
|
@ -135,8 +140,8 @@ async def list_draft_models():
|
|||
|
||||
# Load model endpoint
|
||||
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||
async def load_model(data: ModelLoadRequest):
|
||||
"""Loads a model into the model container."""
|
||||
async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
|
||||
"""Loads a model into the model container. This returns an SSE stream."""
|
||||
|
||||
# Verify request parameters
|
||||
if not data.name:
|
||||
|
|
@ -189,7 +194,7 @@ async def unload_model():
|
|||
|
||||
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
||||
async def get_templates():
|
||||
async def get_templates() -> TemplateList:
|
||||
templates = get_all_templates()
|
||||
template_strings = [template.stem for template in templates]
|
||||
return TemplateList(data=template_strings)
|
||||
|
|
@ -233,7 +238,7 @@ async def unload_template():
|
|||
# Sampler override endpoints
|
||||
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_sampler_overrides():
|
||||
async def list_sampler_overrides() -> SamplerOverrideListResponse:
|
||||
"""API wrapper to list all currently applied sampler overrides"""
|
||||
|
||||
return SamplerOverrideListResponse(
|
||||
|
|
@ -281,7 +286,7 @@ async def unload_sampler_override():
|
|||
|
||||
|
||||
@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
|
||||
async def download_model(request: Request, data: DownloadRequest):
|
||||
async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse:
|
||||
"""Downloads a model from HuggingFace."""
|
||||
|
||||
try:
|
||||
|
|
@ -304,7 +309,7 @@ async def download_model(request: Request, data: DownloadRequest):
|
|||
# Lora list endpoint
|
||||
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
async def get_all_loras():
|
||||
async def get_all_loras() -> LoraList:
|
||||
"""Lists all LoRAs in the lora directory."""
|
||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
|
|
@ -317,7 +322,7 @@ async def get_all_loras():
|
|||
"/v1/lora",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def get_active_loras():
|
||||
async def get_active_loras() -> LoraList:
|
||||
"""Returns the currently loaded loras."""
|
||||
active_loras = LoraList(
|
||||
data=[
|
||||
|
|
@ -337,7 +342,7 @@ async def get_active_loras():
|
|||
"/v1/lora/load",
|
||||
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
|
||||
)
|
||||
async def load_lora(data: LoraLoadRequest):
|
||||
async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
|
||||
"""Loads a LoRA into the model container."""
|
||||
|
||||
if not data.loras:
|
||||
|
|
@ -383,7 +388,7 @@ async def unload_loras():
|
|||
"/v1/token/encode",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def encode_tokens(data: TokenEncodeRequest):
|
||||
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
|
||||
"""Encodes a string or chat completion messages into tokens."""
|
||||
|
||||
if isinstance(data.text, str):
|
||||
|
|
@ -413,7 +418,7 @@ async def encode_tokens(data: TokenEncodeRequest):
|
|||
"/v1/token/decode",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def decode_tokens(data: TokenDecodeRequest):
|
||||
async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse:
|
||||
"""Decodes tokens into a string."""
|
||||
message = model.container.decode_tokens(data.tokens, **data.get_params())
|
||||
response = TokenDecodeResponse(text=unwrap(message, ""))
|
||||
|
|
@ -426,7 +431,7 @@ async def get_key_permission(
|
|||
x_admin_key: Optional[str] = Header(None),
|
||||
x_api_key: Optional[str] = Header(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
) -> AuthPermissionResponse:
|
||||
"""
|
||||
Gets the access level/permission of a provided key in headers.
|
||||
|
||||
|
|
@ -452,8 +457,15 @@ async def get_key_permission(
|
|||
"/v1/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def completion_request(request: Request, data: CompletionRequest):
|
||||
"""Generates a completion from a prompt."""
|
||||
async def completion_request(
|
||||
request: Request, data: CompletionRequest
|
||||
) -> CompletionResponse:
|
||||
"""
|
||||
Generates a completion from a prompt.
|
||||
|
||||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
model_path = model.container.get_model_path()
|
||||
|
||||
if isinstance(data.prompt, list):
|
||||
|
|
@ -488,8 +500,14 @@ async def completion_request(request: Request, data: CompletionRequest):
|
|||
"/v1/chat/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def chat_completion_request(request: Request, data: ChatCompletionRequest):
|
||||
"""Generates a chat completion from a prompt."""
|
||||
async def chat_completion_request(
|
||||
request: Request, data: ChatCompletionRequest
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Generates a chat completion from a prompt.
|
||||
|
||||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
if model.container.prompt_template is None:
|
||||
error_message = handle_request_error(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue