tabbyAPI-ollama/common/optional_dependencies.py
kingbri 17f3dca6fc Packaging: Add agnostic method to check version of packages
Some packages such as ExllamaV2 and V3 require specific versions for
the latest features. Rather than creating repetitive functions, create
an agnostic function to check the installed package and then report
to the user to upgrade.

This is also sent to requests for loading and unloading, so keep the
error short.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
2025-05-17 01:04:24 -04:00

77 lines
2.1 KiB
Python

"""Construct a model of all optional dependencies"""
import importlib.util
from importlib.metadata import version as package_version
from loguru import logger
from packaging import version
from pydantic import BaseModel, computed_field
# Declare the exported parts of this module
__all__ = ["dependencies"]
class DependenciesModel(BaseModel):
"""Model of which optional dependencies are installed."""
torch: bool
exllamav2: bool
exllamav3: bool
flash_attn: bool
infinity_emb: bool
sentence_transformers: bool
@computed_field
@property
def extras(self) -> bool:
return self.infinity_emb and self.sentence_transformers
@computed_field
@property
def inference(self) -> bool:
return self.torch and (self.exllamav2 or (self.exllamav3 and self.flash_attn))
def is_installed(package_name: str) -> bool:
"""Utility function to check if a package is installed."""
spec = importlib.util.find_spec(package_name)
return spec is not None
def get_installed_deps() -> DependenciesModel:
"""Check if optional dependencies are installed by looping over the fields."""
fields = DependenciesModel.model_fields
installed_deps = {}
for field_name in fields.keys():
installed_deps[field_name] = is_installed(field_name)
return DependenciesModel(**installed_deps)
def check_package_version(package_name: str, required_version_str: str):
"""
Fetches and verifies a given package version.
This assumes that the required package is installed.
"""
required_version = version.parse(required_version_str)
current_version = version.parse(package_version(package_name).split("+")[0])
unsupported_message = (
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
f"or greater. Your current version is {current_version}. "
"Please update your dependencies."
)
if current_version < required_version:
raise RuntimeError(unsupported_message)
else:
logger.info(f"{package_name} version: {current_version}")
dependencies = get_installed_deps()