tabbyAPI-ollama/endpoints/OAI/router.py
kingbri 43cd7f57e8 API + Model: Add blocks and checks for various load requests
Add a sequential lock and wait until jobs are completed before executing
any loading requests that directly alter the model. However, we also
need to block any new requests that come in until the load is finished,
so add a condition that triggers once the lock is free.

Signed-off-by: kingbri <bdashore3@proton.me>
2024-05-25 21:16:14 -04:00

532 lines
16 KiB
Python

import asyncio
import pathlib
from fastapi import APIRouter, Depends, HTTPException, Header, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
from typing import Optional
from common import config, model, gen_logging, sampling
from common.auth import check_admin_key, check_api_key, validate_key_permission
from common.downloader import hf_repo_download
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import PromptTemplate, get_all_templates
from common.utils import coalesce, unwrap
from endpoints.OAI.types.auth import AuthPermissionResponse
from endpoints.OAI.types.completion import CompletionRequest
from endpoints.OAI.types.chat_completion import ChatCompletionRequest
from endpoints.OAI.types.download import DownloadRequest, DownloadResponse
from endpoints.OAI.types.lora import (
LoraCard,
LoraList,
LoraLoadRequest,
LoraLoadResponse,
)
from endpoints.OAI.types.model import (
ModelCard,
ModelLoadRequest,
ModelCardParameters,
)
from endpoints.OAI.types.sampler_overrides import (
SamplerOverrideListResponse,
SamplerOverrideSwitchRequest,
)
from endpoints.OAI.types.template import TemplateList, TemplateSwitchRequest
from endpoints.OAI.types.token import (
TokenEncodeRequest,
TokenEncodeResponse,
TokenDecodeRequest,
TokenDecodeResponse,
)
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
generate_chat_completion,
stream_generate_chat_completion,
)
from endpoints.OAI.utils.completion import (
generate_completion,
stream_generate_completion,
)
from endpoints.OAI.utils.model import get_model_list, stream_model_load
from endpoints.OAI.utils.lora import get_lora_list
router = APIRouter()
async def check_model_container():
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
if model.container is None or not (
model.container.model_is_loading or model.container.model_loaded
):
error_message = handle_request_error(
"No models are currently loaded.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
# Model list endpoint
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
"""Lists all models in the model directory."""
model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
draft_model_dir = config.draft_model_config().get("draft_model_dir")
models = get_model_list(model_path.resolve(), draft_model_dir)
if unwrap(model_config.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
return models
# Currently loaded model endpoint
@router.get(
"/v1/model",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_current_model():
"""Returns the currently loaded model."""
model_params = model.container.get_model_parameters()
draft_model_params = model_params.pop("draft", {})
if draft_model_params:
model_params["draft"] = ModelCard(
id=unwrap(draft_model_params.get("name"), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
else:
draft_model_params = None
model_card = ModelCard(
id=unwrap(model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(model_params),
logging=gen_logging.PREFERENCES,
)
if draft_model_params:
draft_card = ModelCard(
id=unwrap(draft_model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
model_card.parameters.draft = draft_card
return model_card
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models():
"""Lists all draft models in the model directory."""
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
return models
# Load model endpoint
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest):
"""Loads a model into the model container."""
# Verify request parameters
if not data.name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
model_path = model_path / data.name
draft_model_path = None
if data.draft:
if not data.draft.draft_model_name:
error_message = handle_request_error(
"Could not find the draft model name for model load.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
draft_model_path = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
if not model_path.exists():
error_message = handle_request_error(
"Could not find the model path for load. Check model name or config.yml?",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
return EventSourceResponse(
stream_model_load(data, model_path, draft_model_path), ping=maxsize
)
# Unload model endpoint
@router.post(
"/v1/model/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_model():
"""Unloads the currently loaded model."""
await model.unload_model(skip_wait=True)
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def get_templates():
templates = get_all_templates()
template_strings = list(map(lambda template: template.stem, templates))
return TemplateList(data=template_strings)
@router.post(
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
if not data.name:
error_message = handle_request_error(
"New template name not found.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
try:
model.container.prompt_template = PromptTemplate.from_file(data.name)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
exc_info=False,
).error.message
raise HTTPException(400, error_message) from e
@router.post(
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_template():
"""Unloads the currently selected template"""
model.container.prompt_template = None
# Sampler override endpoints
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides():
"""API wrapper to list all currently applied sampler overrides"""
return SamplerOverrideListResponse(
presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump()
)
@router.post(
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"""Switch the currently loaded override preset"""
if data.preset:
try:
sampling.overrides_from_file(data.preset)
except FileNotFoundError as e:
error_message = handle_request_error(
f"Sampler override preset with name {data.preset} does not exist. "
+ "Check the spelling?",
exc_info=False,
).error.message
raise HTTPException(400, error_message) from e
elif data.overrides:
sampling.overrides_from_dict(data.overrides)
else:
error_message = handle_request_error(
"A sampler override preset or dictionary wasn't provided.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
@router.post(
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
async def unload_sampler_override():
"""Unloads the currently selected override preset"""
sampling.overrides_from_dict({})
@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
async def download_model(request: Request, data: DownloadRequest):
"""Downloads a model from HuggingFace."""
try:
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
# For now, the downloader and request data are 1:1
download_path = await run_with_request_disconnect(
request,
download_task,
"Download request cancelled by user. Files have been cleaned up.",
)
return DownloadResponse(download_path=str(download_path))
except Exception as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
# Lora list endpoint
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def get_all_loras():
"""Lists all LoRAs in the lora directory."""
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
return loras
# Currently loaded loras endpoint
@router.get(
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_active_loras():
"""Returns the currently loaded loras."""
active_loras = LoraList(
data=[
LoraCard(
id=pathlib.Path(lora.lora_path).parent.name,
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
)
for lora in model.container.get_loras()
]
)
return active_loras
# Load lora endpoint
@router.post(
"/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def load_lora(data: LoraLoadRequest):
"""Loads a LoRA into the model container."""
if not data.loras:
error_message = handle_request_error(
"List of loras to load is not found.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
if not lora_dir.exists():
error_message = handle_request_error(
"A parent lora directory does not exist for load. Check your config.yml?",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
load_result = await model.load_loras(
lora_dir, **data.model_dump(), skip_wait=data.skip_queue
)
return LoraLoadResponse(
success=unwrap(load_result.get("success"), []),
failure=unwrap(load_result.get("failure"), []),
)
# Unload lora endpoint
@router.post(
"/v1/lora/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_loras():
"""Unloads the currently loaded loras."""
await model.unload_loras()
# Encode tokens endpoint
@router.post(
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def encode_tokens(data: TokenEncodeRequest):
"""Encodes a string or chat completion messages into tokens."""
if isinstance(data.text, str):
text = data.text
else:
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True)
)
template_vars = {
"messages": data.text,
"add_generation_prompt": False,
**special_tokens_dict,
}
text, _ = model.container.prompt_template.render(template_vars)
raw_tokens = model.container.encode_tokens(text, **data.get_params())
tokens = unwrap(raw_tokens, [])
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
return response
# Decode tokens endpoint
@router.post(
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def decode_tokens(data: TokenDecodeRequest):
"""Decodes tokens into a string."""
message = model.container.decode_tokens(data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))
return response
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
async def get_key_permission(
x_admin_key: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None),
authorization: Optional[str] = Header(None),
):
"""
Gets the access level/permission of a provided key in headers.
Priority:
- X-api-key
- X-admin-key
- Authorization
"""
test_key = coalesce(x_admin_key, x_api_key, authorization)
try:
permission = await validate_key_permission(test_key)
return AuthPermissionResponse(permission=permission)
except ValueError as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
# Completions endpoint
@router.post(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def completion_request(request: Request, data: CompletionRequest):
"""Generates a completion from a prompt."""
model_path = model.container.get_model_path()
if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt)
disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_completion(data, model_path),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(generate_completion(data, model_path))
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Completion generation cancelled by user.",
)
return response
# Chat completions endpoint
@router.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def chat_completion_request(request: Request, data: ChatCompletionRequest):
"""Generates a chat completion from a prompt."""
if model.container.prompt_template is None:
error_message = handle_request_error(
"Chat completions are disabled because a prompt template is not set.",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
model_path = model.container.get_model_path()
if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = format_prompt_with_template(data)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
)
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_chat_completion(prompt, data, model_path),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, model_path)
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Chat completion generation cancelled by user.",
)
return response