Model: Fix flash-attn checks

If flash attention is already turned off by exllamaV2 itself, don't
try creating a paged generator. Also condense all the redundant
logic into one if statement.

Also check arch_compat_overrides to see if flash attention should
be disabled for a model arch (ex. Gemma 2)

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-06 20:57:10 -04:00
parent 27d2d5f3d2
commit 773639ea89
2 changed files with 30 additions and 8 deletions

View file

@ -5,7 +5,6 @@ import gc
import math
import pathlib
import traceback
from backends.exllamav2.utils import hardware_supports_flash_attn, supports_paged_attn
import torch
import uuid
from exllamav2 import (
@ -28,6 +27,11 @@ from loguru import logger
from typing import List, Optional, Union
from backends.exllamav2.grammar import ExLlamaV2Grammar
from backends.exllamav2.utils import (
exllama_disabled_flash_attn,
hardware_supports_flash_attn,
supports_paged_attn,
)
from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
log_generation_params,
@ -173,6 +177,9 @@ class ExllamaV2Container:
self.config.prepare()
# Check if the model arch is compatible with various exl2 features
self.config.arch_compat_overrides()
# Then override the base_seq_len if present
override_base_seq_len = kwargs.get("override_base_seq_len")
if override_base_seq_len:
@ -200,13 +207,13 @@ class ExllamaV2Container:
# Enable fasttensors loading if present
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
# Check whether the user's configuration supports paged attention
if not hardware_supports_flash_attn(gpu_device_list):
self.config.no_flash_attn = True
self.paged = False
self.max_batch_size = 1
torch.backends.cuda.enable_flash_sdp(False)
elif not supports_paged_attn():
# Check whether the user's configuration supports flash/paged attention
# Also check if exl2 has disabled flash attention
if (
exllama_disabled_flash_attn(self.config.no_flash_attn)
or not hardware_supports_flash_attn(gpu_device_list)
or not supports_paged_attn()
):
self.config.no_flash_attn = True
self.paged = False
self.max_batch_size = 1

View file

@ -94,3 +94,18 @@ def supports_paged_attn():
return False
else:
return True
def exllama_disabled_flash_attn(no_flash_attn: bool):
unsupported_message = (
"ExllamaV2 has disabled Flash Attention. \n"
"Please see the above logs for warnings/errors. \n"
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
)
if no_flash_attn:
logger.warning(unsupported_message)
return no_flash_attn