tabbyAPI-ollama/main.py
kingbri 30a3cd75cf Start: Migrate options from cu121/118 to cu12
This encapsulates more cuda versions and makes install easier for
new users.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
2025-08-19 22:56:58 -04:00

181 lines
5.6 KiB
Python

"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
# Set this env var for cuda malloc async before torch is initalized
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync"
import argparse
import asyncio
import pathlib
import platform
import signal
from loguru import logger
from typing import Optional
from common import gen_logging, sampling, model
from common.args import convert_args_to_dict, init_argparser
from common.auth import load_auth_keys
from common.actions import run_subcommand
from common.logger import setup_logger
from common.networking import is_port_in_use
from common.optional_dependencies import dependencies
from common.signals import signal_handler
from common.tabby_config import config
from endpoints.server import start_api
async def entrypoint_async():
"""Async entry function for program startup"""
host = config.network.host
port = config.network.port
# Check if the port is available and attempt to bind a fallback
if is_port_in_use(port):
fallback_port = port + 1
if is_port_in_use(fallback_port):
logger.error(
f"Ports {port} and {fallback_port} are in use by different services.\n"
"Please free up those ports or specify a different one.\n"
"Exiting."
)
return
else:
logger.warning(
f"Port {port} is currently in use. Switching to {fallback_port}."
)
port = fallback_port
# If an initial model name is specified, create a container
# and load the model
model_name = config.model.model_name
if model_name:
model_path = pathlib.Path(config.model.model_dir)
model_path = model_path / model_name
# TODO: remove model_dump()
await model.load_model(
model_path.resolve(),
**config.model.model_dump(exclude_none=True),
draft_model=config.draft_model.model_dump(exclude_none=True),
)
# Load loras after loading the model
if config.lora.loras:
lora_dir = pathlib.Path(config.lora.lora_dir)
# TODO: remove model_dump()
await model.container.load_loras(
lora_dir.resolve(), **config.lora.model_dump()
)
# If an initial embedding model name is specified, create a separate container
# and load the model
embedding_model_name = config.embeddings.embedding_model_name
if embedding_model_name:
embedding_model_path = pathlib.Path(config.embeddings.embedding_model_dir)
embedding_model_path = embedding_model_path / embedding_model_name
try:
# TODO: remove model_dump()
await model.load_embedding_model(
embedding_model_path, **config.embeddings.model_dump()
)
except ImportError as ex:
logger.error(ex.msg)
# Initialize auth keys
await load_auth_keys(config.network.disable_auth)
gen_logging.broadcast_status()
# Set sampler parameter overrides if provided
sampling_override_preset = config.sampling.override_preset
if sampling_override_preset:
try:
await sampling.overrides_from_file(sampling_override_preset)
except FileNotFoundError as e:
logger.warning(str(e))
await start_api(host, port)
def entrypoint(
args: Optional[argparse.Namespace] = None,
parser: Optional[argparse.ArgumentParser] = None,
):
setup_logger()
# Set up signal aborting
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
if platform.system() == "Windows":
from winloop import install
else:
from uvloop import install
# Set loop event policy
install()
# Parse and override config from args
if args is None:
parser = init_argparser()
args = parser.parse_args()
dict_args = convert_args_to_dict(args, parser)
# load config
config.load(dict_args)
# branch to default paths if required
if run_subcommand(args):
return
# Check exllamav2 version and give a descriptive error if it's too old
# Skip if launching unsafely
if config.developer.unsafe_launch:
logger.warning(
"UNSAFE: Skipping ExllamaV2 version check.\n"
"If you aren't a developer, please keep this off!"
)
elif not dependencies.inference:
install_message = (
f"ERROR: Inference dependencies for TabbyAPI are not installed.\n"
"Please update your environment by running an update script "
"(update_scripts/"
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
"Or you can manually run a requirements update "
"using the following command:\n\n"
"For CUDA 12.1:\n"
"pip install --upgrade .[cu12]\n\n"
"For ROCm:\n"
"pip install --upgrade .[amd]\n\n"
)
raise SystemExit(install_message)
# Set the process priority
if config.developer.realtime_process_priority:
import psutil
current_process = psutil.Process(os.getpid())
if platform.system() == "Windows":
current_process.nice(psutil.REALTIME_PRIORITY_CLASS)
else:
current_process.nice(psutil.IOPRIO_CLASS_RT)
logger.warning(
"EXPERIMENTAL: Process priority set to Realtime. \n"
"If you're not running on administrator/sudo, the priority is set to high."
)
# Enter into the async event loop
asyncio.run(entrypoint_async())
if __name__ == "__main__":
entrypoint()