Start: Add gpu_lib argument
Argument to override the selected GPU library. Useful for daemoniztion when running for the first time. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
d759a15559
commit
de41e9f7e9
1 changed files with 48 additions and 39 deletions
87
start.py
87
start.py
|
|
@ -10,7 +10,7 @@ import sys
|
|||
from common.args import convert_args_to_dict, init_argparser
|
||||
|
||||
|
||||
def get_user_choice(question, options_dict):
|
||||
def get_user_choice(question: str, options_dict: dict):
|
||||
"""
|
||||
Gets user input in a commandline script.
|
||||
|
||||
|
|
@ -34,51 +34,55 @@ def get_user_choice(question, options_dict):
|
|||
return choice
|
||||
|
||||
|
||||
def get_install_features():
|
||||
def get_install_features(lib_name: str = None):
|
||||
"""Fetches the appropriate requirements file depending on the GPU"""
|
||||
install_features = None
|
||||
possible_features = ["cu121", "cu118", "amd"]
|
||||
|
||||
# Try getting the GPU lib from a file
|
||||
saved_lib_path = pathlib.Path("gpu_lib.txt")
|
||||
if saved_lib_path.exists():
|
||||
with open(saved_lib_path.resolve(), "r") as f:
|
||||
lib = f.readline().strip()
|
||||
|
||||
# Assume default if the file is invalid
|
||||
if lib not in possible_features:
|
||||
print(
|
||||
f"WARN: GPU library {lib} not found. "
|
||||
"Skipping GPU-specific dependencies.\n"
|
||||
"WARN: Please delete gpu_lib.txt and restart "
|
||||
"if you want to change your selection."
|
||||
)
|
||||
return
|
||||
|
||||
print(f"Using {lib} dependencies from your preferences.")
|
||||
install_features = lib
|
||||
if lib_name:
|
||||
print("Overriding GPU lib name from args.")
|
||||
else:
|
||||
# Ask the user for the GPU lib
|
||||
gpu_lib_choices = {
|
||||
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"},
|
||||
"B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"},
|
||||
"C": {"pretty": "AMD", "internal": "amd"},
|
||||
}
|
||||
user_input = get_user_choice(
|
||||
"Select your GPU. If you don't know, select Cuda 12.x (A)",
|
||||
gpu_lib_choices,
|
||||
)
|
||||
|
||||
install_features = gpu_lib_choices.get(user_input, {}).get("internal")
|
||||
|
||||
# Write to a file for subsequent runs
|
||||
with open(saved_lib_path.resolve(), "w") as f:
|
||||
f.write(install_features)
|
||||
print(
|
||||
"Saving your choice to gpu_lib.txt. "
|
||||
"Delete this file and restart if you want to change your selection."
|
||||
# Try getting the GPU lib from file
|
||||
if saved_lib_path.exists():
|
||||
print(saved_lib_path)
|
||||
with open(saved_lib_path.resolve(), "r") as f:
|
||||
lib = f.readline().strip()
|
||||
else:
|
||||
# Ask the user for the GPU lib
|
||||
gpu_lib_choices = {
|
||||
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"},
|
||||
"B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"},
|
||||
"C": {"pretty": "AMD", "internal": "amd"},
|
||||
}
|
||||
user_input = get_user_choice(
|
||||
"Select your GPU. If you don't know, select Cuda 12.x (A)",
|
||||
gpu_lib_choices,
|
||||
)
|
||||
|
||||
lib_name = gpu_lib_choices.get(user_input, {}).get("internal")
|
||||
|
||||
# Write to a file for subsequent runs
|
||||
with open(saved_lib_path.resolve(), "w") as f:
|
||||
f.write(lib_name)
|
||||
print(
|
||||
"Saving your choice to gpu_lib.txt. "
|
||||
"Delete this file and restart if you want to change your selection."
|
||||
)
|
||||
|
||||
# Assume default if the file is invalid
|
||||
if lib_name and lib_name in possible_features:
|
||||
print(f"Using {lib_name} dependencies from your preferences.")
|
||||
install_features = lib_name
|
||||
else:
|
||||
print(
|
||||
f"WARN: GPU library {lib} not found. "
|
||||
"Skipping GPU-specific dependencies.\n"
|
||||
"WARN: Please delete gpu_lib.txt and restart "
|
||||
"if you want to change your selection."
|
||||
)
|
||||
return
|
||||
|
||||
if install_features == "amd":
|
||||
# Exit if using AMD and Windows
|
||||
if platform.system() == "Windows":
|
||||
|
|
@ -111,6 +115,11 @@ def add_start_args(parser: argparse.ArgumentParser):
|
|||
action="store_true",
|
||||
help="Don't upgrade wheel dependencies (exllamav2, torch)",
|
||||
)
|
||||
start_group.add_argument(
|
||||
"--gpu-lib",
|
||||
type=str,
|
||||
help="Select GPU library. Options: cu121, cu118, amd",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -124,7 +133,7 @@ if __name__ == "__main__":
|
|||
if args.ignore_upgrade:
|
||||
print("Ignoring pip dependency upgrade due to user request.")
|
||||
else:
|
||||
install_features = None if args.nowheel else get_install_features()
|
||||
install_features = None if args.nowheel else get_install_features(args.gpu_lib)
|
||||
features = f"[{install_features}]" if install_features else ""
|
||||
|
||||
# pip install .[features]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue