Moving the API into its own directory helps compartmentalize it and allows for cleaning up the main file to just contain bootstrapping and the entry point. Signed-off-by: kingbri <bdashore3@proton.me>
134 lines
4.2 KiB
Python
134 lines
4.2 KiB
Python
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
|
|
|
|
import asyncio
|
|
import os
|
|
import pathlib
|
|
import signal
|
|
import sys
|
|
import threading
|
|
import time
|
|
from functools import partial
|
|
from loguru import logger
|
|
from typing import Optional
|
|
|
|
from backends.exllamav2.utils import check_exllama_version
|
|
from common import config, gen_logging, sampling, model
|
|
from common.args import convert_args_to_dict, init_argparser
|
|
from common.auth import load_auth_keys
|
|
from common.logger import setup_logger
|
|
from common.utils import is_port_in_use, unwrap
|
|
from endpoints.OAI.app import start_api
|
|
|
|
|
|
def signal_handler(*_):
|
|
logger.warning("Shutdown signal called. Exiting gracefully.")
|
|
sys.exit(0)
|
|
|
|
|
|
async def entrypoint(args: Optional[dict] = None):
|
|
"""Entry function for program startup"""
|
|
|
|
setup_logger()
|
|
|
|
# Set up signal aborting
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
|
|
# Load from YAML config
|
|
config.from_file(pathlib.Path("config.yml"))
|
|
|
|
# Parse and override config from args
|
|
if args is None:
|
|
parser = init_argparser()
|
|
args = convert_args_to_dict(parser.parse_args(), parser)
|
|
|
|
config.from_args(args)
|
|
|
|
developer_config = config.developer_config()
|
|
|
|
# Check exllamav2 version and give a descriptive error if it's too old
|
|
# Skip if launching unsafely
|
|
|
|
if unwrap(developer_config.get("unsafe_launch"), False):
|
|
logger.warning(
|
|
"UNSAFE: Skipping ExllamaV2 version check.\n"
|
|
"If you aren't a developer, please keep this off!"
|
|
)
|
|
else:
|
|
check_exllama_version()
|
|
|
|
# Enable CUDA malloc backend
|
|
if unwrap(developer_config.get("cuda_malloc_backend"), False):
|
|
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
|
|
logger.warning("Enabled the experimental CUDA malloc backend.")
|
|
|
|
network_config = config.network_config()
|
|
|
|
host = unwrap(network_config.get("host"), "127.0.0.1")
|
|
port = unwrap(network_config.get("port"), 5000)
|
|
|
|
# Check if the port is available and attempt to bind a fallback
|
|
if is_port_in_use(port):
|
|
fallback_port = port + 1
|
|
|
|
if is_port_in_use(fallback_port):
|
|
logger.error(
|
|
f"Ports {port} and {fallback_port} are in use by different services.\n"
|
|
"Please free up those ports or specify a different one.\n"
|
|
"Exiting."
|
|
)
|
|
|
|
return
|
|
else:
|
|
logger.warning(
|
|
f"Port {port} is currently in use. Switching to {fallback_port}."
|
|
)
|
|
|
|
port = fallback_port
|
|
|
|
# Initialize auth keys
|
|
load_auth_keys(unwrap(network_config.get("disable_auth"), False))
|
|
|
|
# Override the generation log options if given
|
|
log_config = config.gen_logging_config()
|
|
if log_config:
|
|
gen_logging.update_from_dict(log_config)
|
|
|
|
gen_logging.broadcast_status()
|
|
|
|
# Set sampler parameter overrides if provided
|
|
sampling_config = config.sampling_config()
|
|
sampling_override_preset = sampling_config.get("override_preset")
|
|
if sampling_override_preset:
|
|
try:
|
|
sampling.overrides_from_file(sampling_override_preset)
|
|
except FileNotFoundError as e:
|
|
logger.warning(str(e))
|
|
|
|
# If an initial model name is specified, create a container
|
|
# and load the model
|
|
model_config = config.model_config()
|
|
model_name = model_config.get("model_name")
|
|
if model_name:
|
|
model_path = pathlib.Path(unwrap(model_config.get("model_dir"), "models"))
|
|
model_path = model_path / model_name
|
|
|
|
await model.load_model(model_path.resolve(), **model_config)
|
|
|
|
# Load loras after loading the model
|
|
lora_config = config.lora_config()
|
|
if lora_config.get("loras"):
|
|
lora_dir = pathlib.Path(unwrap(lora_config.get("lora_dir"), "loras"))
|
|
model.container.load_loras(lora_dir.resolve(), **lora_config)
|
|
|
|
# TODO: Replace this with abortables, async via producer consumer, or something else
|
|
api_thread = threading.Thread(target=partial(start_api, host, port), daemon=True)
|
|
|
|
api_thread.start()
|
|
# Keep the program alive
|
|
while api_thread.is_alive():
|
|
time.sleep(0.5)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(entrypoint())
|