API: Move to ModelManager

This is a shared module  which manages the model container and provides
extra utility functions around it to help slim down the API.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-08 23:57:07 -05:00 committed by Brian Dashore
parent 8b46282aef
commit b373b25235
5 changed files with 178 additions and 143 deletions

215
main.py
View file

@ -1,4 +1,5 @@
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
import asyncio
import os
import pathlib
import signal
@ -17,10 +18,10 @@ from fastapi.middleware.cors import CORSMiddleware
from functools import partial
from loguru import logger
from common.logger import UVICORN_LOG_CONFIG, setup_logger, get_loading_progress_bar
from common.logger import UVICORN_LOG_CONFIG, setup_logger
import common.gen_logging as gen_logging
from backends.exllamav2.model import ExllamaV2Container
from backends.exllamav2.utils import check_exllama_version
from common import model
from common.args import convert_args_to_dict, init_argparser
from common.auth import check_admin_key, check_api_key, load_auth_keys
from common.config import (
@ -52,7 +53,6 @@ from common.templating import (
from common.utils import (
get_generator_error,
handle_request_error,
load_progress,
is_port_in_use,
unwrap,
)
@ -90,24 +90,6 @@ app = FastAPI(
),
)
# Globally scoped variables. Undefined until initalized in main
MODEL_CONTAINER: Optional[ExllamaV2Container] = None
async def _check_model_container():
"""Checks if a model isn't loading or loaded."""
if MODEL_CONTAINER is None or not (
MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded
):
error_message = handle_request_error(
"No models are currently loaded.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
# ALlow CORS requests
app.add_middleware(
CORSMiddleware,
@ -118,6 +100,20 @@ app.add_middleware(
)
async def check_model_container():
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
if model.container is None or not (
model.container.model_is_loading or model.container.model_loaded
):
error_message = handle_request_error(
"No models are currently loaded.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
# Model list endpoint
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
@ -139,35 +135,33 @@ async def list_models():
# Currently loaded model endpoint
@app.get(
"/v1/model",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_current_model():
"""Returns the currently loaded model."""
model_name = MODEL_CONTAINER.get_model_path().name
prompt_template = MODEL_CONTAINER.prompt_template
model_params = model.container.get_model_parameters()
draft_model_params = model_params.pop("draft", {})
if draft_model_params:
model_params["draft"] = ModelCard(
id=unwrap(draft_model_params.get("name"), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
else:
draft_model_params = None
model_card = ModelCard(
id=model_name,
parameters=ModelCardParameters(
rope_scale=MODEL_CONTAINER.config.scale_pos_emb,
rope_alpha=MODEL_CONTAINER.config.scale_alpha_value,
max_seq_len=MODEL_CONTAINER.config.max_seq_len,
cache_mode=MODEL_CONTAINER.cache_mode,
prompt_template=prompt_template.name if prompt_template else None,
num_experts_per_token=MODEL_CONTAINER.config.num_experts_per_token,
use_cfg=MODEL_CONTAINER.use_cfg,
),
id=unwrap(model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(model_params),
logging=gen_logging.PREFERENCES,
)
if MODEL_CONTAINER.draft_config:
if draft_model_params:
draft_card = ModelCard(
id=MODEL_CONTAINER.get_model_path(True).name,
parameters=ModelCardParameters(
rope_scale=MODEL_CONTAINER.draft_config.scale_pos_emb,
rope_alpha=MODEL_CONTAINER.draft_config.scale_alpha_value,
max_seq_len=MODEL_CONTAINER.draft_config.max_seq_len,
),
id=unwrap(draft_model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
model_card.parameters.draft = draft_card
return model_card
@ -211,35 +205,12 @@ async def load_model(request: Request, data: ModelLoadRequest):
if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?")
# Check if the model is already loaded
if MODEL_CONTAINER and MODEL_CONTAINER.model:
loaded_model_name = MODEL_CONTAINER.get_model_path().name
if loaded_model_name == data.name:
raise HTTPException(
400, f'Model "{loaded_model_name}"is already loaded! Aborting.'
)
async def generator():
"""Generator for the loading process."""
global MODEL_CONTAINER
# Unload the existing model
if MODEL_CONTAINER and MODEL_CONTAINER.model:
logger.info("Unloading existing model.")
await unload_model()
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
load_status = MODEL_CONTAINER.load_gen(load_progress)
"""Request generation wrapper for the loading process."""
load_status = model.load_model_gen(model_path, **load_data)
try:
progress = get_loading_progress_bar()
progress.start()
for module, modules in load_status:
# Get out if the request gets disconnected
async for module, modules, model_type in load_status:
if await request.is_disconnected():
release_semaphore()
logger.error(
@ -248,13 +219,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
)
return
if module == 0:
loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules
)
else:
progress.advance(loading_task)
if module != 0:
response = ModelLoadResponse(
model_type=model_type,
module=module,
@ -273,13 +238,6 @@ async def load_model(request: Request, data: ModelLoadRequest):
)
yield response.model_dump_json()
# Switch to model progress if the draft model is loaded
if model_type == "draft":
model_type = "model"
else:
progress.stop()
except CancelledError:
logger.error(
"Model load cancelled by user. "
@ -287,8 +245,6 @@ async def load_model(request: Request, data: ModelLoadRequest):
)
except Exception as exc:
yield get_generator_error(str(exc))
finally:
progress.stop()
# Determine whether to use or skip the queue
if data.skip_queue:
@ -306,14 +262,11 @@ async def load_model(request: Request, data: ModelLoadRequest):
# Unload model endpoint
@app.post(
"/v1/model/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_model():
"""Unloads the currently loaded model."""
global MODEL_CONTAINER
MODEL_CONTAINER.unload()
MODEL_CONTAINER = None
await model.unload_model()
@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
@ -326,7 +279,7 @@ async def get_templates():
@app.post(
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def switch_template(data: TemplateSwitchRequest):
"""Switch the currently loaded template"""
@ -335,19 +288,19 @@ async def switch_template(data: TemplateSwitchRequest):
try:
template = get_template_from_file(data.name)
MODEL_CONTAINER.prompt_template = template
model.container.prompt_template = template
except FileNotFoundError as e:
raise HTTPException(400, "Template does not exist. Check the name?") from e
@app.post(
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_template():
"""Unloads the currently selected template"""
MODEL_CONTAINER.prompt_template = None
model.container.prompt_template = None
# Sampler override endpoints
@ -405,7 +358,7 @@ async def get_all_loras():
# Currently loaded loras endpoint
@app.get(
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_active_loras():
"""Returns the currently loaded loras."""
@ -416,7 +369,7 @@ async def get_active_loras():
id=pathlib.Path(lora.lora_path).parent.name,
scaling=lora.lora_scaling * lora.lora_r / lora.lora_alpha,
),
MODEL_CONTAINER.active_loras,
model.container.active_loras,
)
)
)
@ -427,7 +380,7 @@ async def get_active_loras():
# Load lora endpoint
@app.post(
"/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def load_lora(data: LoraLoadRequest):
"""Loads a LoRA into the model container."""
@ -443,10 +396,10 @@ async def load_lora(data: LoraLoadRequest):
# Clean-up existing loras if present
def load_loras_internal():
if len(MODEL_CONTAINER.active_loras) > 0:
if len(model.container.active_loras) > 0:
unload_loras()
result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump())
result = model.container.load_loras(lora_dir, **data.model_dump())
return LoraLoadResponse(
success=unwrap(result.get("success"), []),
failure=unwrap(result.get("failure"), []),
@ -468,21 +421,21 @@ async def load_lora(data: LoraLoadRequest):
# Unload lora endpoint
@app.post(
"/v1/lora/unload",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
async def unload_loras():
"""Unloads the currently loaded loras."""
MODEL_CONTAINER.unload(loras_only=True)
model.container.unload(loras_only=True)
# Encode tokens endpoint
@app.post(
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def encode_tokens(data: TokenEncodeRequest):
"""Encodes a string into tokens."""
raw_tokens = MODEL_CONTAINER.encode_tokens(data.text, **data.get_params())
raw_tokens = model.container.encode_tokens(data.text, **data.get_params())
tokens = unwrap(raw_tokens, [])
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
@ -492,11 +445,11 @@ async def encode_tokens(data: TokenEncodeRequest):
# Decode tokens endpoint
@app.post(
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def decode_tokens(data: TokenDecodeRequest):
"""Decodes tokens into a string."""
message = MODEL_CONTAINER.decode_tokens(data.tokens, **data.get_params())
message = model.container.decode_tokens(data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))
return response
@ -505,11 +458,11 @@ async def decode_tokens(data: TokenDecodeRequest):
# Completions endpoint
@app.post(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def generate_completion(request: Request, data: CompletionRequest):
"""Generates a completion from a prompt."""
model_path = MODEL_CONTAINER.get_model_path()
model_path = model.container.get_model_path()
if isinstance(data.prompt, list):
data.prompt = "\n".join(data.prompt)
@ -522,7 +475,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
async def generator():
try:
new_generation = MODEL_CONTAINER.generate_gen(
new_generation = model.container.generate_gen(
data.prompt, **data.to_gen_params()
)
for generation in new_generation:
@ -549,7 +502,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
generation = await call_with_semaphore(
partial(
run_in_threadpool,
MODEL_CONTAINER.generate,
model.container.generate,
data.prompt,
**data.to_gen_params(),
)
@ -570,30 +523,31 @@ async def generate_completion(request: Request, data: CompletionRequest):
# Chat completions endpoint
@app.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(_check_model_container)],
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
"""Generates a chat completion from a prompt."""
if MODEL_CONTAINER.prompt_template is None:
if model.container.prompt_template is None:
raise HTTPException(
422,
"This endpoint is disabled because a prompt template is not set.",
)
model_path = MODEL_CONTAINER.get_model_path()
model_path = model.container.get_model_path()
if isinstance(data.messages, str):
prompt = data.messages
else:
try:
special_tokens_dict = MODEL_CONTAINER.get_special_tokens(
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True),
unwrap(data.ban_eos_token, False),
)
prompt = get_prompt_from_template(
data.messages,
MODEL_CONTAINER.prompt_template,
model.container.prompt_template,
data.add_generation_prompt,
special_tokens_dict,
)
@ -601,7 +555,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
raise HTTPException(
400,
"Could not find a Conversation from prompt template "
f"'{MODEL_CONTAINER.prompt_template.name}'. "
f"'{model.container.prompt_template.name}'. "
"Check your spelling?",
) from exc
except TemplateError as exc:
@ -620,7 +574,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
async def generator():
"""Generator for the generation process."""
try:
new_generation = MODEL_CONTAINER.generate_gen(
new_generation = model.container.generate_gen(
prompt, **data.to_gen_params()
)
for generation in new_generation:
@ -653,7 +607,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
generation = await call_with_semaphore(
partial(
run_in_threadpool,
MODEL_CONTAINER.generate,
model.container.generate,
prompt,
**data.to_gen_params(),
)
@ -692,11 +646,9 @@ def signal_handler(*_):
sys.exit(0)
def entrypoint(args: Optional[dict] = None):
async def entrypoint(args: Optional[dict] = None):
"""Entry function for program startup"""
global MODEL_CONTAINER
setup_logger()
# Set up signal aborting
@ -782,34 +734,13 @@ def entrypoint(args: Optional[dict] = None):
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
model_path = model_path / model_name
MODEL_CONTAINER = ExllamaV2Container(
model_path.resolve(), False, **model_config
)
load_status = MODEL_CONTAINER.load_gen(load_progress)
progress = get_loading_progress_bar()
progress.start()
model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
for module, modules in load_status:
if module == 0:
loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules
)
else:
progress.advance(loading_task, 1)
if module == modules:
if model_type == "draft":
model_type = "model"
else:
progress.stop()
await model.load_model(model_path.resolve(), **model_config)
# Load loras after loading the model
lora_config = get_lora_config()
if lora_config.get("loras"):
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
MODEL_CONTAINER.load_loras(lora_dir.resolve(), **lora_config)
model.container.load_loras(lora_dir.resolve(), **lora_config)
# TODO: Replace this with abortables, async via producer consumer, or something else
api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True)
@ -821,4 +752,4 @@ def entrypoint(args: Optional[dict] = None):
if __name__ == "__main__":
entrypoint()
asyncio.run(entrypoint())