Some packages such as ExllamaV2 and V3 require specific versions for the latest features. Rather than creating repetitive functions, create an agnostic function to check the installed package and then report to the user to upgrade. This is also sent to requests for loading and unloading, so keep the error short. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
77 lines
2.1 KiB
Python
77 lines
2.1 KiB
Python
"""Construct a model of all optional dependencies"""
|
|
|
|
import importlib.util
|
|
from importlib.metadata import version as package_version
|
|
from loguru import logger
|
|
from packaging import version
|
|
from pydantic import BaseModel, computed_field
|
|
|
|
|
|
# Declare the exported parts of this module
|
|
__all__ = ["dependencies"]
|
|
|
|
|
|
class DependenciesModel(BaseModel):
|
|
"""Model of which optional dependencies are installed."""
|
|
|
|
torch: bool
|
|
exllamav2: bool
|
|
exllamav3: bool
|
|
flash_attn: bool
|
|
infinity_emb: bool
|
|
sentence_transformers: bool
|
|
|
|
@computed_field
|
|
@property
|
|
def extras(self) -> bool:
|
|
return self.infinity_emb and self.sentence_transformers
|
|
|
|
@computed_field
|
|
@property
|
|
def inference(self) -> bool:
|
|
return self.torch and (self.exllamav2 or (self.exllamav3 and self.flash_attn))
|
|
|
|
|
|
def is_installed(package_name: str) -> bool:
|
|
"""Utility function to check if a package is installed."""
|
|
|
|
spec = importlib.util.find_spec(package_name)
|
|
return spec is not None
|
|
|
|
|
|
def get_installed_deps() -> DependenciesModel:
|
|
"""Check if optional dependencies are installed by looping over the fields."""
|
|
|
|
fields = DependenciesModel.model_fields
|
|
|
|
installed_deps = {}
|
|
|
|
for field_name in fields.keys():
|
|
installed_deps[field_name] = is_installed(field_name)
|
|
|
|
return DependenciesModel(**installed_deps)
|
|
|
|
|
|
def check_package_version(package_name: str, required_version_str: str):
|
|
"""
|
|
Fetches and verifies a given package version.
|
|
|
|
This assumes that the required package is installed.
|
|
"""
|
|
|
|
required_version = version.parse(required_version_str)
|
|
current_version = version.parse(package_version(package_name).split("+")[0])
|
|
|
|
unsupported_message = (
|
|
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
|
|
f"or greater. Your current version is {current_version}. "
|
|
"Please update your dependencies."
|
|
)
|
|
|
|
if current_version < required_version:
|
|
raise RuntimeError(unsupported_message)
|
|
else:
|
|
logger.info(f"{package_name} version: {current_version}")
|
|
|
|
|
|
dependencies = get_installed_deps()
|