Revision to paged attention checks (#133)
* Model: Clean up paged attention checks * Model: Move cache_size checks after paged attn checks Cache size is only relevant in paged mode * Model: Fix no_flash_attention * Model: Remove no_flash_attention Ability to use flash attention is auto-detected, so this flag is unneeded. Uninstall flash attention to disable it on supported hardware.
This commit is contained in:
parent
55d979b7a5
commit
156b74f3f0
3 changed files with 99 additions and 94 deletions
|
|
@ -5,6 +5,7 @@ import gc
|
|||
import math
|
||||
import pathlib
|
||||
import traceback
|
||||
from backends.exllamav2.utils import hardware_supports_flash_attn, supports_paged_attn
|
||||
import torch
|
||||
import uuid
|
||||
from exllamav2 import (
|
||||
|
|
@ -196,112 +197,83 @@ class ExllamaV2Container:
|
|||
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
|
||||
)
|
||||
|
||||
# Set k/v cache size
|
||||
cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len)
|
||||
|
||||
if cache_size < self.config.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is smaller than the "
|
||||
"desired context length.\n"
|
||||
"Overriding cache_size to max_seq_len. "
|
||||
)
|
||||
|
||||
cache_size = self.config.max_seq_len
|
||||
|
||||
# Enforce a multiple of 256 for cache size
|
||||
# Overestimate to ensure that the cache isn't below max_seq_len
|
||||
cache_remainder = cache_size % 256
|
||||
if cache_remainder != 0:
|
||||
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
|
||||
|
||||
logger.warning(
|
||||
f"The given cache size ({cache_size}) is "
|
||||
"not a multiple of 256.\n"
|
||||
"Overriding cache_size with an overestimated value of "
|
||||
f"{rounded_cache_size} tokens."
|
||||
)
|
||||
|
||||
cache_size = rounded_cache_size
|
||||
|
||||
# Warn user if cache size may be inadequate for CFG
|
||||
if cache_size < 2 * self.config.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
|
||||
"and may be too small for requests using CFG. \n"
|
||||
"Ignore this warning if you do not plan on using CFG."
|
||||
)
|
||||
|
||||
self.cache_size = cache_size
|
||||
|
||||
# Enable fasttensors loading if present
|
||||
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
||||
|
||||
# Check whether the user's configuration supports paged attention
|
||||
if self.config.no_flash_attn:
|
||||
if not hardware_supports_flash_attn(gpu_device_list):
|
||||
logger.warning(
|
||||
"Flash attention is disabled via config. "
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG)."
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
elif not supports_paged_attn():
|
||||
logger.warning(
|
||||
"Flash attention version >=2.5.7 "
|
||||
"is required to use paged attention. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"Please upgrade your environment by running a start script "
|
||||
"(start.bat or start.sh)\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
"pip install --upgrade .[cu121]\n\n"
|
||||
"For CUDA 11.8:\n"
|
||||
"pip install --upgrade .[cu118]\n\n"
|
||||
"NOTE: Windows users must use CUDA 12.x to use flash-attn."
|
||||
)
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
else:
|
||||
try:
|
||||
# Disable paged mode if the user's min GPU isn't supported (ampere+)
|
||||
min_compute_capability = min(
|
||||
torch.cuda.get_device_capability(device=device_idx)[0]
|
||||
for device_idx in gpu_device_list
|
||||
)
|
||||
|
||||
# Compute capability < 8 is not supported by FA2
|
||||
# AMD is also unsupported until ROCm updates its FA2 fork
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
logger.warning(
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
else:
|
||||
import flash_attn
|
||||
# Set k/v cache size
|
||||
# cache_size is only relevant when paged mode is enabled
|
||||
if self.paged:
|
||||
cache_size = unwrap(kwargs.get("cache_size"), self.config.max_seq_len)
|
||||
|
||||
flash_attn_ver = [
|
||||
int(t) for t in flash_attn.__version__.split(".") if t.isdigit()
|
||||
]
|
||||
|
||||
# Disable paged mode if the user's flash attention version < 2.5.7
|
||||
if flash_attn_ver < [2, 5, 7]:
|
||||
logger.warning(
|
||||
"Flash attention version is older than 2.5.7 "
|
||||
"which is required for paged attention. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"Please run start.bat or start.sh to update. \n"
|
||||
"NOTE: Windows users must select CUDA 12.x to use FA2."
|
||||
)
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
|
||||
except ModuleNotFoundError:
|
||||
# Disable paged mode if flash attention is not installed
|
||||
if cache_size < self.config.max_seq_len:
|
||||
logger.warning(
|
||||
"Flash attention is not installed. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG)."
|
||||
"Please run start.bat or start.sh to install. \n"
|
||||
"NOTE: Windows users must select CUDA 12.x to use FA2."
|
||||
f"The given cache_size ({cache_size}) is smaller than the "
|
||||
"desired context length.\n"
|
||||
"Overriding cache_size to max_seq_len. "
|
||||
)
|
||||
self.config.no_flash_attn = True
|
||||
self.paged = False
|
||||
self.max_batch_size = 1
|
||||
|
||||
cache_size = self.config.max_seq_len
|
||||
|
||||
# Enforce a multiple of 256 for cache size
|
||||
# Overestimate to ensure that the cache isn't below max_seq_len
|
||||
cache_remainder = cache_size % 256
|
||||
if cache_remainder != 0:
|
||||
rounded_cache_size = int(
|
||||
256 * ((cache_size - cache_remainder) / 256 + 1)
|
||||
)
|
||||
|
||||
logger.warning(
|
||||
f"The given cache size ({cache_size}) is "
|
||||
"not a multiple of 256.\n"
|
||||
"Overriding cache_size with an overestimated value of "
|
||||
f"{rounded_cache_size} tokens."
|
||||
)
|
||||
|
||||
cache_size = rounded_cache_size
|
||||
|
||||
# Warn user if cache size may be inadequate for CFG
|
||||
if cache_size < 2 * self.config.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
|
||||
"and may be too small for requests using CFG. \n"
|
||||
"Ignore this warning if you do not plan on using CFG."
|
||||
)
|
||||
|
||||
self.cache_size = cache_size
|
||||
else:
|
||||
self.cache_size = self.config.max_seq_len
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from packaging import version
|
||||
from importlib.metadata import version as package_version
|
||||
from importlib.metadata import PackageNotFoundError, version as package_version
|
||||
from loguru import logger
|
||||
import torch
|
||||
|
||||
|
||||
def check_exllama_version():
|
||||
|
|
@ -26,3 +27,36 @@ def check_exllama_version():
|
|||
)
|
||||
else:
|
||||
logger.info(f"ExllamaV2 version: {current_version}")
|
||||
|
||||
|
||||
def hardware_supports_flash_attn(gpu_device_list: list[int]):
|
||||
"""
|
||||
Check whether all GPUs in list support FA2
|
||||
|
||||
Compute capability < 8 is not supported by FA2
|
||||
AMD is also unsupported until ROCm updates its FA2 fork
|
||||
"""
|
||||
|
||||
min_compute_capability = min(
|
||||
torch.cuda.get_device_capability(device=device_idx)[0]
|
||||
for device_idx in gpu_device_list
|
||||
)
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def supports_paged_attn():
|
||||
"""Check whether the user's flash-attn version supports paged mode"""
|
||||
|
||||
required_version = version.parse("2.5.7")
|
||||
try:
|
||||
current_version = version.parse(package_version("flash-attn").split("+")[0])
|
||||
except PackageNotFoundError:
|
||||
return False
|
||||
|
||||
if current_version < required_version:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -94,7 +94,6 @@ class ModelLoadRequest(BaseModel):
|
|||
default=None,
|
||||
examples=[1.0],
|
||||
)
|
||||
no_flash_attention: Optional[bool] = False
|
||||
# low_mem: Optional[bool] = False
|
||||
cache_mode: Optional[str] = "FP16"
|
||||
chunk_size: Optional[int] = 2048
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue