Model: Fix flash-attn checks
If flash attention is already turned off by exllamaV2 itself, don't try creating a paged generator. Also condense all the redundant logic into one if statement. Also check arch_compat_overrides to see if flash attention should be disabled for a model arch (ex. Gemma 2) Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
27d2d5f3d2
commit
773639ea89
2 changed files with 30 additions and 8 deletions
|
|
@ -5,7 +5,6 @@ import gc
|
|||
import math
|
||||
import pathlib
|
||||
import traceback
|
||||
from backends.exllamav2.utils import hardware_supports_flash_attn, supports_paged_attn
|
||||
import torch
|
||||
import uuid
|
||||
from exllamav2 import (
|
||||
|
|
@ -28,6 +27,11 @@ from loguru import logger
|
|||
from typing import List, Optional, Union
|
||||
|
||||
from backends.exllamav2.grammar import ExLlamaV2Grammar
|
||||
from backends.exllamav2.utils import (
|
||||
exllama_disabled_flash_attn,
|
||||
hardware_supports_flash_attn,
|
||||
supports_paged_attn,
|
||||
)
|
||||
from common.concurrency import iterate_in_threadpool
|
||||
from common.gen_logging import (
|
||||
log_generation_params,
|
||||
|
|
@ -173,6 +177,9 @@ class ExllamaV2Container:
|
|||
|
||||
self.config.prepare()
|
||||
|
||||
# Check if the model arch is compatible with various exl2 features
|
||||
self.config.arch_compat_overrides()
|
||||
|
||||
# Then override the base_seq_len if present
|
||||
override_base_seq_len = kwargs.get("override_base_seq_len")
|
||||
if override_base_seq_len:
|
||||
|
|
@ -200,13 +207,13 @@ class ExllamaV2Container:
|
|||
# Enable fasttensors loading if present
|
||||
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
||||
|
||||
# Check whether the user's configuration supports paged attention
|
||||
if not hardware_supports_flash_attn(gpu_device_list):
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
torch.backends.cuda.enable_flash_sdp(False)
|
||||
elif not supports_paged_attn():
|
||||
# Check whether the user's configuration supports flash/paged attention
|
||||
# Also check if exl2 has disabled flash attention
|
||||
if (
|
||||
exllama_disabled_flash_attn(self.config.no_flash_attn)
|
||||
or not hardware_supports_flash_attn(gpu_device_list)
|
||||
or not supports_paged_attn()
|
||||
):
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
|
|
|
|||
|
|
@ -94,3 +94,18 @@ def supports_paged_attn():
|
|||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def exllama_disabled_flash_attn(no_flash_attn: bool):
|
||||
unsupported_message = (
|
||||
"ExllamaV2 has disabled Flash Attention. \n"
|
||||
"Please see the above logs for warnings/errors. \n"
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
)
|
||||
|
||||
if no_flash_attn:
|
||||
logger.warning(unsupported_message)
|
||||
|
||||
return no_flash_attn
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue