Auth: Make key permission check work on Requests

Pass a request and internally unwrap the headers. In addition, allow
X-admin-key to get checked in an API key request.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-11 11:16:24 -04:00
parent ff15eed85d
commit b9a58ff01b
2 changed files with 35 additions and 17 deletions

View file

@ -5,11 +5,13 @@ application, it should be fine.
import secrets
import yaml
from fastapi import Header, HTTPException
from fastapi import Header, HTTPException, Request
from pydantic import BaseModel
from loguru import logger
from typing import Optional
from common.utils import coalesce
class AuthKeys(BaseModel):
"""
@ -75,7 +77,23 @@ def load_auth_keys(disable_from_config: bool):
)
async def validate_key_permission(test_key: str):
def get_key_permission(request: Request):
"""
Gets the key permission from a request.
Internal only! Use the depends functions for incoming requests.
"""
# Hyphens are okay here
test_key = coalesce(
request.headers.get("authorization"),
request.headers.get("x-admin-key"),
request.headers.get("x-api-key"),
)
if test_key is None:
raise ValueError("The provided authentication key is missing.")
if test_key.lower().startswith("bearer"):
test_key = test_key.split(" ")[1]
@ -88,7 +106,9 @@ async def validate_key_permission(test_key: str):
async def check_api_key(
x_api_key: str = Header(None), authorization: str = Header(None)
x_api_key: str = Header(None),
x_admin_key: str = Header(None),
authorization: str = Header(None),
):
"""Check if the API key is valid."""
@ -101,6 +121,11 @@ async def check_api_key(
raise HTTPException(401, "Invalid API key")
return x_api_key
if x_admin_key:
if not AUTH_KEYS.verify_key(x_admin_key, "admin_key"):
raise HTTPException(401, "Invalid API key")
return x_admin_key
if authorization:
split_key = authorization.split(" ")
if len(split_key) < 2:

View file

@ -1,16 +1,15 @@
import asyncio
import pathlib
from fastapi import APIRouter, Depends, HTTPException, Header, Request
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
from typing import Optional
from common import config, model, gen_logging, sampling
from common.auth import check_admin_key, check_api_key, validate_key_permission
from common.auth import check_admin_key, check_api_key, get_key_permission
from common.downloader import hf_repo_download
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import PromptTemplate, get_all_templates
from common.utils import coalesce, unwrap
from common.utils import unwrap
from endpoints.OAI.types.auth import AuthPermissionResponse
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
@ -432,24 +431,18 @@ async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse:
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
async def get_key_permission(
x_admin_key: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None),
authorization: Optional[str] = Header(None),
) -> AuthPermissionResponse:
async def key_permission(request: Request) -> AuthPermissionResponse:
"""
Gets the access level/permission of a provided key in headers.
Priority:
- X-api-key
- X-admin-key
- Authorization
- X-admin-key
- X-api-key
"""
test_key = coalesce(x_admin_key, x_api_key, authorization)
try:
permission = await validate_key_permission(test_key)
permission = get_key_permission(request)
return AuthPermissionResponse(permission=permission)
except ValueError as exc:
error_message = handle_request_error(str(exc)).error.message