Fix disabling flash attention in exl2 config (#136)
* Model: Fix disabling flash attention in exl2 config * Model: Pass no_flash_attn to draft config * Model: Force torch flash SDP off in compatibility mode
This commit is contained in:
parent
156b74f3f0
commit
85387d97ad
1 changed files with 5 additions and 0 deletions
|
|
@ -210,8 +210,10 @@ class ExllamaV2Container:
|
|||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
elif not supports_paged_attn():
|
||||
logger.warning(
|
||||
"Flash attention version >=2.5.7 "
|
||||
|
|
@ -229,8 +231,10 @@ class ExllamaV2Container:
|
|||
"pip install --upgrade .[cu118]\n\n"
|
||||
"NOTE: Windows users must use CUDA 12.x to use flash-attn."
|
||||
)
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
|
||||
# Set k/v cache size
|
||||
# cache_size is only relevant when paged mode is enabled
|
||||
|
|
@ -331,6 +335,7 @@ class ExllamaV2Container:
|
|||
|
||||
if enable_draft:
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
self.draft_config.no_flash_attn = self.config.no_flash_attn
|
||||
draft_model_path = pathlib.Path(
|
||||
unwrap(draft_args.get("draft_model_dir"), "models")
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue