diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 5f6d1d1..f3ab99f 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -176,7 +176,12 @@ async def chat_completion_request( @router.get("/v1/models", dependencies=[Depends(check_api_key)]) @router.get("/v1/model/list", dependencies=[Depends(check_api_key)]) async def list_models(request: Request) -> ModelList: - """Lists all models in the model directory.""" + """ + Lists all models in the model directory. + + Requires an admin key to see all models. + """ + model_config = config.model_config() model_dir = unwrap(model_config.get("model_dir"), "models") model_path = pathlib.Path(model_dir) @@ -207,7 +212,11 @@ async def current_model() -> ModelCard: @router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) async def list_draft_models(request: Request) -> ModelList: - """Lists all draft models in the model directory.""" + """ + Lists all draft models in the model directory. + + Requires an admin key to see all draft models. + """ if get_key_permission(request) == "admin": draft_model_dir = unwrap( @@ -301,7 +310,11 @@ async def download_model(request: Request, data: DownloadRequest) -> DownloadRes @router.get("/v1/loras", dependencies=[Depends(check_api_key)]) @router.get("/v1/lora/list", dependencies=[Depends(check_api_key)]) async def list_all_loras(request: Request) -> LoraList: - """Lists all LoRAs in the lora directory.""" + """ + Lists all LoRAs in the lora directory. + + Requires an admin key to see all LoRAs. + """ if get_key_permission(request) == "admin": lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras")) @@ -406,6 +419,7 @@ async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse: ) async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse: """Decodes tokens into a string.""" + message = model.container.decode_tokens(data.tokens, **data.get_params()) response = TokenDecodeResponse(text=unwrap(message, "")) @@ -435,7 +449,11 @@ async def key_permission(request: Request) -> AuthPermissionResponse: @router.get("/v1/templates", dependencies=[Depends(check_api_key)]) @router.get("/v1/template/list", dependencies=[Depends(check_api_key)]) async def list_templates(request: Request) -> TemplateList: - """Get a list of all templates.""" + """ + Get a list of all templates. + + Requires an admin key to see all templates. + """ template_strings = [] if get_key_permission(request) == "admin": @@ -453,7 +471,7 @@ async def list_templates(request: Request) -> TemplateList: dependencies=[Depends(check_admin_key), Depends(check_model_container)], ) async def switch_template(data: TemplateSwitchRequest): - """Switch the currently loaded template""" + """Switch the currently loaded template.""" if not data.name: error_message = handle_request_error( @@ -488,7 +506,11 @@ async def unload_template(): @router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)]) @router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)]) async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse: - """API wrapper to list all currently applied sampler overrides""" + """ + List all currently applied sampler overrides. + + Requires an admin key to see all override presets. + """ if get_key_permission(request) == "admin": presets = sampling.get_all_presets()