Launch: Make exllamav2 requirement more friendly
Add the ability to use an unsafe config flag if needed and migrate the exl2 check to a different file within the exl2 backend code. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
b827bcbb44
commit
1919bf7705
5 changed files with 72 additions and 22 deletions
31
backends/exllamav2/utils.py
Normal file
31
backends/exllamav2/utils.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from packaging import version
|
||||
from importlib.metadata import version as package_version
|
||||
|
||||
from common.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def check_exllama_version():
|
||||
"""Verifies the exllama version"""
|
||||
|
||||
required_version = "0.0.12"
|
||||
current_version = package_version("exllamav2").split("+")[0]
|
||||
|
||||
if version.parse(current_version) < version.parse(required_version):
|
||||
raise SystemExit(
|
||||
f"ERROR: TabbyAPI requires ExLlamaV2 {required_version} "
|
||||
f"or greater. Your current version is {current_version}.\n"
|
||||
"Please upgrade your environment by running a start script "
|
||||
"(start.bat or start.sh)\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
"pip install --upgrade -r requirements.txt\n\n"
|
||||
"For CUDA 11.8:\n"
|
||||
"pip install --upgrade -r requirements-cu118.txt\n\n"
|
||||
"For ROCm:\n"
|
||||
"pip install --upgrade -r requirements-amd.txt\n\n"
|
||||
)
|
||||
else:
|
||||
logger.info(f"ExllamaV2 version: {current_version}")
|
||||
|
|
@ -125,3 +125,12 @@ def add_logging_args(parser: argparse.ArgumentParser):
|
|||
type=str_to_bool,
|
||||
help="Enable generation parameter logging",
|
||||
)
|
||||
|
||||
|
||||
def add_developer_args(parser: argparse.ArgumentParser):
|
||||
"""Adds developer-specific arguments"""
|
||||
|
||||
developer_group = parser.add_argument_group("developer")
|
||||
developer_group.add_argument(
|
||||
"--unsafe-launch", type=str_to_bool, help="Skip Exllamav2 version check"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -55,6 +55,11 @@ def override_config_from_args(args: dict):
|
|||
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
|
||||
}
|
||||
|
||||
developer_override = args.get("developer")
|
||||
if developer_override:
|
||||
developer_config = get_developer_config()
|
||||
GLOBAL_CONFIG["developer"] = {**developer_config, **developer_override}
|
||||
|
||||
|
||||
def get_sampling_config():
|
||||
"""Returns the sampling parameter config from the global config"""
|
||||
|
|
@ -86,3 +91,8 @@ def get_network_config():
|
|||
def get_gen_logging_config():
|
||||
"""Returns the generation logging config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("logging"), {})
|
||||
|
||||
|
||||
def get_developer_config():
|
||||
"""Returns the developer specific config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("developer"), {})
|
||||
|
|
|
|||
|
|
@ -35,6 +35,13 @@ sampling:
|
|||
# WARNING: Using this can result in a generation speed penalty
|
||||
#override_preset:
|
||||
|
||||
# Options for development
|
||||
developer:
|
||||
# Skips exllamav2 version check (default: False)
|
||||
# It's highly recommended to update your dependencies rather than enabling this flag
|
||||
# WARNING: Don't set this unless you know what you're doing!
|
||||
#unsafe_launch: False
|
||||
|
||||
# Options for model overrides and loading
|
||||
model:
|
||||
# Overrides the directory to look for models (default: models)
|
||||
|
|
|
|||
37
main.py
37
main.py
|
|
@ -9,15 +9,15 @@ from fastapi import FastAPI, Depends, HTTPException, Request
|
|||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from functools import partial
|
||||
from packaging import version
|
||||
from importlib.metadata import version as package_version
|
||||
from progress.bar import IncrementalBar
|
||||
|
||||
import common.gen_logging as gen_logging
|
||||
from backends.exllamav2.model import ExllamaV2Container
|
||||
from backends.exllamav2.utils import check_exllama_version
|
||||
from common.args import convert_args_to_dict, init_argparser
|
||||
from common.auth import check_admin_key, check_api_key, load_auth_keys
|
||||
from common.config import (
|
||||
get_developer_config,
|
||||
get_sampling_config,
|
||||
override_config_from_args,
|
||||
read_config_from_file,
|
||||
|
|
@ -580,26 +580,6 @@ def entrypoint(args: Optional[dict] = None):
|
|||
"""Entry function for program startup"""
|
||||
global MODEL_CONTAINER
|
||||
|
||||
# Check exllamav2 version and give a descriptive error if it's too old
|
||||
required_exl_version = "0.0.12"
|
||||
current_exl_version = package_version("exllamav2").split("+")[0]
|
||||
|
||||
if version.parse(current_exl_version) < version.parse(required_exl_version):
|
||||
raise SystemExit(
|
||||
f"TabbyAPI requires ExLlamaV2 {required_exl_version} "
|
||||
f"or greater. Your current version is {current_exl_version}.\n"
|
||||
"Please upgrade your environment by running a start script "
|
||||
"(start.bat or start.sh)\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
"pip install --upgrade -r requirements.txt\n\n"
|
||||
"For CUDA 11.8:\n"
|
||||
"pip install --upgrade -r requirements-cu118.txt\n\n"
|
||||
"For ROCm:\n"
|
||||
"pip install --upgrade -r requirements-amd.txt\n\n"
|
||||
)
|
||||
|
||||
# Load from YAML config
|
||||
read_config_from_file(pathlib.Path("config.yml"))
|
||||
|
||||
|
|
@ -610,6 +590,19 @@ def entrypoint(args: Optional[dict] = None):
|
|||
|
||||
override_config_from_args(args)
|
||||
|
||||
developer_config = get_developer_config()
|
||||
|
||||
# Check exllamav2 version and give a descriptive error if it's too old
|
||||
# Skip if launching unsafely
|
||||
|
||||
if unwrap(developer_config.get("unsafe_launch"), False):
|
||||
logger.warning(
|
||||
"UNSAFE: Skipping ExllamaV2 version check.\n"
|
||||
"If you aren't a developer, please keep this off!"
|
||||
)
|
||||
else:
|
||||
check_exllama_version()
|
||||
|
||||
network_config = get_network_config()
|
||||
|
||||
# Initialize auth keys
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue