Start: Update to use pyproject
Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
b1ca435695
commit
72b08624a3
1 changed files with 11 additions and 10 deletions
21
start.py
21
start.py
|
|
@ -8,15 +8,15 @@ import subprocess
|
|||
from common.args import convert_args_to_dict, init_argparser
|
||||
|
||||
|
||||
def get_requirements_file():
|
||||
def get_install_features():
|
||||
"""Fetches the appropriate requirements file depending on the GPU"""
|
||||
requirements_name = "requirements-nowheel"
|
||||
install_features = None
|
||||
ROCM_PATH = os.environ.get("ROCM_PATH")
|
||||
CUDA_PATH = os.environ.get("CUDA_PATH")
|
||||
|
||||
# TODO: Check if the user has an AMD gpu on windows
|
||||
if ROCM_PATH:
|
||||
requirements_name = "requirements-amd"
|
||||
install_features = "amd"
|
||||
|
||||
# Also override env vars for ROCm support on non-supported GPUs
|
||||
os.environ["ROCM_PATH"] = "/opt/rocm"
|
||||
|
|
@ -25,11 +25,11 @@ def get_requirements_file():
|
|||
elif CUDA_PATH:
|
||||
cuda_version = pathlib.Path(CUDA_PATH).name
|
||||
if "12" in cuda_version:
|
||||
requirements_name = "requirements"
|
||||
install_features = "cu121"
|
||||
elif "11" in cuda_version:
|
||||
requirements_name = "requirements-cu118"
|
||||
install_features = "cu118"
|
||||
|
||||
return requirements_name
|
||||
return install_features
|
||||
|
||||
|
||||
def add_start_args(parser: argparse.ArgumentParser):
|
||||
|
|
@ -60,10 +60,11 @@ if __name__ == "__main__":
|
|||
if args.ignore_upgrade:
|
||||
print("Ignoring pip dependency upgrade due to user request.")
|
||||
else:
|
||||
requirements_file = (
|
||||
"requirements-nowheel" if args.nowheel else get_requirements_file()
|
||||
)
|
||||
subprocess.run(["pip", "install", "-U", "-r", f"{requirements_file}.txt"])
|
||||
install_features = None if args.nowheel else get_install_features()
|
||||
features = f"[{install_features}]" if install_features else ""
|
||||
|
||||
# pip install .[features]
|
||||
subprocess.run(["pip", "install", "-U", f".{features}"])
|
||||
|
||||
# Import entrypoint after installing all requirements
|
||||
from main import entrypoint
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue