fix issues with optional dependencies (#204)
* fix issues with optional dependencies * format document * Tree: Format and comment
This commit is contained in:
parent
75af974c88
commit
3aeddc5255
9 changed files with 104 additions and 53 deletions
4
.github/workflows/pages.yml
vendored
4
.github/workflows/pages.yml
vendored
|
|
@ -48,8 +48,8 @@ jobs:
|
|||
npm install @redocly/cli -g
|
||||
- name: Export OpenAPI docs
|
||||
run: |
|
||||
EXPORT_OPENAPI=1 python main.py --openapi-export-path "openapi-oai.json" --api-servers OAI
|
||||
EXPORT_OPENAPI=1 python main.py --openapi-export-path "openapi-kobold.json" --api-servers kobold
|
||||
python main.py --export-openapi true --openapi-export-path "openapi-kobold.json" --api-servers kobold
|
||||
python main.py --export-openapi true --openapi-export-path "openapi-oai.json" --api-servers OAI
|
||||
- name: Build and store Redocly site
|
||||
run: |
|
||||
mkdir static
|
||||
|
|
|
|||
|
|
@ -5,34 +5,6 @@ from importlib.metadata import PackageNotFoundError, version as package_version
|
|||
from loguru import logger
|
||||
|
||||
|
||||
def check_exllama_version():
|
||||
"""Verifies the exllama version"""
|
||||
|
||||
required_version = version.parse("0.2.2")
|
||||
current_version = version.parse(package_version("exllamav2").split("+")[0])
|
||||
|
||||
unsupported_message = (
|
||||
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
|
||||
f"or greater. Your current version is {current_version}.\n"
|
||||
"Please update your environment by running an update script "
|
||||
"(update_scripts/"
|
||||
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
"pip install --upgrade .[cu121]\n\n"
|
||||
"For CUDA 11.8:\n"
|
||||
"pip install --upgrade .[cu118]\n\n"
|
||||
"For ROCm:\n"
|
||||
"pip install --upgrade .[amd]\n\n"
|
||||
)
|
||||
|
||||
if current_version < required_version:
|
||||
raise SystemExit(unsupported_message)
|
||||
else:
|
||||
logger.info(f"ExllamaV2 version: {current_version}")
|
||||
|
||||
|
||||
def hardware_supports_flash_attn(gpu_device_list: list[int]):
|
||||
"""
|
||||
Check whether all GPUs in list support FA2
|
||||
|
|
|
|||
39
backends/exllamav2/version.py
Normal file
39
backends/exllamav2/version.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
import platform
|
||||
from packaging import version
|
||||
from importlib.metadata import version as package_version
|
||||
from loguru import logger
|
||||
from common.optional_dependencies import dependencies
|
||||
|
||||
|
||||
def check_exllama_version():
|
||||
"""Verifies the exllama version"""
|
||||
|
||||
install_message = (
|
||||
"Please update your environment by running an update script "
|
||||
"(update_scripts/"
|
||||
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
"pip install --upgrade .[cu121]\n\n"
|
||||
"For CUDA 11.8:\n"
|
||||
"pip install --upgrade .[cu118]\n\n"
|
||||
"For ROCm:\n"
|
||||
"pip install --upgrade .[amd]\n\n"
|
||||
)
|
||||
|
||||
if not dependencies.exl2:
|
||||
raise SystemExit(("Exllamav2 is not installed.\n" + install_message))
|
||||
|
||||
required_version = version.parse("0.2.2")
|
||||
current_version = version.parse(package_version("exllamav2").split("+")[0])
|
||||
|
||||
unsupported_message = (
|
||||
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
|
||||
f"or greater. Your current version is {current_version}.\n" + install_message
|
||||
)
|
||||
|
||||
if current_version < required_version:
|
||||
raise SystemExit(unsupported_message)
|
||||
else:
|
||||
logger.info(f"ExllamaV2 version: {current_version}")
|
||||
|
|
@ -5,16 +5,12 @@ from loguru import logger
|
|||
from typing import List, Optional
|
||||
|
||||
from common.utils import unwrap
|
||||
from common.optional_dependencies import dependencies
|
||||
|
||||
# Conditionally import infinity to sidestep its logger
|
||||
has_infinity_emb: bool = False
|
||||
try:
|
||||
if dependencies.extras:
|
||||
from infinity_emb import EngineArgs, AsyncEmbeddingEngine
|
||||
|
||||
has_infinity_emb = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class InfinityContainer:
|
||||
model_dir: pathlib.Path
|
||||
|
|
@ -23,7 +19,7 @@ class InfinityContainer:
|
|||
|
||||
# Conditionally set the type hint based on importablity
|
||||
# TODO: Clean this up
|
||||
if has_infinity_emb:
|
||||
if dependencies.extras:
|
||||
engine: Optional[AsyncEmbeddingEngine] = None
|
||||
else:
|
||||
engine = None
|
||||
|
|
|
|||
|
|
@ -3,13 +3,12 @@ from loguru import logger
|
|||
|
||||
from common.tabby_config import config, generate_config_file
|
||||
from endpoints.server import export_openapi
|
||||
from endpoints.utils import do_export_openapi
|
||||
|
||||
|
||||
def branch_to_actions() -> bool:
|
||||
"""Checks if a optional action needs to be run."""
|
||||
|
||||
if config.actions.export_openapi or do_export_openapi:
|
||||
if config.actions.export_openapi:
|
||||
openapi_json = export_openapi()
|
||||
|
||||
with open(config.actions.openapi_export_path, "w") as f:
|
||||
|
|
|
|||
|
|
@ -13,22 +13,20 @@ from typing import Optional
|
|||
from common.logger import get_loading_progress_bar
|
||||
from common.networking import handle_request_error
|
||||
from common.tabby_config import config
|
||||
from endpoints.utils import do_export_openapi
|
||||
from common.optional_dependencies import dependencies
|
||||
|
||||
if not do_export_openapi:
|
||||
if dependencies.exl2:
|
||||
from backends.exllamav2.model import ExllamaV2Container
|
||||
|
||||
# Global model container
|
||||
container: Optional[ExllamaV2Container] = None
|
||||
embeddings_container = None
|
||||
|
||||
# Type hint the infinity emb container if it exists
|
||||
from backends.infinity.model import has_infinity_emb
|
||||
|
||||
if has_infinity_emb:
|
||||
from backends.infinity.model import InfinityContainer
|
||||
if dependencies.extras:
|
||||
from backends.infinity.model import InfinityContainer
|
||||
|
||||
embeddings_container: Optional[InfinityContainer] = None
|
||||
embeddings_container: Optional[InfinityContainer] = None
|
||||
|
||||
|
||||
class ModelType(Enum):
|
||||
|
|
@ -121,7 +119,7 @@ async def load_embedding_model(model_path: pathlib.Path, **kwargs):
|
|||
global embeddings_container
|
||||
|
||||
# Break out if infinity isn't installed
|
||||
if not has_infinity_emb:
|
||||
if not dependencies.extras:
|
||||
raise ImportError(
|
||||
"Skipping embeddings because infinity-emb is not installed.\n"
|
||||
"Please run the following command in your environment "
|
||||
|
|
|
|||
52
common/optional_dependencies.py
Normal file
52
common/optional_dependencies.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
"""Construct a model of all optional dependencies"""
|
||||
|
||||
import importlib.util
|
||||
from pydantic import BaseModel, computed_field
|
||||
|
||||
|
||||
# Declare the exported parts of this module
|
||||
__all__ = ["dependencies"]
|
||||
|
||||
|
||||
class DependenciesModel(BaseModel):
|
||||
"""Model of which optional dependencies are installed."""
|
||||
|
||||
torch: bool
|
||||
exllamav2: bool
|
||||
flash_attn: bool
|
||||
outlines: bool
|
||||
infinity_emb: bool
|
||||
sentence_transformers: bool
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def extras(self) -> bool:
|
||||
return self.outlines and self.infinity_emb and self.sentence_transformers
|
||||
|
||||
@computed_field
|
||||
@property
|
||||
def exl2(self) -> bool:
|
||||
return self.torch and self.exllamav2 and self.flash_attn
|
||||
|
||||
|
||||
def is_installed(package_name: str) -> bool:
|
||||
"""Utility function to check if a package is installed."""
|
||||
|
||||
spec = importlib.util.find_spec(package_name)
|
||||
return spec is not None
|
||||
|
||||
|
||||
def get_installed_deps() -> DependenciesModel:
|
||||
"""Check if optional dependencies are installed by looping over the fields."""
|
||||
|
||||
fields = DependenciesModel.model_fields
|
||||
|
||||
installed_deps = {}
|
||||
|
||||
for field_name in fields.keys():
|
||||
installed_deps[field_name] = is_installed(field_name)
|
||||
|
||||
return DependenciesModel(**installed_deps)
|
||||
|
||||
|
||||
dependencies = get_installed_deps()
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
import os
|
||||
|
||||
do_export_openapi = os.getenv("EXPORT_OPENAPI", "").lower() in ("true", "1")
|
||||
4
main.py
4
main.py
|
|
@ -17,10 +17,8 @@ from common.networking import is_port_in_use
|
|||
from common.signals import signal_handler
|
||||
from common.tabby_config import config
|
||||
from endpoints.server import start_api
|
||||
from endpoints.utils import do_export_openapi
|
||||
|
||||
if not do_export_openapi:
|
||||
from backends.exllamav2.utils import check_exllama_version
|
||||
from backends.exllamav2.version import check_exllama_version
|
||||
|
||||
|
||||
async def entrypoint_async():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue