diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 89fa005..e86f48a 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -44,6 +44,7 @@ from common.gen_logging import ( from common.hardware import hardware_supports_flash_attn from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper +from common.optional_dependencies import check_package_version from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate, find_prompt_template from common.transformers_utils import HFModel @@ -111,6 +112,9 @@ class ExllamaV2Container(BaseModelContainer): # Create a new instance as a "fake self" self = cls() + # Make sure ExllamaV2 is up to date + check_package_version("exllamav2", "0.3.0") + # Initialize config self.config = ExLlamaV2Config() self.model_dir = model_directory diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index d80ee56..60cf5d0 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -32,6 +32,7 @@ from common.gen_logging import ( from common.hardware import hardware_supports_flash_attn from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper +from common.optional_dependencies import check_package_version from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate, find_prompt_template from common.transformers_utils import HFModel @@ -96,6 +97,9 @@ class ExllamaV3Container(BaseModelContainer): self = cls() + # Make sure ExllamaV3 is up to date + check_package_version("exllamav3", "0.0.2") + logger.warning( "ExllamaV3 is currently in an alpha state. " "Please note that all config options may not work." diff --git a/common/optional_dependencies.py b/common/optional_dependencies.py index b449c2e..207c8f0 100644 --- a/common/optional_dependencies.py +++ b/common/optional_dependencies.py @@ -1,6 +1,9 @@ """Construct a model of all optional dependencies""" import importlib.util +from importlib.metadata import version as package_version +from loguru import logger +from packaging import version from pydantic import BaseModel, computed_field @@ -49,4 +52,26 @@ def get_installed_deps() -> DependenciesModel: return DependenciesModel(**installed_deps) +def check_package_version(package_name: str, required_version_str: str): + """ + Fetches and verifies a given package version. + + This assumes that the required package is installed. + """ + + required_version = version.parse(required_version_str) + current_version = version.parse(package_version(package_name).split("+")[0]) + + unsupported_message = ( + f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} " + f"or greater. Your current version is {current_version}. " + "Please update your dependencies." + ) + + if current_version < required_version: + raise RuntimeError(unsupported_message) + else: + logger.info(f"{package_name} version: {current_version}") + + dependencies = get_installed_deps()