API: Add more methods to semaphore

The semaphore/queue model for Tabby is as follows:
- Any load requests go through the semaphore by default
- Any load request can include the skip_queue parameter to bypass
the semaphore
- Any unload requests are immediately executed
- All completion requests are placed inside the semaphore by default

This model preserves the parallelism of single-user mode with extra
convenience methods for queues in multi-user. It also helps mitigate
problems that were previously present in the concurrency stack.

Also change how the program's loop runs so it exits when the API thread
dies.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-03 01:22:34 -05:00 committed by Brian Dashore
parent c82697fef2
commit b0c295dd2f
5 changed files with 77 additions and 39 deletions

View file

@ -32,6 +32,7 @@ class LoraLoadRequest(BaseModel):
"""Represents a Lora load request.""" """Represents a Lora load request."""
loras: List[LoraLoadInfo] loras: List[LoraLoadInfo]
skip_queue: bool = False
class LoraLoadResponse(BaseModel): class LoraLoadResponse(BaseModel):

View file

@ -93,6 +93,7 @@ class ModelLoadRequest(BaseModel):
use_cfg: Optional[bool] = None use_cfg: Optional[bool] = None
fasttensors: Optional[bool] = False fasttensors: Optional[bool] = False
draft: Optional[DraftModelLoadRequest] = None draft: Optional[DraftModelLoadRequest] = None
skip_queue: Optional[bool] = False
class ModelLoadResponse(BaseModel): class ModelLoadResponse(BaseModel):

View file

@ -55,6 +55,7 @@ class ExllamaV2Container:
autosplit_reserve: List[float] = [96 * 1024**2] autosplit_reserve: List[float] = [96 * 1024**2]
# Load state # Load state
model_is_loading: bool = False
model_loaded: bool = False model_loaded: bool = False
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs): def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
@ -350,6 +351,9 @@ class ExllamaV2Container:
def progress(loaded_modules: int, total_modules: int) def progress(loaded_modules: int, total_modules: int)
""" """
# Notify that the model is being loaded
self.model_is_loading = True
# Load tokenizer # Load tokenizer
self.tokenizer = ExLlamaV2Tokenizer(self.config) self.tokenizer = ExLlamaV2Tokenizer(self.config)
@ -439,6 +443,7 @@ class ExllamaV2Container:
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Update model load state # Update model load state
self.model_is_loading = False
self.model_loaded = True self.model_loaded = True
logger.info("Model successfully loaded.") logger.info("Model successfully loaded.")
@ -472,7 +477,7 @@ class ExllamaV2Container:
# Update model load state # Update model load state
self.model_loaded = False self.model_loaded = False
logger.info("Model unloaded.") logger.info("Loras unloaded." if loras_only else "Model unloaded.")
def encode_tokens(self, text: str, **kwargs): def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string""" """Wrapper to encode tokens from a text string"""

View file

@ -45,8 +45,10 @@ def handle_request_error(message: str):
request_error = TabbyRequestError(error=error_message) request_error = TabbyRequestError(error=error_message)
# Log the error and provided message to the console # Log the error and provided message to the console
logger.error(error_message.trace) if error_message.trace:
logger.error(message) logger.error(error_message.trace)
logger.error(f"Sent to request: {message}")
return request_error return request_error

101
main.py
View file

@ -1,9 +1,9 @@
"""The main tabbyAPI module. Contains the FastAPI server and endpoints.""" """The main tabbyAPI module. Contains the FastAPI server and endpoints."""
import asyncio
import os import os
import pathlib import pathlib
import signal import signal
import sys import sys
import time
import uvicorn import uvicorn
import threading import threading
from asyncio import CancelledError from asyncio import CancelledError
@ -93,7 +93,9 @@ MODEL_CONTAINER: Optional[ExllamaV2Container] = None
def _check_model_container(): def _check_model_container():
if MODEL_CONTAINER is None or not MODEL_CONTAINER.model_loaded: if MODEL_CONTAINER is None or not (
MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded
):
error_message = handle_request_error( error_message = handle_request_error(
"No models are currently loaded." "No models are currently loaded."
).error.message ).error.message
@ -183,24 +185,13 @@ def list_draft_models():
# Load model endpoint # Load model endpoint
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)]) @app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
def load_model(request: Request, data: ModelLoadRequest): async def load_model(request: Request, data: ModelLoadRequest):
"""Loads a model into the model container.""" """Loads a model into the model container."""
global MODEL_CONTAINER
# Verify request parameters
if not data.name: if not data.name:
raise HTTPException(400, "A model name was not provided.") raise HTTPException(400, "A model name was not provided.")
# Unload the existing model
if MODEL_CONTAINER and MODEL_CONTAINER.model:
loaded_model_name = MODEL_CONTAINER.get_model_path().name
if loaded_model_name == data.name:
raise HTTPException(
400, f'Model "{loaded_model_name}"is already loaded! Aborting.'
)
else:
MODEL_CONTAINER.unload()
model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models")) model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models"))
model_path = model_path / data.name model_path = model_path / data.name
@ -219,10 +210,24 @@ def load_model(request: Request, data: ModelLoadRequest):
if not model_path.exists(): if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?") raise HTTPException(400, "model_path does not exist. Check model_name?")
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data) # Check if the model is already loaded
if MODEL_CONTAINER and MODEL_CONTAINER.model:
loaded_model_name = MODEL_CONTAINER.get_model_path().name
if loaded_model_name == data.name:
raise HTTPException(
400, f'Model "{loaded_model_name}"is already loaded! Aborting.'
)
async def generator(): async def generator():
"""Generator for the loading process.""" """Generator for the loading process."""
global MODEL_CONTAINER
# Unload the existing model
if MODEL_CONTAINER and MODEL_CONTAINER.model:
unload_model()
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
model_type = "draft" if MODEL_CONTAINER.draft_config else "model" model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
load_status = MODEL_CONTAINER.load_gen(load_progress) load_status = MODEL_CONTAINER.load_gen(load_progress)
@ -230,6 +235,7 @@ def load_model(request: Request, data: ModelLoadRequest):
try: try:
for module, modules in load_status: for module, modules in load_status:
if await request.is_disconnected(): if await request.is_disconnected():
unload_model()
break break
if module == 0: if module == 0:
@ -269,7 +275,17 @@ def load_model(request: Request, data: ModelLoadRequest):
except Exception as exc: except Exception as exc:
yield get_generator_error(str(exc)) yield get_generator_error(str(exc))
return StreamingResponse(generator(), media_type="text/event-stream") # Determine whether to use or skip the queue
if data.skip_queue:
logger.warning(
"Model load request is skipping the completions queue. "
"Unexpected results may occur."
)
generator_callback = generator
else:
generator_callback = partial(generate_with_semaphore, generator)
return StreamingResponse(generator_callback(), media_type="text/event-stream")
# Unload model endpoint # Unload model endpoint
@ -398,7 +414,7 @@ def get_active_loras():
"/v1/lora/load", "/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(_check_model_container)], dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
) )
def load_lora(data: LoraLoadRequest): async def load_lora(data: LoraLoadRequest):
"""Loads a LoRA into the model container.""" """Loads a LoRA into the model container."""
if not data.loras: if not data.loras:
raise HTTPException(400, "List of loras to load is not found.") raise HTTPException(400, "List of loras to load is not found.")
@ -411,14 +427,27 @@ def load_lora(data: LoraLoadRequest):
) )
# Clean-up existing loras if present # Clean-up existing loras if present
if len(MODEL_CONTAINER.active_loras) > 0: def load_loras_internal():
MODEL_CONTAINER.unload(True) if len(MODEL_CONTAINER.active_loras) > 0:
unload_loras()
result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump()) result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump())
return LoraLoadResponse( return LoraLoadResponse(
success=unwrap(result.get("success"), []), success=unwrap(result.get("success"), []),
failure=unwrap(result.get("failure"), []), failure=unwrap(result.get("failure"), []),
) )
internal_callback = partial(run_in_threadpool, load_loras_internal)
# Determine whether to skip the queue
if data.skip_queue:
logger.warning(
"Lora load request is skipping the completions queue. "
"Unexpected results may occur."
)
return await internal_callback()
else:
return await call_with_semaphore(internal_callback)
# Unload lora endpoint # Unload lora endpoint
@ -428,7 +457,7 @@ def load_lora(data: LoraLoadRequest):
) )
def unload_loras(): def unload_loras():
"""Unloads the currently loaded loras.""" """Unloads the currently loaded loras."""
MODEL_CONTAINER.unload(True) MODEL_CONTAINER.unload(loras_only=True)
# Encode tokens endpoint # Encode tokens endpoint
@ -498,7 +527,8 @@ async def generate_completion(request: Request, data: CompletionRequest):
) )
return StreamingResponse( return StreamingResponse(
generate_with_semaphore(generator), media_type="text/event-stream" generate_with_semaphore(generator),
media_type="text/event-stream",
) )
try: try:
@ -515,7 +545,8 @@ async def generate_completion(request: Request, data: CompletionRequest):
return response return response
except Exception as exc: except Exception as exc:
error_message = handle_request_error( error_message = handle_request_error(
"Completion aborted. Please check the server console." "Completion aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message ).error.message
# Server error if there's a generation exception # Server error if there's a generation exception
@ -617,7 +648,8 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
return response return response
except Exception as exc: except Exception as exc:
error_message = handle_request_error( error_message = handle_request_error(
"Chat completion aborted. Please check the server console." "Chat completion aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message ).error.message
# Server error if there's a generation exception # Server error if there's a generation exception
@ -636,7 +668,6 @@ def start_api(host: str, port: int):
app, app,
host=host, host=host,
port=port, port=port,
log_level="debug",
) )
@ -733,15 +764,13 @@ def entrypoint(args: Optional[dict] = None):
host = unwrap(network_config.get("host"), "127.0.0.1") host = unwrap(network_config.get("host"), "127.0.0.1")
port = unwrap(network_config.get("port"), 5000) port = unwrap(network_config.get("port"), 5000)
# Start the API in a daemon thread
# This allows for command signals to be passed and properly shut down the program
# Otherwise the program will hang
# TODO: Replace this with abortables, async via producer consumer, or something else # TODO: Replace this with abortables, async via producer consumer, or something else
threading.Thread(target=partial(start_api, host, port), daemon=True).start() api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True)
api_thread.start()
# Keep the program alive # Keep the program alive
loop = asyncio.get_event_loop() while api_thread.is_alive():
loop.run_forever() time.sleep(0.5)
if __name__ == "__main__": if __name__ == "__main__":