From 48ea1737cfa3b9b092fa49d0335fcf810c06cd41 Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Fri, 9 May 2025 21:59:00 -0400 Subject: [PATCH] Startup: Check agnostically for inference deps If an inference dep isn't present, force exit the application. This occurs after all subcommands have been appropriately processed. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- common/optional_dependencies.py | 2 +- main.py | 18 ++++++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/common/optional_dependencies.py b/common/optional_dependencies.py index 0c1e7ff..b449c2e 100644 --- a/common/optional_dependencies.py +++ b/common/optional_dependencies.py @@ -26,7 +26,7 @@ class DependenciesModel(BaseModel): @computed_field @property def inference(self) -> bool: - return self.torch and (self.exllamav2 or self.exllamav3) and self.flash_attn + return self.torch and (self.exllamav2 or (self.exllamav3 and self.flash_attn)) def is_installed(package_name: str) -> bool: diff --git a/main.py b/main.py index df4e472..0115421 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,7 @@ from common.auth import load_auth_keys from common.actions import run_subcommand from common.logger import setup_logger from common.networking import is_port_in_use +from common.optional_dependencies import dependencies from common.signals import signal_handler from common.tabby_config import config from endpoints.server import start_api @@ -139,8 +140,21 @@ def entrypoint( "UNSAFE: Skipping ExllamaV2 version check.\n" "If you aren't a developer, please keep this off!" ) - else: - check_exllama_version() + elif not dependencies.inference: + install_message = ( + f"ERROR: Inference dependencies for TabbyAPI are not installed.\n" + "Please update your environment by running an update script " + "(update_scripts/" + f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n" + "Or you can manually run a requirements update " + "using the following command:\n\n" + "For CUDA 12.1:\n" + "pip install --upgrade .[cu121]\n\n" + "For ROCm:\n" + "pip install --upgrade .[amd]\n\n" + ) + + raise SystemExit(install_message) # Enable CUDA malloc backend if config.developer.cuda_malloc_backend: