Start: Give the user a hint when a module can't be imported

If an ImportError or ModuleNotFoundError is raised, tell the user
to run the update scripts.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-08-03 21:56:53 -04:00
parent 1aa934664c
commit b6d2676f1c
3 changed files with 23 additions and 10 deletions

View file

@ -1,7 +1,8 @@
import platform
import torch
from packaging import version
from importlib.metadata import PackageNotFoundError, version as package_version
from loguru import logger
import torch
def check_exllama_version():
@ -13,8 +14,9 @@ def check_exllama_version():
unsupported_message = (
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
f"or greater. Your current version is {current_version}.\n"
"Please upgrade your environment by running a start script "
"(start.bat or start.sh)\n\n"
"Please update your environment by running an update script "
"(update_scripts/"
f"update_deps.{'bat' if platform.system == 'Windows' else 'sh'})\n\n"
"Or you can manually run a requirements update "
"using the following command:\n\n"
"For CUDA 12.1:\n"
@ -71,8 +73,9 @@ def supports_paged_attn():
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"Please upgrade your environment by running a start script "
"(start.bat or start.sh)\n\n"
"Please upgrade your environment by running an update script "
"(update_scripts/"
f"update_deps.{'bat' if platform.system == 'Windows' else 'sh'})\n\n"
"Or you can manually run a requirements update "
"using the following command:\n\n"
"For CUDA 12.1:\n"

View file

@ -51,7 +51,7 @@ class PromptTemplate:
raise ImportError(
"Parsing these chat completion messages requires jinja2 3.0.0 "
f"or greater. Current version: {package_version('jinja2')}\n"
"Please upgrade jinja by running the following command: "
"Please update jinja by running the following command: "
"pip install --upgrade jinja2"
)

View file

@ -8,6 +8,7 @@ import platform
import subprocess
import sys
from shutil import copyfile
import traceback
from common.args import convert_args_to_dict, init_argparser
@ -231,9 +232,18 @@ if __name__ == "__main__":
)
# Import entrypoint after installing all requirements
from main import entrypoint
try:
from main import entrypoint
converted_args = convert_args_to_dict(args, parser)
converted_args = convert_args_to_dict(args, parser)
print("Starting TabbyAPI...")
entrypoint(converted_args)
print("Starting TabbyAPI...")
entrypoint(converted_args)
except (ModuleNotFoundError, ImportError):
traceback.print_exc()
print(
"\n"
"This error was raised because a package was not found.\n"
"Update your dependencies by running update_scripts/"
f"update_deps.{'bat' if platform.system == 'Windows' else 'sh'}\n\n"
)