Endpoints: Add key permission checker

This is a definite way to check if an authorized key is API or admin.
The endpoint only runs if the key is valid in the first place to keep
inline with the API's security model.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-18 00:45:40 -04:00
parent c9a6d9ae1f
commit 3c08f46c51
3 changed files with 51 additions and 2 deletions

View file

@ -10,6 +10,8 @@ from pydantic import BaseModel
from loguru import logger
from typing import Optional
from endpoints.OAI.types.auth import AuthPermissionResponse
class AuthKeys(BaseModel):
"""
@ -75,6 +77,18 @@ def load_auth_keys(disable_from_config: bool):
)
async def validate_key_permission(test_key: str):
if test_key.lower().startswith("bearer"):
test_key = test_key.split(" ")[1]
if AUTH_KEYS.verify_key(test_key, "admin_key"):
return AuthPermissionResponse(permission="admin")
elif AUTH_KEYS.verify_key(test_key, "api_key"):
return AuthPermissionResponse(permission="api")
else:
raise ValueError("The provided authentication key is invalid.")
async def check_api_key(
x_api_key: str = Header(None), authorization: str = Header(None)
):

View file

@ -2,15 +2,16 @@ import pathlib
import signal
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi import FastAPI, Depends, HTTPException, Header, Request
from fastapi.middleware.cors import CORSMiddleware
from functools import partial
from loguru import logger
from sse_starlette import EventSourceResponse
from sys import maxsize
from typing import Optional
from common import config, model, gen_logging, sampling
from common.auth import check_admin_key, check_api_key
from common.auth import check_admin_key, check_api_key, validate_key_permission
from common.concurrency import (
call_with_semaphore,
generate_with_semaphore,
@ -22,6 +23,7 @@ from common.templating import (
get_template_from_file,
)
from common.utils import (
coalesce,
handle_request_error,
unwrap,
)
@ -399,6 +401,32 @@ async def decode_tokens(data: TokenDecodeRequest):
return response
@app.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
async def get_key_permission(
x_admin_key: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None),
authorization: Optional[str] = Header(None),
):
"""
Gets the access level/permission of a provided key in headers.
Priority:
- X-api-key
- X-admin-key
- Authorization
"""
test_key = coalesce(x_admin_key, x_api_key, authorization)
try:
response = await validate_key_permission(test_key)
return response
except ValueError as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
# Completions endpoint
@app.post(
"/v1/completions",

View file

@ -0,0 +1,7 @@
"""Types for auth requests."""
from pydantic import BaseModel
class AuthPermissionResponse(BaseModel):
permission: str