OAI: Restrict list permissions for API keys
API keys are not allowed to view all the admin's models, templates, draft models, loras, etc. Basically anything that can be viewed on the filesystem outside of anything that's currently loaded is not allowed to be returned unless an admin key is present. This change helps preserve user privacy while not erroring out on list endpoints that the OAI spec requires. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
10890913b8
commit
1f46a1130c
5 changed files with 119 additions and 60 deletions
|
|
@ -405,6 +405,9 @@ class ExllamaV2Container:
|
|||
def get_model_path(self, is_draft: bool = False):
|
||||
"""Get the path for this model."""
|
||||
|
||||
if is_draft and not self.draft_config:
|
||||
return None
|
||||
|
||||
model_path = pathlib.Path(
|
||||
self.draft_config.model_dir if is_draft else self.config.model_dir
|
||||
)
|
||||
|
|
|
|||
|
|
@ -106,8 +106,7 @@ def get_key_permission(request: Request):
|
|||
|
||||
|
||||
async def check_api_key(
|
||||
x_api_key: str = Header(None),
|
||||
authorization: str = Header(None)
|
||||
x_api_key: str = Header(None), authorization: str = Header(None)
|
||||
):
|
||||
"""Check if the API key is valid."""
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from fastapi import APIRouter, Depends, HTTPException, Request
|
|||
from sse_starlette import EventSourceResponse
|
||||
from sys import maxsize
|
||||
|
||||
from common import config, model, gen_logging, sampling
|
||||
from common import config, model, sampling
|
||||
from common.auth import check_admin_key, check_api_key, get_key_permission
|
||||
from common.downloader import hf_repo_download
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
|
|
@ -18,7 +18,6 @@ from endpoints.OAI.types.chat_completion import (
|
|||
)
|
||||
from endpoints.OAI.types.download import DownloadRequest, DownloadResponse
|
||||
from endpoints.OAI.types.lora import (
|
||||
LoraCard,
|
||||
LoraList,
|
||||
LoraLoadRequest,
|
||||
LoraLoadResponse,
|
||||
|
|
@ -27,7 +26,6 @@ from endpoints.OAI.types.model import (
|
|||
ModelCard,
|
||||
ModelList,
|
||||
ModelLoadRequest,
|
||||
ModelCardParameters,
|
||||
ModelLoadResponse,
|
||||
)
|
||||
from endpoints.OAI.types.sampler_overrides import (
|
||||
|
|
@ -50,8 +48,13 @@ from endpoints.OAI.utils.completion import (
|
|||
generate_completion,
|
||||
stream_generate_completion,
|
||||
)
|
||||
from endpoints.OAI.utils.model import get_model_list, stream_model_load
|
||||
from endpoints.OAI.utils.lora import get_lora_list
|
||||
from endpoints.OAI.utils.model import (
|
||||
get_current_model,
|
||||
get_current_model_list,
|
||||
get_model_list,
|
||||
stream_model_load,
|
||||
)
|
||||
from endpoints.OAI.utils.lora import get_active_loras, get_lora_list
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
|
@ -172,7 +175,7 @@ async def chat_completion_request(
|
|||
# Model list endpoint
|
||||
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_models() -> ModelList:
|
||||
async def list_models(request: Request) -> ModelList:
|
||||
"""Lists all models in the model directory."""
|
||||
model_config = config.model_config()
|
||||
model_dir = unwrap(model_config.get("model_dir"), "models")
|
||||
|
|
@ -180,7 +183,11 @@ async def list_models() -> ModelList:
|
|||
|
||||
draft_model_dir = config.draft_model_config().get("draft_model_dir")
|
||||
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
if get_key_permission(request) == "admin":
|
||||
models = get_model_list(model_path.resolve(), draft_model_dir)
|
||||
else:
|
||||
models = await get_current_model_list()
|
||||
|
||||
if unwrap(model_config.get("use_dummy_models"), False):
|
||||
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
|
||||
|
||||
|
|
@ -194,43 +201,23 @@ async def list_models() -> ModelList:
|
|||
)
|
||||
async def current_model() -> ModelCard:
|
||||
"""Returns the currently loaded model."""
|
||||
model_params = model.container.get_model_parameters()
|
||||
draft_model_params = model_params.pop("draft", {})
|
||||
|
||||
if draft_model_params:
|
||||
model_params["draft"] = ModelCard(
|
||||
id=unwrap(draft_model_params.get("name"), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
else:
|
||||
draft_model_params = None
|
||||
|
||||
model_card = ModelCard(
|
||||
id=unwrap(model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(model_params),
|
||||
logging=gen_logging.PREFERENCES,
|
||||
)
|
||||
|
||||
if draft_model_params:
|
||||
draft_card = ModelCard(
|
||||
id=unwrap(draft_model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
|
||||
model_card.parameters.draft = draft_card
|
||||
|
||||
return model_card
|
||||
return get_current_model()
|
||||
|
||||
|
||||
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_draft_models() -> ModelList:
|
||||
async def list_draft_models(request: Request) -> ModelList:
|
||||
"""Lists all draft models in the model directory."""
|
||||
draft_model_dir = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
if get_key_permission(request) == "admin":
|
||||
draft_model_dir = unwrap(
|
||||
config.draft_model_config().get("draft_model_dir"), "models"
|
||||
)
|
||||
draft_model_path = pathlib.Path(draft_model_dir)
|
||||
|
||||
models = get_model_list(draft_model_path.resolve())
|
||||
else:
|
||||
models = await get_current_model_list(is_draft=True)
|
||||
|
||||
return models
|
||||
|
||||
|
|
@ -313,10 +300,14 @@ async def download_model(request: Request, data: DownloadRequest) -> DownloadRes
|
|||
# Lora list endpoint
|
||||
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_all_loras() -> LoraList:
|
||||
async def list_all_loras(request: Request) -> LoraList:
|
||||
"""Lists all LoRAs in the lora directory."""
|
||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
|
||||
loras = get_lora_list(lora_path.resolve())
|
||||
else:
|
||||
loras = get_active_loras()
|
||||
|
||||
return loras
|
||||
|
||||
|
|
@ -328,17 +319,8 @@ async def list_all_loras() -> LoraList:
|
|||
)
|
||||
async def active_loras() -> LoraList:
|
||||
"""Returns the currently loaded loras."""
|
||||
loras = LoraList(
|
||||
data=[
|
||||
LoraCard(
|
||||
id=pathlib.Path(lora.lora_path).parent.name,
|
||||
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
|
||||
)
|
||||
for lora in model.container.get_loras()
|
||||
]
|
||||
)
|
||||
|
||||
return loras
|
||||
return get_active_loras()
|
||||
|
||||
|
||||
# Load lora endpoint
|
||||
|
|
@ -452,9 +434,17 @@ async def key_permission(request: Request) -> AuthPermissionResponse:
|
|||
|
||||
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_templates() -> TemplateList:
|
||||
templates = get_all_templates()
|
||||
template_strings = [template.stem for template in templates]
|
||||
async def list_templates(request: Request) -> TemplateList:
|
||||
"""Get a list of all templates."""
|
||||
|
||||
template_strings = []
|
||||
if get_key_permission(request) == "admin":
|
||||
templates = get_all_templates()
|
||||
template_strings = [template.stem for template in templates]
|
||||
else:
|
||||
if model.container and model.container.prompt_template:
|
||||
template_strings.append(model.container.prompt_template.name)
|
||||
|
||||
return TemplateList(data=template_strings)
|
||||
|
||||
|
||||
|
|
@ -464,6 +454,7 @@ async def list_templates() -> TemplateList:
|
|||
)
|
||||
async def switch_template(data: TemplateSwitchRequest):
|
||||
"""Switch the currently loaded template"""
|
||||
|
||||
if not data.name:
|
||||
error_message = handle_request_error(
|
||||
"New template name not found.",
|
||||
|
|
@ -496,11 +487,16 @@ async def unload_template():
|
|||
# Sampler override endpoints
|
||||
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
|
||||
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
|
||||
async def list_sampler_overrides() -> SamplerOverrideListResponse:
|
||||
async def list_sampler_overrides(request: Request) -> SamplerOverrideListResponse:
|
||||
"""API wrapper to list all currently applied sampler overrides"""
|
||||
|
||||
if get_key_permission(request) == "admin":
|
||||
presets = sampling.get_all_presets()
|
||||
else:
|
||||
presets = []
|
||||
|
||||
return SamplerOverrideListResponse(
|
||||
presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump()
|
||||
presets=presets, **sampling.overrides_container.model_dump()
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import pathlib
|
||||
|
||||
from common import model
|
||||
from endpoints.OAI.types.lora import LoraCard, LoraList
|
||||
|
||||
|
||||
|
|
@ -12,3 +13,18 @@ def get_lora_list(lora_path: pathlib.Path):
|
|||
lora_list.data.append(lora_card)
|
||||
|
||||
return lora_list
|
||||
|
||||
|
||||
def get_active_loras():
|
||||
if model.container:
|
||||
active_loras = [
|
||||
LoraCard(
|
||||
id=pathlib.Path(lora.lora_path).parent.name,
|
||||
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
|
||||
)
|
||||
for lora in model.container.get_loras()
|
||||
]
|
||||
else:
|
||||
active_loras = []
|
||||
|
||||
return LoraList(data=active_loras)
|
||||
|
|
|
|||
|
|
@ -2,11 +2,12 @@ import pathlib
|
|||
from asyncio import CancelledError
|
||||
from typing import Optional
|
||||
|
||||
from common import model
|
||||
from common import gen_logging, model
|
||||
from common.networking import get_generator_error, handle_request_disconnect
|
||||
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.model import (
|
||||
ModelCard,
|
||||
ModelCardParameters,
|
||||
ModelList,
|
||||
ModelLoadRequest,
|
||||
ModelLoadResponse,
|
||||
|
|
@ -31,6 +32,50 @@ def get_model_list(model_path: pathlib.Path, draft_model_path: Optional[str] = N
|
|||
return model_card_list
|
||||
|
||||
|
||||
async def get_current_model_list(is_draft: bool = False):
|
||||
"""Gets the current model in list format and with path only."""
|
||||
current_models = []
|
||||
|
||||
# Make sure the model container exists
|
||||
if model.container:
|
||||
model_path = model.container.get_model_path(is_draft)
|
||||
if model_path:
|
||||
current_models.append(ModelCard(id=model_path.name))
|
||||
|
||||
return ModelList(data=current_models)
|
||||
|
||||
|
||||
def get_current_model():
|
||||
"""Gets the current model with all parameters."""
|
||||
|
||||
model_params = model.container.get_model_parameters()
|
||||
draft_model_params = model_params.pop("draft", {})
|
||||
|
||||
if draft_model_params:
|
||||
model_params["draft"] = ModelCard(
|
||||
id=unwrap(draft_model_params.get("name"), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
else:
|
||||
draft_model_params = None
|
||||
|
||||
model_card = ModelCard(
|
||||
id=unwrap(model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(model_params),
|
||||
logging=gen_logging.PREFERENCES,
|
||||
)
|
||||
|
||||
if draft_model_params:
|
||||
draft_card = ModelCard(
|
||||
id=unwrap(draft_model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
|
||||
model_card.parameters.draft = draft_card
|
||||
|
||||
return model_card
|
||||
|
||||
|
||||
async def stream_model_load(
|
||||
data: ModelLoadRequest,
|
||||
model_path: pathlib.Path,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue