API: Add more methods to semaphore

The semaphore/queue model for Tabby is as follows:
- Any load requests go through the semaphore by default
- Any load request can include the skip_queue parameter to bypass
the semaphore
- Any unload requests are immediately executed
- All completion requests are placed inside the semaphore by default

This model preserves the parallelism of single-user mode with extra
convenience methods for queues in multi-user. It also helps mitigate
problems that were previously present in the concurrency stack.

Also change how the program's loop runs so it exits when the API thread
dies.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-03 01:22:34 -05:00 committed by Brian Dashore
parent c82697fef2
commit b0c295dd2f
5 changed files with 77 additions and 39 deletions

101
main.py
View file

@ -1,9 +1,9 @@
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
import asyncio
import os
import pathlib
import signal
import sys
import time
import uvicorn
import threading
from asyncio import CancelledError
@ -93,7 +93,9 @@ MODEL_CONTAINER: Optional[ExllamaV2Container] = None
def _check_model_container():
if MODEL_CONTAINER is None or not MODEL_CONTAINER.model_loaded:
if MODEL_CONTAINER is None or not (
MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded
):
error_message = handle_request_error(
"No models are currently loaded."
).error.message
@ -183,24 +185,13 @@ def list_draft_models():
# Load model endpoint
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
def load_model(request: Request, data: ModelLoadRequest):
async def load_model(request: Request, data: ModelLoadRequest):
"""Loads a model into the model container."""
global MODEL_CONTAINER
# Verify request parameters
if not data.name:
raise HTTPException(400, "A model name was not provided.")
# Unload the existing model
if MODEL_CONTAINER and MODEL_CONTAINER.model:
loaded_model_name = MODEL_CONTAINER.get_model_path().name
if loaded_model_name == data.name:
raise HTTPException(
400, f'Model "{loaded_model_name}"is already loaded! Aborting.'
)
else:
MODEL_CONTAINER.unload()
model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models"))
model_path = model_path / data.name
@ -219,10 +210,24 @@ def load_model(request: Request, data: ModelLoadRequest):
if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?")
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
# Check if the model is already loaded
if MODEL_CONTAINER and MODEL_CONTAINER.model:
loaded_model_name = MODEL_CONTAINER.get_model_path().name
if loaded_model_name == data.name:
raise HTTPException(
400, f'Model "{loaded_model_name}"is already loaded! Aborting.'
)
async def generator():
"""Generator for the loading process."""
global MODEL_CONTAINER
# Unload the existing model
if MODEL_CONTAINER and MODEL_CONTAINER.model:
unload_model()
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
load_status = MODEL_CONTAINER.load_gen(load_progress)
@ -230,6 +235,7 @@ def load_model(request: Request, data: ModelLoadRequest):
try:
for module, modules in load_status:
if await request.is_disconnected():
unload_model()
break
if module == 0:
@ -269,7 +275,17 @@ def load_model(request: Request, data: ModelLoadRequest):
except Exception as exc:
yield get_generator_error(str(exc))
return StreamingResponse(generator(), media_type="text/event-stream")
# Determine whether to use or skip the queue
if data.skip_queue:
logger.warning(
"Model load request is skipping the completions queue. "
"Unexpected results may occur."
)
generator_callback = generator
else:
generator_callback = partial(generate_with_semaphore, generator)
return StreamingResponse(generator_callback(), media_type="text/event-stream")
# Unload model endpoint
@ -398,7 +414,7 @@ def get_active_loras():
"/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
)
def load_lora(data: LoraLoadRequest):
async def load_lora(data: LoraLoadRequest):
"""Loads a LoRA into the model container."""
if not data.loras:
raise HTTPException(400, "List of loras to load is not found.")
@ -411,14 +427,27 @@ def load_lora(data: LoraLoadRequest):
)
# Clean-up existing loras if present
if len(MODEL_CONTAINER.active_loras) > 0:
MODEL_CONTAINER.unload(True)
def load_loras_internal():
if len(MODEL_CONTAINER.active_loras) > 0:
unload_loras()
result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump())
return LoraLoadResponse(
success=unwrap(result.get("success"), []),
failure=unwrap(result.get("failure"), []),
)
result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump())
return LoraLoadResponse(
success=unwrap(result.get("success"), []),
failure=unwrap(result.get("failure"), []),
)
internal_callback = partial(run_in_threadpool, load_loras_internal)
# Determine whether to skip the queue
if data.skip_queue:
logger.warning(
"Lora load request is skipping the completions queue. "
"Unexpected results may occur."
)
return await internal_callback()
else:
return await call_with_semaphore(internal_callback)
# Unload lora endpoint
@ -428,7 +457,7 @@ def load_lora(data: LoraLoadRequest):
)
def unload_loras():
"""Unloads the currently loaded loras."""
MODEL_CONTAINER.unload(True)
MODEL_CONTAINER.unload(loras_only=True)
# Encode tokens endpoint
@ -498,7 +527,8 @@ async def generate_completion(request: Request, data: CompletionRequest):
)
return StreamingResponse(
generate_with_semaphore(generator), media_type="text/event-stream"
generate_with_semaphore(generator),
media_type="text/event-stream",
)
try:
@ -515,7 +545,8 @@ async def generate_completion(request: Request, data: CompletionRequest):
return response
except Exception as exc:
error_message = handle_request_error(
"Completion aborted. Please check the server console."
"Completion aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message
# Server error if there's a generation exception
@ -617,7 +648,8 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
return response
except Exception as exc:
error_message = handle_request_error(
"Chat completion aborted. Please check the server console."
"Chat completion aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message
# Server error if there's a generation exception
@ -636,7 +668,6 @@ def start_api(host: str, port: int):
app,
host=host,
port=port,
log_level="debug",
)
@ -733,15 +764,13 @@ def entrypoint(args: Optional[dict] = None):
host = unwrap(network_config.get("host"), "127.0.0.1")
port = unwrap(network_config.get("port"), 5000)
# Start the API in a daemon thread
# This allows for command signals to be passed and properly shut down the program
# Otherwise the program will hang
# TODO: Replace this with abortables, async via producer consumer, or something else
threading.Thread(target=partial(start_api, host, port), daemon=True).start()
api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True)
api_thread.start()
# Keep the program alive
loop = asyncio.get_event_loop()
loop.run_forever()
while api_thread.is_alive():
time.sleep(0.5)
if __name__ == "__main__":