Add on the fly model loading to requests

This commit is contained in:
Colin Kealty 2024-06-04 13:35:48 -04:00
parent ff15eed85d
commit 279e900ea5
2 changed files with 49 additions and 1 deletions

View file

@ -1,5 +1,6 @@
import asyncio
import pathlib
from loguru import logger
from fastapi import APIRouter, Depends, HTTPException, Header, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
@ -118,7 +119,7 @@ async def completion_request(
# Chat completions endpoint
@router.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
dependencies=[Depends(check_api_key)],
)
async def chat_completion_request(
request: Request, data: ChatCompletionRequest
@ -129,6 +130,52 @@ async def chat_completion_request(
If stream = true, this returns an SSE stream.
"""
if data.model is not None and (
model.container is None or model.container.get_model_path().name != data.model
):
adminValid = False
if "x_admin_key" in request.headers.keys():
try:
await check_admin_key(
x_admin_key=request.headers.get("x_admin_key"), authorization=None
)
adminValid = True
except HTTPException:
pass
if not adminValid and "authorization" in request.headers.keys():
try:
await check_admin_key(
x_admin_key=None, authorization=request.headers.get("authorization")
)
adminValid = True
except HTTPException:
pass
if adminValid:
logger.info(
f"New request for {data.model} which is not loaded, proper admin key provided, loading new model"
)
model_path = pathlib.Path(
unwrap(config.model_config().get("model_dir"), "models")
)
model_path = model_path / data.model
if not model_path.exists():
error_message = handle_request_error(
"Could not find the model path for load. Check model name.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
await model.load_model(model_path)
else:
logger.info(f"No valid admin key found to change loaded model, ignoring")
else:
await check_model_container()
if model.container.prompt_template is None:
error_message = handle_request_error(
"Chat completions are disabled because a prompt template is not set.",

View file

@ -47,6 +47,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = {}
response_prefix: Optional[str] = None
model: Optional[str] = None
class ChatCompletionResponse(BaseModel):