Model: Fix paged and FA2 checks

If a user is using GPU split, check compute capability on only those
GPUs. Autosplit assumes that all GPUs will be used.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-26 11:29:31 -04:00
parent 9fbbc5afca
commit 094c7b1734

View file

@ -128,6 +128,7 @@ class ExllamaV2Container:
# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
gpu_device_list = list(range(0, gpu_count))
if gpu_count > 1 and gpu_split_auto:
# Auto GPU split parameters
@ -141,6 +142,12 @@ class ExllamaV2Container:
# Manual GPU split
self.gpu_split = kwargs.get("gpu_split")
self.gpu_split_auto = False
gpu_device_list = [
device_idx
for device_idx, memory in enumerate(self.gpu_split)
if memory > 0
]
else:
# One GPU setup
self.gpu_split_auto = False
@ -185,6 +192,27 @@ class ExllamaV2Container:
# Enable fasttensors loading if present
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
# Disable paged mode if the user's min GPU isn't supported (ampere and up)
min_compute_capability = min(
torch.cuda.get_device_capability(device=device_idx)[0]
for device_idx in gpu_device_list
)
# Compute capability < 8 is not supported by FA2
# AMD is also unsupported until ROCm updates its FA2 fork
if torch.version.hip or min_compute_capability < 8:
logger.warning(
"An unsupported GPU is found in this configuration. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"To disable compatability mode, all GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
self.config.no_flash_attn = True
self.paged = False
self.max_batch_size = 1
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
@ -397,31 +425,6 @@ class ExllamaV2Container:
async for value in iterate_in_threadpool(model_load_generator):
yield value
# Disable paged mode if the user's min GPU isn't supported (ampere and up)
device_list = {
module.device_idx
for module in self.model.modules
if module.device_idx >= 0
}
min_compute_capability = min(
torch.cuda.get_device_capability(device=device)[0]
for device in device_list
)
# Compute capability < 8 is not supported by FA2
# AMD is also unsupported until ROCm updates its FA2 fork
if torch.version.hip or min_compute_capability < 8:
logger.warning(
"An unsupported GPU is found in this configuration. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"To disable compatability mode, all GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
self.paged = False
self.max_batch_size = 1
# Create async generator
self.generator = ExLlamaV2DynamicGeneratorAsync(
model=self.model,