diff --git a/tests/model_test.py b/tests/model_test.py index 6fc7825..ee5375d 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -9,7 +9,8 @@ for (module, modules) in loader: print(module, modules) generator = container.generate_gen("Once upon a tim", token_healing = True) -for g in generator: print(g, end = "") +for g in generator: + print(g, end = "") container.unload() del container diff --git a/tests/wheel_test.py b/tests/wheel_test.py index 9fc5a14..150d66e 100644 --- a/tests/wheel_test.py +++ b/tests/wheel_test.py @@ -1,47 +1,41 @@ -import traceback from importlib.metadata import version +from importlib.util import find_spec successful_packages = [] errored_packages = [] -try: - import flash_attn +if find_spec("flash_attn") is not None: print(f"Flash attention on version {version('flash_attn')} successfully imported") successful_packages.append("flash_attn") -except: - print("Flash attention could not be loaded because:") - print(traceback.format_exc()) +else: + print("Flash attention 2 is not found in your environment.") errored_packages.append("flash_attn") -try: - import exllamav2 +if find_spec("exllamav2") is not None: print(f"Exllamav2 on version {version('exllamav2')} successfully imported") successful_packages.append("exllamav2") -except: - print("Exllamav2 could not be loaded because:") - print(traceback.format_exc()) +else: + print("Exllamav2 is not found in your environment.") errored_packages.append("exllamav2") -try: - import torch +if find_spec("torch") is not None: print(f"Torch on version {version('torch')} successfully imported") successful_packages.append("torch") -except: - print("Torch could not be loaded because:") - print(traceback.format_exc()) +else: + print("Torch is not found in your environment.") errored_packages.append("torch") -try: - import fastchat +if find_spec("fastchat") is not None: print(f"Fastchat on version {version('fschat')} successfully imported") successful_packages.append("fastchat") -except: - print("Fastchat is only needed for chat completions with message arrays. Ignore this error if this isn't your usecase.") - print("Fastchat could not be loaded because:") - print(traceback.format_exc()) +else: + print("Fastchat is not found in your environment. It isn't needed unless you're using chat completions with message arrays.") errored_packages.append("fastchat") print( f"\nSuccessful imports: {', '.join(successful_packages)}", f"\nErrored imports: {''.join(errored_packages)}" ) + +if len(errored_packages) > 0: + print("\nIf packages are installed, but not found on this test, please check the wheel versions for the correct python version and CUDA version (if applicable).")