diff --git a/start.py b/start.py index 764249e..a0a5245 100644 --- a/start.py +++ b/start.py @@ -10,7 +10,7 @@ import sys from common.args import convert_args_to_dict, init_argparser -def get_user_choice(question, options_dict): +def get_user_choice(question: str, options_dict: dict): """ Gets user input in a commandline script. @@ -34,51 +34,55 @@ def get_user_choice(question, options_dict): return choice -def get_install_features(): +def get_install_features(lib_name: str = None): """Fetches the appropriate requirements file depending on the GPU""" install_features = None possible_features = ["cu121", "cu118", "amd"] - - # Try getting the GPU lib from a file saved_lib_path = pathlib.Path("gpu_lib.txt") - if saved_lib_path.exists(): - with open(saved_lib_path.resolve(), "r") as f: - lib = f.readline().strip() - # Assume default if the file is invalid - if lib not in possible_features: - print( - f"WARN: GPU library {lib} not found. " - "Skipping GPU-specific dependencies.\n" - "WARN: Please delete gpu_lib.txt and restart " - "if you want to change your selection." - ) - return - - print(f"Using {lib} dependencies from your preferences.") - install_features = lib + if lib_name: + print("Overriding GPU lib name from args.") else: - # Ask the user for the GPU lib - gpu_lib_choices = { - "A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"}, - "B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"}, - "C": {"pretty": "AMD", "internal": "amd"}, - } - user_input = get_user_choice( - "Select your GPU. If you don't know, select Cuda 12.x (A)", - gpu_lib_choices, - ) - - install_features = gpu_lib_choices.get(user_input, {}).get("internal") - - # Write to a file for subsequent runs - with open(saved_lib_path.resolve(), "w") as f: - f.write(install_features) - print( - "Saving your choice to gpu_lib.txt. " - "Delete this file and restart if you want to change your selection." + # Try getting the GPU lib from file + if saved_lib_path.exists(): + print(saved_lib_path) + with open(saved_lib_path.resolve(), "r") as f: + lib = f.readline().strip() + else: + # Ask the user for the GPU lib + gpu_lib_choices = { + "A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"}, + "B": {"pretty": "NVIDIA Cuda 11.8", "internal": "cu118"}, + "C": {"pretty": "AMD", "internal": "amd"}, + } + user_input = get_user_choice( + "Select your GPU. If you don't know, select Cuda 12.x (A)", + gpu_lib_choices, ) + lib_name = gpu_lib_choices.get(user_input, {}).get("internal") + + # Write to a file for subsequent runs + with open(saved_lib_path.resolve(), "w") as f: + f.write(lib_name) + print( + "Saving your choice to gpu_lib.txt. " + "Delete this file and restart if you want to change your selection." + ) + + # Assume default if the file is invalid + if lib_name and lib_name in possible_features: + print(f"Using {lib_name} dependencies from your preferences.") + install_features = lib_name + else: + print( + f"WARN: GPU library {lib} not found. " + "Skipping GPU-specific dependencies.\n" + "WARN: Please delete gpu_lib.txt and restart " + "if you want to change your selection." + ) + return + if install_features == "amd": # Exit if using AMD and Windows if platform.system() == "Windows": @@ -111,6 +115,11 @@ def add_start_args(parser: argparse.ArgumentParser): action="store_true", help="Don't upgrade wheel dependencies (exllamav2, torch)", ) + start_group.add_argument( + "--gpu-lib", + type=str, + help="Select GPU library. Options: cu121, cu118, amd", + ) if __name__ == "__main__": @@ -124,7 +133,7 @@ if __name__ == "__main__": if args.ignore_upgrade: print("Ignoring pip dependency upgrade due to user request.") else: - install_features = None if args.nowheel else get_install_features() + install_features = None if args.nowheel else get_install_features(args.gpu_lib) features = f"[{install_features}]" if install_features else "" # pip install .[features]