Common: Add hardware file

Removed from a commit as well.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-05-01 22:39:32 -04:00
parent eca403a0e4
commit 59d081fe83

20
common/hardware.py Normal file
View file

@ -0,0 +1,20 @@
import torch
def hardware_supports_flash_attn(gpu_device_list: list[int]):
"""
Check whether all GPUs in list support FA2
Compute capability < 8 is not supported by FA2
AMD is also unsupported until ROCm updates its FA2 fork
"""
min_compute_capability = min(
torch.cuda.get_device_capability(device=device_idx)[0]
for device_idx in gpu_device_list
)
if torch.version.hip or min_compute_capability < 8:
return False
else:
return True