Add on the fly model loading to requests
This commit is contained in:
parent
ff15eed85d
commit
279e900ea5
2 changed files with 49 additions and 1 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import pathlib
|
||||
from loguru import logger
|
||||
from fastapi import APIRouter, Depends, HTTPException, Header, Request
|
||||
from sse_starlette import EventSourceResponse
|
||||
from sys import maxsize
|
||||
|
|
@ -118,7 +119,7 @@ async def completion_request(
|
|||
# Chat completions endpoint
|
||||
@router.post(
|
||||
"/v1/chat/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def chat_completion_request(
|
||||
request: Request, data: ChatCompletionRequest
|
||||
|
|
@ -129,6 +130,52 @@ async def chat_completion_request(
|
|||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
if data.model is not None and (
|
||||
model.container is None or model.container.get_model_path().name != data.model
|
||||
):
|
||||
adminValid = False
|
||||
if "x_admin_key" in request.headers.keys():
|
||||
try:
|
||||
await check_admin_key(
|
||||
x_admin_key=request.headers.get("x_admin_key"), authorization=None
|
||||
)
|
||||
adminValid = True
|
||||
except HTTPException:
|
||||
pass
|
||||
|
||||
if not adminValid and "authorization" in request.headers.keys():
|
||||
try:
|
||||
await check_admin_key(
|
||||
x_admin_key=None, authorization=request.headers.get("authorization")
|
||||
)
|
||||
adminValid = True
|
||||
except HTTPException:
|
||||
pass
|
||||
|
||||
if adminValid:
|
||||
logger.info(
|
||||
f"New request for {data.model} which is not loaded, proper admin key provided, loading new model"
|
||||
)
|
||||
|
||||
model_path = pathlib.Path(
|
||||
unwrap(config.model_config().get("model_dir"), "models")
|
||||
)
|
||||
model_path = model_path / data.model
|
||||
|
||||
if not model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
"Could not find the model path for load. Check model name.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
await model.load_model(model_path)
|
||||
else:
|
||||
logger.info(f"No valid admin key found to change loaded model, ignoring")
|
||||
else:
|
||||
await check_model_container()
|
||||
|
||||
if model.container.prompt_template is None:
|
||||
error_message = handle_request_error(
|
||||
"Chat completions are disabled because a prompt template is not set.",
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
|||
add_generation_prompt: Optional[bool] = True
|
||||
template_vars: Optional[dict] = {}
|
||||
response_prefix: Optional[str] = None
|
||||
model: Optional[str] = None
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue