Start: Migrate options from cu121/118 to cu12
This encapsulates more cuda versions and makes install easier for new users. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
1344726936
commit
30a3cd75cf
6 changed files with 35 additions and 41 deletions
60
start.py
60
start.py
|
|
@ -41,14 +41,13 @@ def get_user_choice(question: str, options_dict: dict):
|
|||
def get_install_features(lib_name: str = None):
|
||||
"""Fetches the appropriate requirements file depending on the GPU"""
|
||||
install_features = None
|
||||
possible_features = ["cu121", "cu118", "amd"]
|
||||
possible_features = ["cu12", "amd"]
|
||||
|
||||
if not lib_name:
|
||||
# Ask the user for the GPU lib
|
||||
gpu_lib_choices = {
|
||||
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu121"},
|
||||
"B": {"pretty": "NVIDIA Cuda 11.8 (Unsupported)", "internal": "cu118"},
|
||||
"C": {"pretty": "AMD", "internal": "amd"},
|
||||
"A": {"pretty": "NVIDIA Cuda 12.x", "internal": "cu12"},
|
||||
"B": {"pretty": "AMD", "internal": "amd"},
|
||||
}
|
||||
user_input = get_user_choice(
|
||||
"Select your GPU. If you don't know, select Cuda 12.x (A)",
|
||||
|
|
@ -79,7 +78,7 @@ def get_install_features(lib_name: str = None):
|
|||
if platform.system() == "Windows":
|
||||
print(
|
||||
"ERROR: TabbyAPI does not support AMD and Windows. "
|
||||
"Please use Linux and ROCm 6.0. Exiting."
|
||||
"Please use Linux and ROCm 6.4. Exiting."
|
||||
)
|
||||
sys.exit(0)
|
||||
|
||||
|
|
@ -139,24 +138,17 @@ def add_start_args(parser: argparse.ArgumentParser):
|
|||
)
|
||||
|
||||
|
||||
def migrate_gpu_lib():
|
||||
gpu_lib_path = pathlib.Path("gpu_lib.txt")
|
||||
def migrate_start_options(start_options: dict):
|
||||
migrated = False
|
||||
|
||||
if not gpu_lib_path.exists():
|
||||
return
|
||||
# Migrate gpu_lib key
|
||||
gpu_lib = start_options.get("gpu_lib")
|
||||
if (gpu_lib == "cu121" or gpu_lib == "cu118"):
|
||||
print("GPU lib key is legacy, migrating to cu12")
|
||||
start_options["gpu_lib"] = "cu12"
|
||||
migrated = True
|
||||
|
||||
print("Migrating gpu_lib.txt to the new start_options.json")
|
||||
with open("gpu_lib.txt", "r") as gpu_lib_file:
|
||||
start_options["gpu_lib"] = gpu_lib_file.readline().strip()
|
||||
start_options["first_run_done"] = True
|
||||
|
||||
# Remove the old file
|
||||
gpu_lib_path.unlink()
|
||||
|
||||
print(
|
||||
"Successfully migrated gpu lib options to start_options. "
|
||||
"The old file has been deleted."
|
||||
)
|
||||
return migrated
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
@ -183,6 +175,7 @@ if __name__ == "__main__":
|
|||
add_start_args(parser)
|
||||
args, _ = parser.parse_known_args()
|
||||
script_ext = "bat" if platform.system() == "Windows" else "sh"
|
||||
do_start_options_write = False
|
||||
|
||||
start_options_path = pathlib.Path("start_options.json")
|
||||
if start_options_path.exists():
|
||||
|
|
@ -190,6 +183,7 @@ if __name__ == "__main__":
|
|||
start_options = json.load(start_options_file)
|
||||
print("Loaded your saved preferences from `start_options.json`")
|
||||
|
||||
do_start_options_write = migrate_start_options(start_options)
|
||||
if start_options.get("first_run_done"):
|
||||
first_run = False
|
||||
else:
|
||||
|
|
@ -198,9 +192,6 @@ if __name__ == "__main__":
|
|||
"Getting things ready..."
|
||||
)
|
||||
|
||||
# Migrate from old setting storage
|
||||
migrate_gpu_lib()
|
||||
|
||||
# Set variables that rely on start options
|
||||
first_run = not start_options.get("first_run_done")
|
||||
|
||||
|
|
@ -240,15 +231,7 @@ if __name__ == "__main__":
|
|||
start_options["first_run_done"] = True
|
||||
|
||||
# Save start options on first run
|
||||
with open("start_options.json", "w") as start_file:
|
||||
start_file.write(json.dumps(start_options))
|
||||
|
||||
print(
|
||||
"Successfully wrote your start script options to "
|
||||
"`start_options.json`. \n"
|
||||
"If something goes wrong, editing or deleting the file "
|
||||
"will reinstall TabbyAPI as a first-time user."
|
||||
)
|
||||
do_start_options_write = True
|
||||
|
||||
if args.update_deps:
|
||||
print(
|
||||
|
|
@ -262,6 +245,17 @@ if __name__ == "__main__":
|
|||
"inside the `update_scripts` folder."
|
||||
)
|
||||
|
||||
if do_start_options_write:
|
||||
with open("start_options.json", "w") as start_file:
|
||||
start_file.write(json.dumps(start_options))
|
||||
|
||||
print(
|
||||
"Successfully wrote your start script options to "
|
||||
"`start_options.json`. \n"
|
||||
"If something goes wrong, editing or deleting the file "
|
||||
"will reinstall TabbyAPI as a first-time user."
|
||||
)
|
||||
|
||||
# Expand the parser if it's not fully created
|
||||
if not has_full_parser:
|
||||
from common.args import init_argparser
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue