From 59d081fe83fc6acc23545220d6028da9cf8d7f60 Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Thu, 1 May 2025 22:39:32 -0400 Subject: [PATCH] Common: Add hardware file Removed from a commit as well. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- common/hardware.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 common/hardware.py diff --git a/common/hardware.py b/common/hardware.py new file mode 100644 index 0000000..10723c5 --- /dev/null +++ b/common/hardware.py @@ -0,0 +1,20 @@ +import torch + + +def hardware_supports_flash_attn(gpu_device_list: list[int]): + """ + Check whether all GPUs in list support FA2 + + Compute capability < 8 is not supported by FA2 + AMD is also unsupported until ROCm updates its FA2 fork + """ + + min_compute_capability = min( + torch.cuda.get_device_capability(device=device_idx)[0] + for device_idx in gpu_device_list + ) + + if torch.version.hip or min_compute_capability < 8: + return False + else: + return True