Endpoints: Add key permission checker
This is a definite way to check if an authorized key is API or admin. The endpoint only runs if the key is valid in the first place to keep inline with the API's security model. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
c9a6d9ae1f
commit
3c08f46c51
3 changed files with 51 additions and 2 deletions
|
|
@ -10,6 +10,8 @@ from pydantic import BaseModel
|
|||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
from endpoints.OAI.types.auth import AuthPermissionResponse
|
||||
|
||||
|
||||
class AuthKeys(BaseModel):
|
||||
"""
|
||||
|
|
@ -75,6 +77,18 @@ def load_auth_keys(disable_from_config: bool):
|
|||
)
|
||||
|
||||
|
||||
async def validate_key_permission(test_key: str):
|
||||
if test_key.lower().startswith("bearer"):
|
||||
test_key = test_key.split(" ")[1]
|
||||
|
||||
if AUTH_KEYS.verify_key(test_key, "admin_key"):
|
||||
return AuthPermissionResponse(permission="admin")
|
||||
elif AUTH_KEYS.verify_key(test_key, "api_key"):
|
||||
return AuthPermissionResponse(permission="api")
|
||||
else:
|
||||
raise ValueError("The provided authentication key is invalid.")
|
||||
|
||||
|
||||
async def check_api_key(
|
||||
x_api_key: str = Header(None), authorization: str = Header(None)
|
||||
):
|
||||
|
|
|
|||
|
|
@ -2,15 +2,16 @@ import pathlib
|
|||
import signal
|
||||
import uvicorn
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, Depends, HTTPException, Request
|
||||
from fastapi import FastAPI, Depends, HTTPException, Header, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from functools import partial
|
||||
from loguru import logger
|
||||
from sse_starlette import EventSourceResponse
|
||||
from sys import maxsize
|
||||
from typing import Optional
|
||||
|
||||
from common import config, model, gen_logging, sampling
|
||||
from common.auth import check_admin_key, check_api_key
|
||||
from common.auth import check_admin_key, check_api_key, validate_key_permission
|
||||
from common.concurrency import (
|
||||
call_with_semaphore,
|
||||
generate_with_semaphore,
|
||||
|
|
@ -22,6 +23,7 @@ from common.templating import (
|
|||
get_template_from_file,
|
||||
)
|
||||
from common.utils import (
|
||||
coalesce,
|
||||
handle_request_error,
|
||||
unwrap,
|
||||
)
|
||||
|
|
@ -399,6 +401,32 @@ async def decode_tokens(data: TokenDecodeRequest):
|
|||
return response
|
||||
|
||||
|
||||
@app.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
|
||||
async def get_key_permission(
|
||||
x_admin_key: Optional[str] = Header(None),
|
||||
x_api_key: Optional[str] = Header(None),
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""
|
||||
Gets the access level/permission of a provided key in headers.
|
||||
|
||||
Priority:
|
||||
- X-api-key
|
||||
- X-admin-key
|
||||
- Authorization
|
||||
"""
|
||||
|
||||
test_key = coalesce(x_admin_key, x_api_key, authorization)
|
||||
|
||||
try:
|
||||
response = await validate_key_permission(test_key)
|
||||
return response
|
||||
except ValueError as exc:
|
||||
error_message = handle_request_error(str(exc)).error.message
|
||||
|
||||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
# Completions endpoint
|
||||
@app.post(
|
||||
"/v1/completions",
|
||||
|
|
|
|||
7
endpoints/OAI/types/auth.py
Normal file
7
endpoints/OAI/types/auth.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""Types for auth requests."""
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthPermissionResponse(BaseModel):
|
||||
permission: str
|
||||
Loading…
Add table
Add a link
Reference in a new issue