OAI: Restrict list permissions for API keys

API keys are not allowed to view all the admin's models, templates,
draft models, loras, etc. Basically anything that can be viewed
on the filesystem outside of anything that's currently loaded is
not allowed to be returned unless an admin key is present.

This change helps preserve user privacy while not erroring out on
list endpoints that the OAI spec requires.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-11 14:06:03 -04:00
parent 10890913b8
commit 1f46a1130c
5 changed files with 119 additions and 60 deletions

View file

@ -405,6 +405,9 @@ class ExllamaV2Container:
def get_model_path(self, is_draft: bool = False):
"""Get the path for this model."""
if is_draft and not self.draft_config:
return None
model_path = pathlib.Path(
self.draft_config.model_dir if is_draft else self.config.model_dir
)

View file

@ -106,8 +106,7 @@ def get_key_permission(request: Request):
async def check_api_key(
x_api_key: str = Header(None),
authorization: str = Header(None)
x_api_key: str = Header(None), authorization: str = Header(None)
):
"""Check if the API key is valid."""

View file

@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
from common import config, model, gen_logging, sampling
from common import config, model, sampling
from common.auth import check_admin_key, check_api_key, get_key_permission
from common.downloader import hf_repo_download
from common.networking import handle_request_error, run_with_request_disconnect
@ -18,7 +18,6 @@ from endpoints.OAI.types.chat_completion import (
)
from endpoints.OAI.types.download import DownloadRequest, DownloadResponse
from endpoints.OAI.types.lora import (
LoraCard,
LoraList,
LoraLoadRequest,
LoraLoadResponse,
@ -27,7 +26,6 @@ from endpoints.OAI.types.model import (
ModelCard,
ModelList,
ModelLoadRequest,
ModelCardParameters,
ModelLoadResponse,
)
from endpoints.OAI.types.sampler_overrides import (
@ -50,8 +48,13 @@ from endpoints.OAI.utils.completion import (
generate_completion,
stream_generate_completion,
)
from endpoints.OAI.utils.model import get_model_list, stream_model_load
from endpoints.OAI.utils.lora import get_lora_list
from endpoints.OAI.utils.model import (
get_current_model,
get_current_model_list,
get_model_list,
stream_model_load,
)
from endpoints.OAI.utils.lora import get_active_loras, get_lora_list
router = APIRouter()
@ -172,7 +175,7 @@ async def chat_completion_request(
# Model list endpoint
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models() -> ModelList:
async def list_models(request: Request) -> ModelList:
"""Lists all models in the model directory."""
model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
@ -180,7 +183,11 @@ async def list_models() -> ModelList:
draft_model_dir = config.draft_model_config().get("draft_model_dir")
models = get_model_list(model_path.resolve(), draft_model_dir)
if get_key_permission(request) == "admin":
models = get_model_list(model_path.resolve(), draft_model_dir)
else:
models = await get_current_model_list()
if unwrap(model_config.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
@ -194,43 +201,23 @@ async def list_models() -> ModelList:
)
async def current_model() -> ModelCard:
"""Returns the currently loaded model."""
model_params = model.container.get_model_parameters()
draft_model_params = model_params.pop("draft", {})
if draft_model_params:
model_params["draft"] = ModelCard(
id=unwrap(draft_model_params.get("name"), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
else:
draft_model_params = None
model_card = ModelCard(
id=unwrap(model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(model_params),
logging=gen_logging.PREFERENCES,
)
if draft_model_params:
draft_card = ModelCard(
id=unwrap(draft_model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
model_card.parameters.draft = draft_card
return model_card
return get_current_model()
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models() -> ModelList:
async def list_draft_models(request: Request) -> ModelList:
"""Lists all draft models in the model directory."""
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
if get_key_permission(request) == "admin":
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
else:
models = await get_current_model_list(is_draft=True)
return models
@ -313,10 +300,14 @@ async def download_model(request: Request, data: DownloadRequest) -> DownloadRes
# Lora list endpoint
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def list_all_loras() -> LoraList:
async def list_all_loras(request: Request) -> LoraList:
"""Lists all LoRAs in the lora directory."""
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
if get_key_permission(request) == "admin":
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
else:
loras = get_active_loras()
return loras
@ -328,17 +319,8 @@ async def list_all_loras() -> LoraList:
)
async def active_loras() -> LoraList:
"""Returns the currently loaded loras."""
loras = LoraList(
data=[
LoraCard(
id=pathlib.Path(lora.lora_path).parent.name,
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
)
for lora in model.container.get_loras()
]
)
return loras
return get_active_loras()
# Load lora endpoint
@ -452,9 +434,17 @@ async def key_permission(request: Request) -> AuthPermissionResponse:
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def list_templates() -> TemplateList:
templates = get_all_templates()
template_strings = [template.stem for template in templates]
async def list_templates(request: Request) -> TemplateList:
"""Get a list of all templates."""
template_strings = []
if get_key_permission(request) == "admin":
templates = get_all_templates()
template_strings = [template.stem for template in templates]
else:
if model.container and model.container.prompt_template:
template_strings.append(model.container.prompt_template.name)
return TemplateList(data=template_strings)
@ -464,6 +454,7 @@ async def list_templates() -> TemplateList:
)
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
if not data.name:
error_message = handle_request_error(
"New template name not found.",
@ -496,11 +487,16 @@ async def unload_template():
# Sampler override endpoints
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides() -> SamplerOverrideListResponse:
async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse:
"""API wrapper to list all currently applied sampler overrides"""
if get_key_permission(request) == "admin":
presets = sampling.get_all_presets()
else:
presets = []
return SamplerOverrideListResponse(
presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump()
presets=presets, **sampling.overrides_container.model_dump()
)

View file

@ -1,5 +1,6 @@
import pathlib
from common import model
from endpoints.OAI.types.lora import LoraCard, LoraList
@ -12,3 +13,18 @@ def get_lora_list(lora_path: pathlib.Path):
lora_list.data.append(lora_card)
return lora_list
def get_active_loras():
if model.container:
active_loras = [
LoraCard(
id=pathlib.Path(lora.lora_path).parent.name,
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
)
for lora in model.container.get_loras()
]
else:
active_loras = []
return LoraList(data=active_loras)

View file

@ -2,11 +2,12 @@ import pathlib
from asyncio import CancelledError
from typing import Optional
from common import model
from common import gen_logging, model
from common.networking import get_generator_error, handle_request_disconnect
from common.utils import unwrap
from endpoints.OAI.types.model import (
ModelCard,
ModelCardParameters,
ModelList,
ModelLoadRequest,
ModelLoadResponse,
@ -31,6 +32,50 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
return model_card_list
async def get_current_model_list(is_draft: bool = False):
"""Gets the current model in list format and with path only."""
current_models = []
# Make sure the model container exists
if model.container:
model_path = model.container.get_model_path(is_draft)
if model_path:
current_models.append(ModelCard(id=model_path.name))
return ModelList(data=current_models)
def get_current_model():
"""Gets the current model with all parameters."""
model_params = model.container.get_model_parameters()
draft_model_params = model_params.pop("draft", {})
if draft_model_params:
model_params["draft"] = ModelCard(
id=unwrap(draft_model_params.get("name"), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
else:
draft_model_params = None
model_card = ModelCard(
id=unwrap(model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(model_params),
logging=gen_logging.PREFERENCES,
)
if draft_model_params:
draft_card = ModelCard(
id=unwrap(draft_model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
model_card.parameters.draft = draft_card
return model_card
async def stream_model_load(
data: ModelLoadRequest,
model_path: pathlib.Path,