Model: Clarify warning and device check on load
FA2 v2.5.7 and up is not supported below ampere and on AMD GPUs. Clarify the error message and explain what happens as a result. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
47582c2440
commit
a46ee62d03
1 changed files with 15 additions and 9 deletions
|
|
@ -400,21 +400,27 @@ class ExllamaV2Container:
|
|||
async for value in iterate_in_threadpool(model_load_generator):
|
||||
yield value
|
||||
|
||||
# Disable paged mode if the user's min GPU is supported (ampere and above)
|
||||
# Disable paged mode if the user's min GPU isn't supported (ampere and up)
|
||||
device_list = {
|
||||
module.device_idx
|
||||
for module in self.model.modules
|
||||
if module.device_idx >= 0
|
||||
}
|
||||
min_compute_capability = min(
|
||||
set(
|
||||
[
|
||||
torch.cuda.get_device_capability(device=module.device_idx)[0]
|
||||
for module in self.model.modules
|
||||
if module.device_idx >= 0
|
||||
]
|
||||
)
|
||||
torch.cuda.get_device_capability(device=device)[0]
|
||||
for device in device_list
|
||||
)
|
||||
|
||||
# Compute capability < 8 is not supported by FA2
|
||||
# AMD is also unsupported until ROCm updates its FA2 fork
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
logger.warning(
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. This disables parallel batching."
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue