API: Add more methods to semaphore
The semaphore/queue model for Tabby is as follows: - Any load requests go through the semaphore by default - Any load request can include the skip_queue parameter to bypass the semaphore - Any unload requests are immediately executed - All completion requests are placed inside the semaphore by default This model preserves the parallelism of single-user mode with extra convenience methods for queues in multi-user. It also helps mitigate problems that were previously present in the concurrency stack. Also change how the program's loop runs so it exits when the API thread dies. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
c82697fef2
commit
b0c295dd2f
5 changed files with 77 additions and 39 deletions
|
|
@ -32,6 +32,7 @@ class LoraLoadRequest(BaseModel):
|
|||
"""Represents a Lora load request."""
|
||||
|
||||
loras: List[LoraLoadInfo]
|
||||
skip_queue: bool = False
|
||||
|
||||
|
||||
class LoraLoadResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ class ModelLoadRequest(BaseModel):
|
|||
use_cfg: Optional[bool] = None
|
||||
fasttensors: Optional[bool] = False
|
||||
draft: Optional[DraftModelLoadRequest] = None
|
||||
skip_queue: Optional[bool] = False
|
||||
|
||||
|
||||
class ModelLoadResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -55,6 +55,7 @@ class ExllamaV2Container:
|
|||
autosplit_reserve: List[float] = [96 * 1024**2]
|
||||
|
||||
# Load state
|
||||
model_is_loading: bool = False
|
||||
model_loaded: bool = False
|
||||
|
||||
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||
|
|
@ -350,6 +351,9 @@ class ExllamaV2Container:
|
|||
def progress(loaded_modules: int, total_modules: int)
|
||||
"""
|
||||
|
||||
# Notify that the model is being loaded
|
||||
self.model_is_loading = True
|
||||
|
||||
# Load tokenizer
|
||||
self.tokenizer = ExLlamaV2Tokenizer(self.config)
|
||||
|
||||
|
|
@ -439,6 +443,7 @@ class ExllamaV2Container:
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
# Update model load state
|
||||
self.model_is_loading = False
|
||||
self.model_loaded = True
|
||||
logger.info("Model successfully loaded.")
|
||||
|
||||
|
|
@ -472,7 +477,7 @@ class ExllamaV2Container:
|
|||
|
||||
# Update model load state
|
||||
self.model_loaded = False
|
||||
logger.info("Model unloaded.")
|
||||
logger.info("Loras unloaded." if loras_only else "Model unloaded.")
|
||||
|
||||
def encode_tokens(self, text: str, **kwargs):
|
||||
"""Wrapper to encode tokens from a text string"""
|
||||
|
|
|
|||
|
|
@ -45,8 +45,10 @@ def handle_request_error(message: str):
|
|||
request_error = TabbyRequestError(error=error_message)
|
||||
|
||||
# Log the error and provided message to the console
|
||||
logger.error(error_message.trace)
|
||||
logger.error(message)
|
||||
if error_message.trace:
|
||||
logger.error(error_message.trace)
|
||||
|
||||
logger.error(f"Sent to request: {message}")
|
||||
|
||||
return request_error
|
||||
|
||||
|
|
|
|||
101
main.py
101
main.py
|
|
@ -1,9 +1,9 @@
|
|||
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
|
||||
import asyncio
|
||||
import os
|
||||
import pathlib
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
import uvicorn
|
||||
import threading
|
||||
from asyncio import CancelledError
|
||||
|
|
@ -93,7 +93,9 @@ MODEL_CONTAINER: Optional[ExllamaV2Container] = None
|
|||
|
||||
|
||||
def _check_model_container():
|
||||
if MODEL_CONTAINER is None or not MODEL_CONTAINER.model_loaded:
|
||||
if MODEL_CONTAINER is None or not (
|
||||
MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded
|
||||
):
|
||||
error_message = handle_request_error(
|
||||
"No models are currently loaded."
|
||||
).error.message
|
||||
|
|
@ -183,24 +185,13 @@ def list_draft_models():
|
|||
|
||||
# Load model endpoint
|
||||
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||
def load_model(request: Request, data: ModelLoadRequest):
|
||||
async def load_model(request: Request, data: ModelLoadRequest):
|
||||
"""Loads a model into the model container."""
|
||||
global MODEL_CONTAINER
|
||||
|
||||
# Verify request parameters
|
||||
if not data.name:
|
||||
raise HTTPException(400, "A model name was not provided.")
|
||||
|
||||
# Unload the existing model
|
||||
if MODEL_CONTAINER and MODEL_CONTAINER.model:
|
||||
loaded_model_name = MODEL_CONTAINER.get_model_path().name
|
||||
|
||||
if loaded_model_name == data.name:
|
||||
raise HTTPException(
|
||||
400, f'Model "{loaded_model_name}"is already loaded! Aborting.'
|
||||
)
|
||||
else:
|
||||
MODEL_CONTAINER.unload()
|
||||
|
||||
model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models"))
|
||||
model_path = model_path / data.name
|
||||
|
||||
|
|
@ -219,10 +210,24 @@ def load_model(request: Request, data: ModelLoadRequest):
|
|||
if not model_path.exists():
|
||||
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
||||
|
||||
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
|
||||
# Check if the model is already loaded
|
||||
if MODEL_CONTAINER and MODEL_CONTAINER.model:
|
||||
loaded_model_name = MODEL_CONTAINER.get_model_path().name
|
||||
|
||||
if loaded_model_name == data.name:
|
||||
raise HTTPException(
|
||||
400, f'Model "{loaded_model_name}"is already loaded! Aborting.'
|
||||
)
|
||||
|
||||
async def generator():
|
||||
"""Generator for the loading process."""
|
||||
global MODEL_CONTAINER
|
||||
|
||||
# Unload the existing model
|
||||
if MODEL_CONTAINER and MODEL_CONTAINER.model:
|
||||
unload_model()
|
||||
|
||||
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
|
||||
|
||||
model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
|
||||
load_status = MODEL_CONTAINER.load_gen(load_progress)
|
||||
|
|
@ -230,6 +235,7 @@ def load_model(request: Request, data: ModelLoadRequest):
|
|||
try:
|
||||
for module, modules in load_status:
|
||||
if await request.is_disconnected():
|
||||
unload_model()
|
||||
break
|
||||
|
||||
if module == 0:
|
||||
|
|
@ -269,7 +275,17 @@ def load_model(request: Request, data: ModelLoadRequest):
|
|||
except Exception as exc:
|
||||
yield get_generator_error(str(exc))
|
||||
|
||||
return StreamingResponse(generator(), media_type="text/event-stream")
|
||||
# Determine whether to use or skip the queue
|
||||
if data.skip_queue:
|
||||
logger.warning(
|
||||
"Model load request is skipping the completions queue. "
|
||||
"Unexpected results may occur."
|
||||
)
|
||||
generator_callback = generator
|
||||
else:
|
||||
generator_callback = partial(generate_with_semaphore, generator)
|
||||
|
||||
return StreamingResponse(generator_callback(), media_type="text/event-stream")
|
||||
|
||||
|
||||
# Unload model endpoint
|
||||
|
|
@ -398,7 +414,7 @@ def get_active_loras():
|
|||
"/v1/lora/load",
|
||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||
)
|
||||
def load_lora(data: LoraLoadRequest):
|
||||
async def load_lora(data: LoraLoadRequest):
|
||||
"""Loads a LoRA into the model container."""
|
||||
if not data.loras:
|
||||
raise HTTPException(400, "List of loras to load is not found.")
|
||||
|
|
@ -411,14 +427,27 @@ def load_lora(data: LoraLoadRequest):
|
|||
)
|
||||
|
||||
# Clean-up existing loras if present
|
||||
if len(MODEL_CONTAINER.active_loras) > 0:
|
||||
MODEL_CONTAINER.unload(True)
|
||||
def load_loras_internal():
|
||||
if len(MODEL_CONTAINER.active_loras) > 0:
|
||||
unload_loras()
|
||||
|
||||
result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump())
|
||||
return LoraLoadResponse(
|
||||
success=unwrap(result.get("success"), []),
|
||||
failure=unwrap(result.get("failure"), []),
|
||||
)
|
||||
result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump())
|
||||
return LoraLoadResponse(
|
||||
success=unwrap(result.get("success"), []),
|
||||
failure=unwrap(result.get("failure"), []),
|
||||
)
|
||||
|
||||
internal_callback = partial(run_in_threadpool, load_loras_internal)
|
||||
|
||||
# Determine whether to skip the queue
|
||||
if data.skip_queue:
|
||||
logger.warning(
|
||||
"Lora load request is skipping the completions queue. "
|
||||
"Unexpected results may occur."
|
||||
)
|
||||
return await internal_callback()
|
||||
else:
|
||||
return await call_with_semaphore(internal_callback)
|
||||
|
||||
|
||||
# Unload lora endpoint
|
||||
|
|
@ -428,7 +457,7 @@ def load_lora(data: LoraLoadRequest):
|
|||
)
|
||||
def unload_loras():
|
||||
"""Unloads the currently loaded loras."""
|
||||
MODEL_CONTAINER.unload(True)
|
||||
MODEL_CONTAINER.unload(loras_only=True)
|
||||
|
||||
|
||||
# Encode tokens endpoint
|
||||
|
|
@ -498,7 +527,8 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
)
|
||||
|
||||
return StreamingResponse(
|
||||
generate_with_semaphore(generator), media_type="text/event-stream"
|
||||
generate_with_semaphore(generator),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
@ -515,7 +545,8 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Completion aborted. Please check the server console."
|
||||
"Completion aborted. Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
# Server error if there's a generation exception
|
||||
|
|
@ -617,7 +648,8 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Chat completion aborted. Please check the server console."
|
||||
"Chat completion aborted. Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
# Server error if there's a generation exception
|
||||
|
|
@ -636,7 +668,6 @@ def start_api(host: str, port: int):
|
|||
app,
|
||||
host=host,
|
||||
port=port,
|
||||
log_level="debug",
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -733,15 +764,13 @@ def entrypoint(args: Optional[dict] = None):
|
|||
host = unwrap(network_config.get("host"), "127.0.0.1")
|
||||
port = unwrap(network_config.get("port"), 5000)
|
||||
|
||||
# Start the API in a daemon thread
|
||||
# This allows for command signals to be passed and properly shut down the program
|
||||
# Otherwise the program will hang
|
||||
# TODO: Replace this with abortables, async via producer consumer, or something else
|
||||
threading.Thread(target=partial(start_api, host, port), daemon=True).start()
|
||||
api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True)
|
||||
|
||||
api_thread.start()
|
||||
# Keep the program alive
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_forever()
|
||||
while api_thread.is_alive():
|
||||
time.sleep(0.5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue