Exl3: Add chunk size, cache size, and model info

Use the same algorithm for estimating and adjusting cache size based
on multiples of 256 and above max seq len.

Same applies for chunk size.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-30 23:58:27 -04:00
parent 303e2dde12
commit bdc5189a4b
3 changed files with 130 additions and 83 deletions

View file

@ -33,11 +33,7 @@ from backends.exllamav2.grammar import (
ExLlamaV2Grammar,
clear_grammar_func_cache,
)
from backends.exllamav2.utils import (
exllama_disabled_flash_attn,
hardware_supports_flash_attn,
supports_paged_attn,
)
from backends.exllamav2.utils import exllama_disabled_flash_attn
from backends.exllamav2.vision import clear_image_embedding_cache
from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
@ -46,6 +42,7 @@ from common.gen_logging import (
log_prompt,
log_response,
)
from common.hardware import hardware_supports_flash_attn
from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
@ -278,11 +275,20 @@ class ExllamaV2Container(BaseModelContainer):
# Check whether the user's configuration supports flash/paged attention
# Also check if exl2 has disabled flash attention
if (
exllama_disabled_flash_attn(self.config.no_flash_attn)
or not hardware_supports_flash_attn(gpu_device_list)
or not supports_paged_attn()
):
if exllama_disabled_flash_attn(
self.config.no_flash_attn
) or not hardware_supports_flash_attn(gpu_device_list):
gpu_unsupported_message = (
"An unsupported GPU is found in this configuration. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"To disable compatability mode, all GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
logger.warning(gpu_unsupported_message)
self.config.no_flash_attn = True
if self.draft_config:
self.draft_config.no_flash_attn = True

View file

@ -1,74 +1,6 @@
import platform
import torch
from packaging import version
from importlib.metadata import PackageNotFoundError, version as package_version
from loguru import logger
def hardware_supports_flash_attn(gpu_device_list: list[int]):
"""
Check whether all GPUs in list support FA2
Compute capability < 8 is not supported by FA2
AMD is also unsupported until ROCm updates its FA2 fork
"""
# Logged message if unsupported
unsupported_message = (
"An unsupported GPU is found in this configuration. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"To disable compatability mode, all GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
min_compute_capability = min(
torch.cuda.get_device_capability(device=device_idx)[0]
for device_idx in gpu_device_list
)
if torch.version.hip or min_compute_capability < 8:
logger.warning(unsupported_message)
return False
else:
return True
def supports_paged_attn():
"""Check whether the user's flash-attn version supports paged mode"""
# Logged message if unsupported
unsupported_message = (
"Flash attention version >=2.5.7 "
"is required to use paged attention. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"Please upgrade your environment by running an update script "
"(update_scripts/"
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
"Or you can manually run a requirements update "
"using the following command:\n\n"
"For CUDA 12.1:\n"
"pip install --upgrade .[cu121]\n\n"
"NOTE: Windows users must use CUDA 12.x to use flash-attn."
)
required_version = version.parse("2.5.7")
try:
current_version = version.parse(package_version("flash-attn").split("+")[0])
except PackageNotFoundError:
logger.warning(unsupported_message)
return False
if current_version < required_version:
logger.warning(unsupported_message)
return False
else:
return True
def exllama_disabled_flash_attn(no_flash_attn: bool):
unsupported_message = (
"ExllamaV2 has disabled Flash Attention. \n"

View file

@ -27,13 +27,14 @@ from common.gen_logging import (
log_generation_params,
log_metrics,
)
from common.hardware import hardware_supports_flash_attn
from common.health import HealthManager
from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig
from common.utils import coalesce, unwrap
from endpoints.core.types.model import ModelCard
from endpoints.core.types.model import ModelCard, ModelCardParameters
class ExllamaV3Container(BaseModelContainer):
@ -59,11 +60,16 @@ class ExllamaV3Container(BaseModelContainer):
tokenizer: Tokenizer
config: Config
generator: Optional[AsyncGenerator] = None
# Class-specific vars
gpu_split: List[float] | None = None
gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96 / 1024]
max_seq_len: int
use_tp: bool = False
max_seq_len: int = 4096
cache_size: int = 4096
chunk_size: int = 2048
max_batch_size: Optional[int] = None
# Required methods
@classmethod
@ -90,8 +96,8 @@ class ExllamaV3Container(BaseModelContainer):
self.model = Model.from_config(self.config)
self.tokenizer = Tokenizer.from_config(self.config)
self.max_seq_len = kwargs.get("max_seq_len")
self.cache = Cache(self.model, max_num_tokens=self.max_seq_len)
# Fallback to 4096 since exl3 can't fetch from HF's config.json
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
# Try to set prompt template
self.prompt_template = await find_prompt_template(
@ -102,6 +108,7 @@ class ExllamaV3Container(BaseModelContainer):
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
gpu_split = unwrap(kwargs.get("gpu_split"), None)
gpu_device_list = list(range(0, gpu_count))
# Set GPU split options
if gpu_count == 1:
@ -114,6 +121,12 @@ class ExllamaV3Container(BaseModelContainer):
# Enable manual GPU split if provided
if gpu_split:
self.gpu_split = gpu_split
gpu_device_list = [
device_idx
for device_idx, memory in enumerate(self.gpu_split)
if memory > 0
]
elif gpu_split_auto and not self.use_tp:
# Otherwise fallback to autosplit settings
self.gpu_split_auto = gpu_split_auto
@ -126,10 +139,87 @@ class ExllamaV3Container(BaseModelContainer):
self.autosplit_reserve = [
value / 1024 for value in autosplit_reserve_megabytes
]
if not hardware_supports_flash_attn(gpu_device_list):
gpu_unsupported_message = (
"Unable to run ExllamaV3 because an unsupported GPU is "
"found in this configuration. \n"
"All GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
logger.warning(gpu_unsupported_message)
raise RuntimeError(gpu_unsupported_message)
# Cache
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
self.cache_size = self.adjust_cache_size(user_cache_size)
self.cache = Cache(self.model, max_num_tokens=self.cache_size)
# Max batch size
self.max_batch_size = kwargs.get("max_batch_size")
# Make sure chunk size is >= 256, keep near or below max seq len
user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048)
self.chunk_size = self.adjust_chunk_size(user_chunk_size)
# TODO: speculative decoding
return self
def adjust_cache_size(self, cache_size):
if cache_size < self.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is smaller than the "
"desired context length.\n"
"Overriding cache_size to max_seq_len. "
)
cache_size = self.max_seq_len
# Enforce a multiple of 256 for cache size
# Overestimate to ensure that the cache isn't below max_seq_len
cache_remainder = cache_size % 256
if cache_remainder != 0:
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
logger.warning(
f"The given cache size ({cache_size}) is "
"not a multiple of 256.\n"
"Overriding cache_size with an overestimated value of "
f"{rounded_cache_size} tokens."
)
cache_size = rounded_cache_size
# Warn user if cache size may be inadequate for CFG
if cache_size < 2 * self.max_seq_len:
logger.warning(
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
"and may be too small for requests using CFG. \n"
"Ignore this warning if you do not plan on using CFG."
)
return cache_size
def adjust_chunk_size(self, user_chunk_size: int):
chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1]
chunk_remainder = chunk_size % 256
if chunk_remainder != 0:
rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1))
logger.warning(
f"The given chunk size ({chunk_size}) is "
"not a multiple of 256.\n"
"Overriding chunk_size with an overestimated value of "
f"{rounded_chunk_size} tokens."
)
chunk_size = rounded_chunk_size
return chunk_size
def model_info(self) -> ModelCard:
"""
Returns a dictionary of the current model's configuration parameters.
@ -138,7 +228,25 @@ class ExllamaV3Container(BaseModelContainer):
Model parameters provided by the backend
"""
pass
model_params = ModelCardParameters(
max_seq_len=self.max_seq_len,
cache_size=self.cache_size,
max_batch_size=self.max_batch_size,
# cache_mode=self.cache_mode,
chunk_size=self.chunk_size,
use_vision=self.use_vision,
)
if self.prompt_template:
model_params.prompt_template = self.prompt_template.name
model_params.prompt_template_content = self.prompt_template.raw_template
model_card = ModelCard(
id=self.model_dir.name,
parameters=model_params,
)
return model_card
async def wait_for_jobs(self, skip_wait: bool = False):
"""
@ -241,6 +349,7 @@ class ExllamaV3Container(BaseModelContainer):
cache=self.cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
max_chunk_size=self.chunk_size,
)
# Update the state of the container var