diff --git a/common/optional_dependencies.py b/common/optional_dependencies.py index 0c1e7ff..b449c2e 100644 --- a/common/optional_dependencies.py +++ b/common/optional_dependencies.py @@ -26,7 +26,7 @@ class DependenciesModel(BaseModel): @computed_field @property def inference(self) -> bool: - return self.torch and (self.exllamav2 or self.exllamav3) and self.flash_attn + return self.torch and (self.exllamav2 or (self.exllamav3 and self.flash_attn)) def is_installed(package_name: str) -> bool: diff --git a/main.py b/main.py index df4e472..0115421 100644 --- a/main.py +++ b/main.py @@ -15,6 +15,7 @@ from common.auth import load_auth_keys from common.actions import run_subcommand from common.logger import setup_logger from common.networking import is_port_in_use +from common.optional_dependencies import dependencies from common.signals import signal_handler from common.tabby_config import config from endpoints.server import start_api @@ -139,8 +140,21 @@ def entrypoint( "UNSAFE: Skipping ExllamaV2 version check.\n" "If you aren't a developer, please keep this off!" ) - else: - check_exllama_version() + elif not dependencies.inference: + install_message = ( + f"ERROR: Inference dependencies for TabbyAPI are not installed.\n" + "Please update your environment by running an update script " + "(update_scripts/" + f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n" + "Or you can manually run a requirements update " + "using the following command:\n\n" + "For CUDA 12.1:\n" + "pip install --upgrade .[cu121]\n\n" + "For ROCm:\n" + "pip install --upgrade .[amd]\n\n" + ) + + raise SystemExit(install_message) # Enable CUDA malloc backend if config.developer.cuda_malloc_backend: