API: Add more methods to semaphore
The semaphore/queue model for Tabby is as follows: - Any load requests go through the semaphore by default - Any load request can include the skip_queue parameter to bypass the semaphore - Any unload requests are immediately executed - All completion requests are placed inside the semaphore by default This model preserves the parallelism of single-user mode with extra convenience methods for queues in multi-user. It also helps mitigate problems that were previously present in the concurrency stack. Also change how the program's loop runs so it exits when the API thread dies. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
c82697fef2
commit
b0c295dd2f
5 changed files with 77 additions and 39 deletions
|
|
@ -32,6 +32,7 @@ class LoraLoadRequest(BaseModel):
|
||||||
"""Represents a Lora load request."""
|
"""Represents a Lora load request."""
|
||||||
|
|
||||||
loras: List[LoraLoadInfo]
|
loras: List[LoraLoadInfo]
|
||||||
|
skip_queue: bool = False
|
||||||
|
|
||||||
|
|
||||||
class LoraLoadResponse(BaseModel):
|
class LoraLoadResponse(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -93,6 +93,7 @@ class ModelLoadRequest(BaseModel):
|
||||||
use_cfg: Optional[bool] = None
|
use_cfg: Optional[bool] = None
|
||||||
fasttensors: Optional[bool] = False
|
fasttensors: Optional[bool] = False
|
||||||
draft: Optional[DraftModelLoadRequest] = None
|
draft: Optional[DraftModelLoadRequest] = None
|
||||||
|
skip_queue: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
class ModelLoadResponse(BaseModel):
|
class ModelLoadResponse(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -55,6 +55,7 @@ class ExllamaV2Container:
|
||||||
autosplit_reserve: List[float] = [96 * 1024**2]
|
autosplit_reserve: List[float] = [96 * 1024**2]
|
||||||
|
|
||||||
# Load state
|
# Load state
|
||||||
|
model_is_loading: bool = False
|
||||||
model_loaded: bool = False
|
model_loaded: bool = False
|
||||||
|
|
||||||
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
|
def __init__(self, model_directory: pathlib.Path, quiet=False, **kwargs):
|
||||||
|
|
@ -350,6 +351,9 @@ class ExllamaV2Container:
|
||||||
def progress(loaded_modules: int, total_modules: int)
|
def progress(loaded_modules: int, total_modules: int)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Notify that the model is being loaded
|
||||||
|
self.model_is_loading = True
|
||||||
|
|
||||||
# Load tokenizer
|
# Load tokenizer
|
||||||
self.tokenizer = ExLlamaV2Tokenizer(self.config)
|
self.tokenizer = ExLlamaV2Tokenizer(self.config)
|
||||||
|
|
||||||
|
|
@ -439,6 +443,7 @@ class ExllamaV2Container:
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
# Update model load state
|
# Update model load state
|
||||||
|
self.model_is_loading = False
|
||||||
self.model_loaded = True
|
self.model_loaded = True
|
||||||
logger.info("Model successfully loaded.")
|
logger.info("Model successfully loaded.")
|
||||||
|
|
||||||
|
|
@ -472,7 +477,7 @@ class ExllamaV2Container:
|
||||||
|
|
||||||
# Update model load state
|
# Update model load state
|
||||||
self.model_loaded = False
|
self.model_loaded = False
|
||||||
logger.info("Model unloaded.")
|
logger.info("Loras unloaded." if loras_only else "Model unloaded.")
|
||||||
|
|
||||||
def encode_tokens(self, text: str, **kwargs):
|
def encode_tokens(self, text: str, **kwargs):
|
||||||
"""Wrapper to encode tokens from a text string"""
|
"""Wrapper to encode tokens from a text string"""
|
||||||
|
|
|
||||||
|
|
@ -45,8 +45,10 @@ def handle_request_error(message: str):
|
||||||
request_error = TabbyRequestError(error=error_message)
|
request_error = TabbyRequestError(error=error_message)
|
||||||
|
|
||||||
# Log the error and provided message to the console
|
# Log the error and provided message to the console
|
||||||
logger.error(error_message.trace)
|
if error_message.trace:
|
||||||
logger.error(message)
|
logger.error(error_message.trace)
|
||||||
|
|
||||||
|
logger.error(f"Sent to request: {message}")
|
||||||
|
|
||||||
return request_error
|
return request_error
|
||||||
|
|
||||||
|
|
|
||||||
101
main.py
101
main.py
|
|
@ -1,9 +1,9 @@
|
||||||
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
|
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
|
||||||
import asyncio
|
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import time
|
||||||
import uvicorn
|
import uvicorn
|
||||||
import threading
|
import threading
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
|
|
@ -93,7 +93,9 @@ MODEL_CONTAINER: Optional[ExllamaV2Container] = None
|
||||||
|
|
||||||
|
|
||||||
def _check_model_container():
|
def _check_model_container():
|
||||||
if MODEL_CONTAINER is None or not MODEL_CONTAINER.model_loaded:
|
if MODEL_CONTAINER is None or not (
|
||||||
|
MODEL_CONTAINER.model_is_loading or MODEL_CONTAINER.model_loaded
|
||||||
|
):
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
"No models are currently loaded."
|
"No models are currently loaded."
|
||||||
).error.message
|
).error.message
|
||||||
|
|
@ -183,24 +185,13 @@ def list_draft_models():
|
||||||
|
|
||||||
# Load model endpoint
|
# Load model endpoint
|
||||||
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||||
def load_model(request: Request, data: ModelLoadRequest):
|
async def load_model(request: Request, data: ModelLoadRequest):
|
||||||
"""Loads a model into the model container."""
|
"""Loads a model into the model container."""
|
||||||
global MODEL_CONTAINER
|
|
||||||
|
|
||||||
|
# Verify request parameters
|
||||||
if not data.name:
|
if not data.name:
|
||||||
raise HTTPException(400, "A model name was not provided.")
|
raise HTTPException(400, "A model name was not provided.")
|
||||||
|
|
||||||
# Unload the existing model
|
|
||||||
if MODEL_CONTAINER and MODEL_CONTAINER.model:
|
|
||||||
loaded_model_name = MODEL_CONTAINER.get_model_path().name
|
|
||||||
|
|
||||||
if loaded_model_name == data.name:
|
|
||||||
raise HTTPException(
|
|
||||||
400, f'Model "{loaded_model_name}"is already loaded! Aborting.'
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
MODEL_CONTAINER.unload()
|
|
||||||
|
|
||||||
model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models"))
|
model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models"))
|
||||||
model_path = model_path / data.name
|
model_path = model_path / data.name
|
||||||
|
|
||||||
|
|
@ -219,10 +210,24 @@ def load_model(request: Request, data: ModelLoadRequest):
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
raise HTTPException(400, "model_path does not exist. Check model_name?")
|
||||||
|
|
||||||
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
|
# Check if the model is already loaded
|
||||||
|
if MODEL_CONTAINER and MODEL_CONTAINER.model:
|
||||||
|
loaded_model_name = MODEL_CONTAINER.get_model_path().name
|
||||||
|
|
||||||
|
if loaded_model_name == data.name:
|
||||||
|
raise HTTPException(
|
||||||
|
400, f'Model "{loaded_model_name}"is already loaded! Aborting.'
|
||||||
|
)
|
||||||
|
|
||||||
async def generator():
|
async def generator():
|
||||||
"""Generator for the loading process."""
|
"""Generator for the loading process."""
|
||||||
|
global MODEL_CONTAINER
|
||||||
|
|
||||||
|
# Unload the existing model
|
||||||
|
if MODEL_CONTAINER and MODEL_CONTAINER.model:
|
||||||
|
unload_model()
|
||||||
|
|
||||||
|
MODEL_CONTAINER = ExllamaV2Container(model_path.resolve(), False, **load_data)
|
||||||
|
|
||||||
model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
|
model_type = "draft" if MODEL_CONTAINER.draft_config else "model"
|
||||||
load_status = MODEL_CONTAINER.load_gen(load_progress)
|
load_status = MODEL_CONTAINER.load_gen(load_progress)
|
||||||
|
|
@ -230,6 +235,7 @@ def load_model(request: Request, data: ModelLoadRequest):
|
||||||
try:
|
try:
|
||||||
for module, modules in load_status:
|
for module, modules in load_status:
|
||||||
if await request.is_disconnected():
|
if await request.is_disconnected():
|
||||||
|
unload_model()
|
||||||
break
|
break
|
||||||
|
|
||||||
if module == 0:
|
if module == 0:
|
||||||
|
|
@ -269,7 +275,17 @@ def load_model(request: Request, data: ModelLoadRequest):
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
yield get_generator_error(str(exc))
|
yield get_generator_error(str(exc))
|
||||||
|
|
||||||
return StreamingResponse(generator(), media_type="text/event-stream")
|
# Determine whether to use or skip the queue
|
||||||
|
if data.skip_queue:
|
||||||
|
logger.warning(
|
||||||
|
"Model load request is skipping the completions queue. "
|
||||||
|
"Unexpected results may occur."
|
||||||
|
)
|
||||||
|
generator_callback = generator
|
||||||
|
else:
|
||||||
|
generator_callback = partial(generate_with_semaphore, generator)
|
||||||
|
|
||||||
|
return StreamingResponse(generator_callback(), media_type="text/event-stream")
|
||||||
|
|
||||||
|
|
||||||
# Unload model endpoint
|
# Unload model endpoint
|
||||||
|
|
@ -398,7 +414,7 @@ def get_active_loras():
|
||||||
"/v1/lora/load",
|
"/v1/lora/load",
|
||||||
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
dependencies=[Depends(check_admin_key), Depends(_check_model_container)],
|
||||||
)
|
)
|
||||||
def load_lora(data: LoraLoadRequest):
|
async def load_lora(data: LoraLoadRequest):
|
||||||
"""Loads a LoRA into the model container."""
|
"""Loads a LoRA into the model container."""
|
||||||
if not data.loras:
|
if not data.loras:
|
||||||
raise HTTPException(400, "List of loras to load is not found.")
|
raise HTTPException(400, "List of loras to load is not found.")
|
||||||
|
|
@ -411,14 +427,27 @@ def load_lora(data: LoraLoadRequest):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Clean-up existing loras if present
|
# Clean-up existing loras if present
|
||||||
if len(MODEL_CONTAINER.active_loras) > 0:
|
def load_loras_internal():
|
||||||
MODEL_CONTAINER.unload(True)
|
if len(MODEL_CONTAINER.active_loras) > 0:
|
||||||
|
unload_loras()
|
||||||
|
|
||||||
result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump())
|
result = MODEL_CONTAINER.load_loras(lora_dir, **data.model_dump())
|
||||||
return LoraLoadResponse(
|
return LoraLoadResponse(
|
||||||
success=unwrap(result.get("success"), []),
|
success=unwrap(result.get("success"), []),
|
||||||
failure=unwrap(result.get("failure"), []),
|
failure=unwrap(result.get("failure"), []),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
internal_callback = partial(run_in_threadpool, load_loras_internal)
|
||||||
|
|
||||||
|
# Determine whether to skip the queue
|
||||||
|
if data.skip_queue:
|
||||||
|
logger.warning(
|
||||||
|
"Lora load request is skipping the completions queue. "
|
||||||
|
"Unexpected results may occur."
|
||||||
|
)
|
||||||
|
return await internal_callback()
|
||||||
|
else:
|
||||||
|
return await call_with_semaphore(internal_callback)
|
||||||
|
|
||||||
|
|
||||||
# Unload lora endpoint
|
# Unload lora endpoint
|
||||||
|
|
@ -428,7 +457,7 @@ def load_lora(data: LoraLoadRequest):
|
||||||
)
|
)
|
||||||
def unload_loras():
|
def unload_loras():
|
||||||
"""Unloads the currently loaded loras."""
|
"""Unloads the currently loaded loras."""
|
||||||
MODEL_CONTAINER.unload(True)
|
MODEL_CONTAINER.unload(loras_only=True)
|
||||||
|
|
||||||
|
|
||||||
# Encode tokens endpoint
|
# Encode tokens endpoint
|
||||||
|
|
@ -498,7 +527,8 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
||||||
)
|
)
|
||||||
|
|
||||||
return StreamingResponse(
|
return StreamingResponse(
|
||||||
generate_with_semaphore(generator), media_type="text/event-stream"
|
generate_with_semaphore(generator),
|
||||||
|
media_type="text/event-stream",
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -515,7 +545,8 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
||||||
return response
|
return response
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
"Completion aborted. Please check the server console."
|
"Completion aborted. Maybe the model was unloaded? "
|
||||||
|
"Please check the server console."
|
||||||
).error.message
|
).error.message
|
||||||
|
|
||||||
# Server error if there's a generation exception
|
# Server error if there's a generation exception
|
||||||
|
|
@ -617,7 +648,8 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
||||||
return response
|
return response
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
"Chat completion aborted. Please check the server console."
|
"Chat completion aborted. Maybe the model was unloaded? "
|
||||||
|
"Please check the server console."
|
||||||
).error.message
|
).error.message
|
||||||
|
|
||||||
# Server error if there's a generation exception
|
# Server error if there's a generation exception
|
||||||
|
|
@ -636,7 +668,6 @@ def start_api(host: str, port: int):
|
||||||
app,
|
app,
|
||||||
host=host,
|
host=host,
|
||||||
port=port,
|
port=port,
|
||||||
log_level="debug",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -733,15 +764,13 @@ def entrypoint(args: Optional[dict] = None):
|
||||||
host = unwrap(network_config.get("host"), "127.0.0.1")
|
host = unwrap(network_config.get("host"), "127.0.0.1")
|
||||||
port = unwrap(network_config.get("port"), 5000)
|
port = unwrap(network_config.get("port"), 5000)
|
||||||
|
|
||||||
# Start the API in a daemon thread
|
|
||||||
# This allows for command signals to be passed and properly shut down the program
|
|
||||||
# Otherwise the program will hang
|
|
||||||
# TODO: Replace this with abortables, async via producer consumer, or something else
|
# TODO: Replace this with abortables, async via producer consumer, or something else
|
||||||
threading.Thread(target=partial(start_api, host, port), daemon=True).start()
|
api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True)
|
||||||
|
|
||||||
|
api_thread.start()
|
||||||
# Keep the program alive
|
# Keep the program alive
|
||||||
loop = asyncio.get_event_loop()
|
while api_thread.is_alive():
|
||||||
loop.run_forever()
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue