diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index b28cfd8..51a5e8e 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -5,7 +5,6 @@ import gc import math import pathlib import traceback -from backends.exllamav2.utils import hardware_supports_flash_attn, supports_paged_attn import torch import uuid from exllamav2 import ( @@ -28,6 +27,11 @@ from loguru import logger from typing import List, Optional, Union from backends.exllamav2.grammar import ExLlamaV2Grammar +from backends.exllamav2.utils import ( + exllama_disabled_flash_attn, + hardware_supports_flash_attn, + supports_paged_attn, +) from common.concurrency import iterate_in_threadpool from common.gen_logging import ( log_generation_params, @@ -173,6 +177,9 @@ class ExllamaV2Container: self.config.prepare() + # Check if the model arch is compatible with various exl2 features + self.config.arch_compat_overrides() + # Then override the base_seq_len if present override_base_seq_len = kwargs.get("override_base_seq_len") if override_base_seq_len: @@ -200,13 +207,13 @@ class ExllamaV2Container: # Enable fasttensors loading if present self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False) - # Check whether the user's configuration supports paged attention - if not hardware_supports_flash_attn(gpu_device_list): - self.config.no_flash_attn = True - self.paged = False - self.max_batch_size = 1 - torch.backends.cuda.enable_flash_sdp(False) - elif not supports_paged_attn(): + # Check whether the user's configuration supports flash/paged attention + # Also check if exl2 has disabled flash attention + if ( + exllama_disabled_flash_attn(self.config.no_flash_attn) + or not hardware_supports_flash_attn(gpu_device_list) + or not supports_paged_attn() + ): self.config.no_flash_attn = True self.paged = False self.max_batch_size = 1 diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index 1a89fe2..5b1d567 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -94,3 +94,18 @@ def supports_paged_attn(): return False else: return True + + +def exllama_disabled_flash_attn(no_flash_attn: bool): + unsupported_message = ( + "ExllamaV2 has disabled Flash Attention. \n" + "Please see the above logs for warnings/errors. \n" + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG). \n" + ) + + if no_flash_attn: + logger.warning(unsupported_message) + + return no_flash_attn