Model: Fix paged and FA2 checks
If a user is using GPU split, check compute capability on only those GPUs. Autosplit assumes that all GPUs will be used. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
9fbbc5afca
commit
094c7b1734
1 changed files with 28 additions and 25 deletions
|
|
@ -128,6 +128,7 @@ class ExllamaV2Container:
|
|||
# Turn off GPU split if the user is using 1 GPU
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
||||
gpu_device_list = list(range(0, gpu_count))
|
||||
|
||||
if gpu_count > 1 and gpu_split_auto:
|
||||
# Auto GPU split parameters
|
||||
|
|
@ -141,6 +142,12 @@ class ExllamaV2Container:
|
|||
# Manual GPU split
|
||||
self.gpu_split = kwargs.get("gpu_split")
|
||||
self.gpu_split_auto = False
|
||||
|
||||
gpu_device_list = [
|
||||
device_idx
|
||||
for device_idx, memory in enumerate(self.gpu_split)
|
||||
if memory > 0
|
||||
]
|
||||
else:
|
||||
# One GPU setup
|
||||
self.gpu_split_auto = False
|
||||
|
|
@ -185,6 +192,27 @@ class ExllamaV2Container:
|
|||
# Enable fasttensors loading if present
|
||||
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
||||
|
||||
# Disable paged mode if the user's min GPU isn't supported (ampere and up)
|
||||
min_compute_capability = min(
|
||||
torch.cuda.get_device_capability(device=device_idx)[0]
|
||||
for device_idx in gpu_device_list
|
||||
)
|
||||
|
||||
# Compute capability < 8 is not supported by FA2
|
||||
# AMD is also unsupported until ROCm updates its FA2 fork
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
logger.warning(
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
|
|
@ -397,31 +425,6 @@ class ExllamaV2Container:
|
|||
async for value in iterate_in_threadpool(model_load_generator):
|
||||
yield value
|
||||
|
||||
# Disable paged mode if the user's min GPU isn't supported (ampere and up)
|
||||
device_list = {
|
||||
module.device_idx
|
||||
for module in self.model.modules
|
||||
if module.device_idx >= 0
|
||||
}
|
||||
min_compute_capability = min(
|
||||
torch.cuda.get_device_capability(device=device)[0]
|
||||
for device in device_list
|
||||
)
|
||||
|
||||
# Compute capability < 8 is not supported by FA2
|
||||
# AMD is also unsupported until ROCm updates its FA2 fork
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
logger.warning(
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
|
||||
# Create async generator
|
||||
self.generator = ExLlamaV2DynamicGeneratorAsync(
|
||||
model=self.model,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue