From bdc5189a4b460061386e24d53ae514edecc01b43 Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Wed, 30 Apr 2025 23:58:27 -0400 Subject: [PATCH] Exl3: Add chunk size, cache size, and model info Use the same algorithm for estimating and adjusting cache size based on multiples of 256 and above max seq len. Same applies for chunk size. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- backends/exllamav2/model.py | 26 +++++--- backends/exllamav2/utils.py | 68 --------------------- backends/exllamav3/model.py | 119 ++++++++++++++++++++++++++++++++++-- 3 files changed, 130 insertions(+), 83 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 09da9a2..4745241 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -33,11 +33,7 @@ from backends.exllamav2.grammar import ( ExLlamaV2Grammar, clear_grammar_func_cache, ) -from backends.exllamav2.utils import ( - exllama_disabled_flash_attn, - hardware_supports_flash_attn, - supports_paged_attn, -) +from backends.exllamav2.utils import exllama_disabled_flash_attn from backends.exllamav2.vision import clear_image_embedding_cache from common.concurrency import iterate_in_threadpool from common.gen_logging import ( @@ -46,6 +42,7 @@ from common.gen_logging import ( log_prompt, log_response, ) +from common.hardware import hardware_supports_flash_attn from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest @@ -278,11 +275,20 @@ class ExllamaV2Container(BaseModelContainer): # Check whether the user's configuration supports flash/paged attention # Also check if exl2 has disabled flash attention - if ( - exllama_disabled_flash_attn(self.config.no_flash_attn) - or not hardware_supports_flash_attn(gpu_device_list) - or not supports_paged_attn() - ): + if exllama_disabled_flash_attn( + self.config.no_flash_attn + ) or not hardware_supports_flash_attn(gpu_device_list): + gpu_unsupported_message = ( + "An unsupported GPU is found in this configuration. " + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG). \n" + "To disable compatability mode, all GPUs must be ampere " + "(30 series) or newer. AMD GPUs are not supported." + ) + + logger.warning(gpu_unsupported_message) + self.config.no_flash_attn = True if self.draft_config: self.draft_config.no_flash_attn = True diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index 0fd1fcc..1648c62 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -1,74 +1,6 @@ -import platform -import torch -from packaging import version -from importlib.metadata import PackageNotFoundError, version as package_version from loguru import logger -def hardware_supports_flash_attn(gpu_device_list: list[int]): - """ - Check whether all GPUs in list support FA2 - - Compute capability < 8 is not supported by FA2 - AMD is also unsupported until ROCm updates its FA2 fork - """ - - # Logged message if unsupported - unsupported_message = ( - "An unsupported GPU is found in this configuration. " - "Switching to compatibility mode. \n" - "This disables parallel batching " - "and features that rely on it (ex. CFG). \n" - "To disable compatability mode, all GPUs must be ampere " - "(30 series) or newer. AMD GPUs are not supported." - ) - - min_compute_capability = min( - torch.cuda.get_device_capability(device=device_idx)[0] - for device_idx in gpu_device_list - ) - - if torch.version.hip or min_compute_capability < 8: - logger.warning(unsupported_message) - return False - else: - return True - - -def supports_paged_attn(): - """Check whether the user's flash-attn version supports paged mode""" - - # Logged message if unsupported - unsupported_message = ( - "Flash attention version >=2.5.7 " - "is required to use paged attention. " - "Switching to compatibility mode. \n" - "This disables parallel batching " - "and features that rely on it (ex. CFG). \n" - "Please upgrade your environment by running an update script " - "(update_scripts/" - f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n" - "Or you can manually run a requirements update " - "using the following command:\n\n" - "For CUDA 12.1:\n" - "pip install --upgrade .[cu121]\n\n" - "NOTE: Windows users must use CUDA 12.x to use flash-attn." - ) - - required_version = version.parse("2.5.7") - try: - current_version = version.parse(package_version("flash-attn").split("+")[0]) - except PackageNotFoundError: - logger.warning(unsupported_message) - return False - - if current_version < required_version: - logger.warning(unsupported_message) - return False - else: - return True - - def exllama_disabled_flash_attn(no_flash_attn: bool): unsupported_message = ( "ExllamaV2 has disabled Flash Attention. \n" diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 06d3b29..ce27a85 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -27,13 +27,14 @@ from common.gen_logging import ( log_generation_params, log_metrics, ) +from common.hardware import hardware_supports_flash_attn from common.health import HealthManager from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate, find_prompt_template from common.transformers_utils import GenerationConfig from common.utils import coalesce, unwrap -from endpoints.core.types.model import ModelCard +from endpoints.core.types.model import ModelCard, ModelCardParameters class ExllamaV3Container(BaseModelContainer): @@ -59,11 +60,16 @@ class ExllamaV3Container(BaseModelContainer): tokenizer: Tokenizer config: Config generator: Optional[AsyncGenerator] = None + + # Class-specific vars gpu_split: List[float] | None = None gpu_split_auto: bool = True autosplit_reserve: List[float] = [96 / 1024] - max_seq_len: int use_tp: bool = False + max_seq_len: int = 4096 + cache_size: int = 4096 + chunk_size: int = 2048 + max_batch_size: Optional[int] = None # Required methods @classmethod @@ -90,8 +96,8 @@ class ExllamaV3Container(BaseModelContainer): self.model = Model.from_config(self.config) self.tokenizer = Tokenizer.from_config(self.config) - self.max_seq_len = kwargs.get("max_seq_len") - self.cache = Cache(self.model, max_num_tokens=self.max_seq_len) + # Fallback to 4096 since exl3 can't fetch from HF's config.json + self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096) # Try to set prompt template self.prompt_template = await find_prompt_template( @@ -102,6 +108,7 @@ class ExllamaV3Container(BaseModelContainer): gpu_count = torch.cuda.device_count() gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) gpu_split = unwrap(kwargs.get("gpu_split"), None) + gpu_device_list = list(range(0, gpu_count)) # Set GPU split options if gpu_count == 1: @@ -114,6 +121,12 @@ class ExllamaV3Container(BaseModelContainer): # Enable manual GPU split if provided if gpu_split: self.gpu_split = gpu_split + + gpu_device_list = [ + device_idx + for device_idx, memory in enumerate(self.gpu_split) + if memory > 0 + ] elif gpu_split_auto and not self.use_tp: # Otherwise fallback to autosplit settings self.gpu_split_auto = gpu_split_auto @@ -126,10 +139,87 @@ class ExllamaV3Container(BaseModelContainer): self.autosplit_reserve = [ value / 1024 for value in autosplit_reserve_megabytes ] + + if not hardware_supports_flash_attn(gpu_device_list): + gpu_unsupported_message = ( + "Unable to run ExllamaV3 because an unsupported GPU is " + "found in this configuration. \n" + "All GPUs must be ampere " + "(30 series) or newer. AMD GPUs are not supported." + ) + + logger.warning(gpu_unsupported_message) + + raise RuntimeError(gpu_unsupported_message) + + # Cache + user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len) + self.cache_size = self.adjust_cache_size(user_cache_size) + self.cache = Cache(self.model, max_num_tokens=self.cache_size) + + # Max batch size + self.max_batch_size = kwargs.get("max_batch_size") + + # Make sure chunk size is >= 256, keep near or below max seq len + user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048) + self.chunk_size = self.adjust_chunk_size(user_chunk_size) + # TODO: speculative decoding return self + def adjust_cache_size(self, cache_size): + if cache_size < self.max_seq_len: + logger.warning( + f"The given cache_size ({cache_size}) is smaller than the " + "desired context length.\n" + "Overriding cache_size to max_seq_len. " + ) + + cache_size = self.max_seq_len + + # Enforce a multiple of 256 for cache size + # Overestimate to ensure that the cache isn't below max_seq_len + cache_remainder = cache_size % 256 + if cache_remainder != 0: + rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1)) + + logger.warning( + f"The given cache size ({cache_size}) is " + "not a multiple of 256.\n" + "Overriding cache_size with an overestimated value of " + f"{rounded_cache_size} tokens." + ) + + cache_size = rounded_cache_size + + # Warn user if cache size may be inadequate for CFG + if cache_size < 2 * self.max_seq_len: + logger.warning( + f"The given cache_size ({cache_size}) is less than 2 * max_seq_len " + "and may be too small for requests using CFG. \n" + "Ignore this warning if you do not plan on using CFG." + ) + + return cache_size + + def adjust_chunk_size(self, user_chunk_size: int): + chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1] + chunk_remainder = chunk_size % 256 + if chunk_remainder != 0: + rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1)) + + logger.warning( + f"The given chunk size ({chunk_size}) is " + "not a multiple of 256.\n" + "Overriding chunk_size with an overestimated value of " + f"{rounded_chunk_size} tokens." + ) + + chunk_size = rounded_chunk_size + + return chunk_size + def model_info(self) -> ModelCard: """ Returns a dictionary of the current model's configuration parameters. @@ -138,7 +228,25 @@ class ExllamaV3Container(BaseModelContainer): Model parameters provided by the backend """ - pass + model_params = ModelCardParameters( + max_seq_len=self.max_seq_len, + cache_size=self.cache_size, + max_batch_size=self.max_batch_size, + # cache_mode=self.cache_mode, + chunk_size=self.chunk_size, + use_vision=self.use_vision, + ) + + if self.prompt_template: + model_params.prompt_template = self.prompt_template.name + model_params.prompt_template_content = self.prompt_template.raw_template + + model_card = ModelCard( + id=self.model_dir.name, + parameters=model_params, + ) + + return model_card async def wait_for_jobs(self, skip_wait: bool = False): """ @@ -241,6 +349,7 @@ class ExllamaV3Container(BaseModelContainer): cache=self.cache, tokenizer=self.tokenizer, max_batch_size=self.max_batch_size, + max_chunk_size=self.chunk_size, ) # Update the state of the container var