From 85387d97adab7b9920095a4fc5fba4a4d37ad6eb Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Wed, 12 Jun 2024 11:00:46 -0700 Subject: [PATCH] Fix disabling flash attention in exl2 config (#136) * Model: Fix disabling flash attention in exl2 config * Model: Pass no_flash_attn to draft config * Model: Force torch flash SDP off in compatibility mode --- backends/exllamav2/model.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 22f9239..d4a65c2 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -210,8 +210,10 @@ class ExllamaV2Container: "To disable compatability mode, all GPUs must be ampere " "(30 series) or newer. AMD GPUs are not supported." ) + self.config.no_flash_attn = True self.paged = False self.max_batch_size = 1 + torch.backends.cuda.enable_flash_sdp(False) elif not supports_paged_attn(): logger.warning( "Flash attention version >=2.5.7 " @@ -229,8 +231,10 @@ class ExllamaV2Container: "pip install --upgrade .[cu118]\n\n" "NOTE: Windows users must use CUDA 12.x to use flash-attn." ) + self.config.no_flash_attn = True self.paged = False self.max_batch_size = 1 + torch.backends.cuda.enable_flash_sdp(False) # Set k/v cache size # cache_size is only relevant when paged mode is enabled @@ -331,6 +335,7 @@ class ExllamaV2Container: if enable_draft: self.draft_config = ExLlamaV2Config() + self.draft_config.no_flash_attn = self.config.no_flash_attn draft_model_path = pathlib.Path( unwrap(draft_args.get("draft_model_dir"), "models") )