tabbyAPI-ollama/endpoints/OAI/router.py
kingbri dfb4c51d5f OAI: Fix function idioms
Make functions mean the same thing to avoid confusion.

Signed-off-by: kingbri <bdashore3@proton.me>
2024-07-11 14:22:50 -04:00

543 lines
17 KiB
Python

import asyncio
import pathlib
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
from common import config, model, gen_logging, sampling
from common.auth import check_admin_key, check_api_key, get_key_permission
from common.downloader import hf_repo_download
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import PromptTemplate, get_all_templates
from common.utils import unwrap
from endpoints.OAI.types.auth import AuthPermissionResponse
from endpoints.OAI.types.completion import CompletionRequest, CompletionResponse
from endpoints.OAI.types.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
)
from endpoints.OAI.types.download import DownloadRequest, DownloadResponse
from endpoints.OAI.types.lora import (
LoraCard,
LoraList,
LoraLoadRequest,
LoraLoadResponse,
)
from endpoints.OAI.types.model import (
ModelCard,
ModelList,
ModelLoadRequest,
ModelCardParameters,
ModelLoadResponse,
)
from endpoints.OAI.types.sampler_overrides import (
SamplerOverrideListResponse,
SamplerOverrideSwitchRequest,
)
from endpoints.OAI.types.template import TemplateList, TemplateSwitchRequest
from endpoints.OAI.types.token import (
TokenEncodeRequest,
TokenEncodeResponse,
TokenDecodeRequest,
TokenDecodeResponse,
)
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
generate_chat_completion,
stream_generate_chat_completion,
)
from endpoints.OAI.utils.completion import (
generate_completion,
stream_generate_completion,
)
from endpoints.OAI.utils.model import get_model_list, stream_model_load
from endpoints.OAI.utils.lora import get_lora_list
router = APIRouter()
async def check_model_container():
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
if model.container is None or not (
model.container.model_is_loading or model.container.model_loaded
):
error_message = handle_request_error(
"No models are currently loaded.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
# Completions endpoint
@router.post(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def completion_request(
request: Request, data: CompletionRequest
) -> CompletionResponse:
"""
Generates a completion from a prompt.
If stream = true, this returns an SSE stream.
"""
model_path = model.container.get_model_path()
if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt)
disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_completion(data, request, model_path),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(generate_completion(data, model_path))
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Completion generation cancelled by user.",
)
return response
# Chat completions endpoint
@router.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def chat_completion_request(
request: Request, data: ChatCompletionRequest
) -> ChatCompletionResponse:
"""
Generates a chat completion from a prompt.
If stream = true, this returns an SSE stream.
"""
if model.container.prompt_template is None:
error_message = handle_request_error(
"Chat completions are disabled because a prompt template is not set.",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
model_path = model.container.get_model_path()
if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = format_prompt_with_template(data)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":
data.json_schema = {"type": "object"}
disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False
)
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_chat_completion(prompt, data, request, model_path),
ping=maxsize,
)
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, model_path)
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Chat completion generation cancelled by user.",
)
return response
# Model list endpoint
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models() -> ModelList:
"""Lists all models in the model directory."""
model_config = config.model_config()
model_dir = unwrap(model_config.get("model_dir"), "models")
model_path = pathlib.Path(model_dir)
draft_model_dir = config.draft_model_config().get("draft_model_dir")
models = get_model_list(model_path.resolve(), draft_model_dir)
if unwrap(model_config.get("use_dummy_models"), False):
models.data.insert(0, ModelCard(id="gpt-3.5-turbo"))
return models
# Currently loaded model endpoint
@router.get(
"/v1/model",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def current_model() -> ModelCard:
"""Returns the currently loaded model."""
model_params = model.container.get_model_parameters()
draft_model_params = model_params.pop("draft", {})
if draft_model_params:
model_params["draft"] = ModelCard(
id=unwrap(draft_model_params.get("name"), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
else:
draft_model_params = None
model_card = ModelCard(
id=unwrap(model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(model_params),
logging=gen_logging.PREFERENCES,
)
if draft_model_params:
draft_card = ModelCard(
id=unwrap(draft_model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
model_card.parameters.draft = draft_card
return model_card
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models() -> ModelList:
"""Lists all draft models in the model directory."""
draft_model_dir = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
draft_model_path = pathlib.Path(draft_model_dir)
models = get_model_list(draft_model_path.resolve())
return models
# Load model endpoint
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest) -> ModelLoadResponse:
"""Loads a model into the model container. This returns an SSE stream."""
# Verify request parameters
if not data.name:
error_message = handle_request_error(
"A model name was not provided for load.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
model_path = pathlib.Path(unwrap(config.model_config().get("model_dir"), "models"))
model_path = model_path / data.name
draft_model_path = None
if data.draft:
if not data.draft.draft_model_name:
error_message = handle_request_error(
"Could not find the draft model name for model load.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
draft_model_path = unwrap(
config.draft_model_config().get("draft_model_dir"), "models"
)
if not model_path.exists():
error_message = handle_request_error(
"Could not find the model path for load. Check model name or config.yml?",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
return EventSourceResponse(
stream_model_load(data, model_path, draft_model_path), ping=maxsize
)
# Unload model endpoint
@router.post(
"/v1/model/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_model():
"""Unloads the currently loaded model."""
await model.unload_model(skip_wait=True)
@router.post("/v1/download", dependencies=[Depends(check_admin_key)])
async def download_model(request: Request, data: DownloadRequest) -> DownloadResponse:
"""Downloads a model from HuggingFace."""
try:
download_task = asyncio.create_task(hf_repo_download(**data.model_dump()))
# For now, the downloader and request data are 1:1
download_path = await run_with_request_disconnect(
request,
download_task,
"Download request cancelled by user. Files have been cleaned up.",
)
return DownloadResponse(download_path=str(download_path))
except Exception as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
# Lora list endpoint
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def list_all_loras() -> LoraList:
"""Lists all LoRAs in the lora directory."""
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
loras = get_lora_list(lora_path.resolve())
return loras
# Currently loaded loras endpoint
@router.get(
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def active_loras() -> LoraList:
"""Returns the currently loaded loras."""
loras = LoraList(
data=[
LoraCard(
id=pathlib.Path(lora.lora_path).parent.name,
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
)
for lora in model.container.get_loras()
]
)
return loras
# Load lora endpoint
@router.post(
"/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def load_lora(data: LoraLoadRequest) -> LoraLoadResponse:
"""Loads a LoRA into the model container."""
if not data.loras:
error_message = handle_request_error(
"List of loras to load is not found.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
lora_dir = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
if not lora_dir.exists():
error_message = handle_request_error(
"A parent lora directory does not exist for load. Check your config.yml?",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
load_result = await model.load_loras(
lora_dir, **data.model_dump(), skip_wait=data.skip_queue
)
return LoraLoadResponse(
success=unwrap(load_result.get("success"), []),
failure=unwrap(load_result.get("failure"), []),
)
# Unload lora endpoint
@router.post(
"/v1/lora/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_loras():
"""Unloads the currently loaded loras."""
await model.unload_loras()
# Encode tokens endpoint
@router.post(
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
"""Encodes a string or chat completion messages into tokens."""
if isinstance(data.text, str):
text = data.text
else:
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True)
)
template_vars = {
"messages": data.text,
"add_generation_prompt": False,
**special_tokens_dict,
}
text, _ = model.container.prompt_template.render(template_vars)
raw_tokens = model.container.encode_tokens(text, **data.get_params())
tokens = unwrap(raw_tokens, [])
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
return response
# Decode tokens endpoint
@router.post(
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def decode_tokens(data: TokenDecodeRequest) -> TokenDecodeResponse:
"""Decodes tokens into a string."""
message = model.container.decode_tokens(data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))
return response
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
async def key_permission(request: Request) -> AuthPermissionResponse:
"""
Gets the access level/permission of a provided key in headers.
Priority:
- Authorization
- X-admin-key
- X-api-key
"""
try:
permission = get_key_permission(request)
return AuthPermissionResponse(permission=permission)
except ValueError as exc:
error_message = handle_request_error(str(exc)).error.message
raise HTTPException(400, error_message) from exc
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def list_templates() -> TemplateList:
templates = get_all_templates()
template_strings = [template.stem for template in templates]
return TemplateList(data=template_strings)
@router.post(
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
if not data.name:
error_message = handle_request_error(
"New template name not found.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
try:
model.container.prompt_template = PromptTemplate.from_file(data.name)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
exc_info=False,
).error.message
raise HTTPException(400, error_message) from e
@router.post(
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_template():
"""Unloads the currently selected template"""
model.container.prompt_template = None
# Sampler override endpoints
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides() -> SamplerOverrideListResponse:
"""API wrapper to list all currently applied sampler overrides"""
return SamplerOverrideListResponse(
presets=sampling.get_all_presets(), **sampling.overrides_container.model_dump()
)
@router.post(
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
"""Switch the currently loaded override preset"""
if data.preset:
try:
sampling.overrides_from_file(data.preset)
except FileNotFoundError as e:
error_message = handle_request_error(
f"Sampler override preset with name {data.preset} does not exist. "
+ "Check the spelling?",
exc_info=False,
).error.message
raise HTTPException(400, error_message) from e
elif data.overrides:
sampling.overrides_from_dict(data.overrides)
else:
error_message = handle_request_error(
"A sampler override preset or dictionary wasn't provided.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
@router.post(
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
async def unload_sampler_override():
"""Unloads the currently selected override preset"""
sampling.overrides_from_dict({})