API: Update inline load
- Add a config flag - Migrate support to /v1/completions - Unify the load function Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
dd30d6592a
commit
21f14d4318
3 changed files with 60 additions and 48 deletions
|
|
@ -83,6 +83,9 @@ model:
|
|||
# Enable this if the program is looking for a specific OAI model
|
||||
#use_dummy_models: False
|
||||
|
||||
# Allow direct loading of models from a completion or chat completion request
|
||||
inline_model_loading: False
|
||||
|
||||
# An initial model to load. Make sure the model is located in the model directory!
|
||||
# A model can be loaded later via the API.
|
||||
# REQUIRED: This must be filled out to load a model on startup!
|
||||
|
|
|
|||
|
|
@ -1,6 +1,4 @@
|
|||
import asyncio
|
||||
import pathlib
|
||||
from loguru import logger
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
from sse_starlette import EventSourceResponse
|
||||
from sys import maxsize
|
||||
|
|
@ -23,6 +21,7 @@ from endpoints.OAI.utils.chat_completion import (
|
|||
)
|
||||
from endpoints.OAI.utils.completion import (
|
||||
generate_completion,
|
||||
load_inline_model,
|
||||
stream_generate_completion,
|
||||
)
|
||||
from endpoints.OAI.utils.embeddings import get_embeddings
|
||||
|
|
@ -43,7 +42,7 @@ def setup():
|
|||
# Completions endpoint
|
||||
@router.post(
|
||||
"/v1/completions",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
dependencies=[Depends(check_api_key)],
|
||||
)
|
||||
async def completion_request(
|
||||
request: Request, data: CompletionRequest
|
||||
|
|
@ -54,6 +53,11 @@ async def completion_request(
|
|||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
if data.model:
|
||||
await load_inline_model(data.model, request)
|
||||
else:
|
||||
await check_model_container()
|
||||
|
||||
model_path = model.container.model_dir
|
||||
|
||||
if isinstance(data.prompt, list):
|
||||
|
|
@ -99,49 +103,8 @@ async def chat_completion_request(
|
|||
If stream = true, this returns an SSE stream.
|
||||
"""
|
||||
|
||||
if data.model is not None and (
|
||||
model.container is None or model.container.get_model_path().name != data.model
|
||||
):
|
||||
adminValid = False
|
||||
if "x_admin_key" in request.headers.keys():
|
||||
try:
|
||||
await check_admin_key(
|
||||
x_admin_key=request.headers.get("x_admin_key"), authorization=None
|
||||
)
|
||||
adminValid = True
|
||||
except HTTPException:
|
||||
pass
|
||||
|
||||
if not adminValid and "authorization" in request.headers.keys():
|
||||
try:
|
||||
await check_admin_key(
|
||||
x_admin_key=None, authorization=request.headers.get("authorization")
|
||||
)
|
||||
adminValid = True
|
||||
except HTTPException:
|
||||
pass
|
||||
|
||||
if adminValid:
|
||||
logger.info(
|
||||
f"New request for {data.model} which is not loaded, proper admin key provided, loading new model"
|
||||
)
|
||||
|
||||
model_path = pathlib.Path(
|
||||
unwrap(config.model_config().get("model_dir"), "models")
|
||||
)
|
||||
model_path = model_path / data.model
|
||||
|
||||
if not model_path.exists():
|
||||
error_message = handle_request_error(
|
||||
"Could not find the model path for load. Check model name.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(400, error_message)
|
||||
|
||||
await model.load_model(model_path)
|
||||
else:
|
||||
logger.info(f"No valid admin key found to change loaded model, ignoring")
|
||||
if data.model:
|
||||
await load_inline_model(data.model, request)
|
||||
else:
|
||||
await check_model_container()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,8 @@
|
|||
"""Completion utilities for OAI server."""
|
||||
"""
|
||||
Completion utilities for OAI server.
|
||||
|
||||
Also serves as a common module for completions and chat completions.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
|
|
@ -9,7 +13,8 @@ from typing import List, Union
|
|||
|
||||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common import config, model
|
||||
from common.auth import get_key_permission
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
|
|
@ -173,6 +178,47 @@ async def stream_generate_completion(
|
|||
)
|
||||
|
||||
|
||||
async def load_inline_model(model_name: str, request: Request):
|
||||
"""Load a model from the data.model parameter"""
|
||||
|
||||
# Return if the model container already exists
|
||||
if model.container and model.container.model_dir.name == model_name:
|
||||
return
|
||||
|
||||
model_config = config.model_config()
|
||||
|
||||
# Inline model loading isn't enabled or the user isn't an admin
|
||||
if not get_key_permission(request) == "admin":
|
||||
logger.warning(
|
||||
f"Unable to switch model to {model_name} "
|
||||
"because an admin key isn't provided."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
if not unwrap(model_config.get("inline_model_loading"), False):
|
||||
logger.warning(
|
||||
f"Unable to switch model to {model_name} because "
|
||||
'"inline_model_load" is not True in config.yml.'
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
||||
model_path = model_path / model_name
|
||||
|
||||
# Model path doesn't exist
|
||||
if not model_path.exists():
|
||||
logger.warning(
|
||||
f"Could not find model path {str(model_path)}. Skipping inline model load."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# Load the model
|
||||
await model.load_model(model_path)
|
||||
|
||||
|
||||
async def generate_completion(
|
||||
data: CompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue