Exl3: Add chunk size, cache size, and model info
Use the same algorithm for estimating and adjusting cache size based on multiples of 256 and above max seq len. Same applies for chunk size. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
303e2dde12
commit
bdc5189a4b
3 changed files with 130 additions and 83 deletions
|
|
@ -33,11 +33,7 @@ from backends.exllamav2.grammar import (
|
|||
ExLlamaV2Grammar,
|
||||
clear_grammar_func_cache,
|
||||
)
|
||||
from backends.exllamav2.utils import (
|
||||
exllama_disabled_flash_attn,
|
||||
hardware_supports_flash_attn,
|
||||
supports_paged_attn,
|
||||
)
|
||||
from backends.exllamav2.utils import exllama_disabled_flash_attn
|
||||
from backends.exllamav2.vision import clear_image_embedding_cache
|
||||
from common.concurrency import iterate_in_threadpool
|
||||
from common.gen_logging import (
|
||||
|
|
@ -46,6 +42,7 @@ from common.gen_logging import (
|
|||
log_prompt,
|
||||
log_response,
|
||||
)
|
||||
from common.hardware import hardware_supports_flash_attn
|
||||
from common.health import HealthManager
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
|
|
@ -278,11 +275,20 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
|
||||
# Check whether the user's configuration supports flash/paged attention
|
||||
# Also check if exl2 has disabled flash attention
|
||||
if (
|
||||
exllama_disabled_flash_attn(self.config.no_flash_attn)
|
||||
or not hardware_supports_flash_attn(gpu_device_list)
|
||||
or not supports_paged_attn()
|
||||
):
|
||||
if exllama_disabled_flash_attn(
|
||||
self.config.no_flash_attn
|
||||
) or not hardware_supports_flash_attn(gpu_device_list):
|
||||
gpu_unsupported_message = (
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
|
||||
logger.warning(gpu_unsupported_message)
|
||||
|
||||
self.config.no_flash_attn = True
|
||||
if self.draft_config:
|
||||
self.draft_config.no_flash_attn = True
|
||||
|
|
|
|||
|
|
@ -1,74 +1,6 @@
|
|||
import platform
|
||||
import torch
|
||||
from packaging import version
|
||||
from importlib.metadata import PackageNotFoundError, version as package_version
|
||||
from loguru import logger
|
||||
|
||||
|
||||
def hardware_supports_flash_attn(gpu_device_list: list[int]):
|
||||
"""
|
||||
Check whether all GPUs in list support FA2
|
||||
|
||||
Compute capability < 8 is not supported by FA2
|
||||
AMD is also unsupported until ROCm updates its FA2 fork
|
||||
"""
|
||||
|
||||
# Logged message if unsupported
|
||||
unsupported_message = (
|
||||
"An unsupported GPU is found in this configuration. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"To disable compatability mode, all GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
|
||||
min_compute_capability = min(
|
||||
torch.cuda.get_device_capability(device=device_idx)[0]
|
||||
for device_idx in gpu_device_list
|
||||
)
|
||||
|
||||
if torch.version.hip or min_compute_capability < 8:
|
||||
logger.warning(unsupported_message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def supports_paged_attn():
|
||||
"""Check whether the user's flash-attn version supports paged mode"""
|
||||
|
||||
# Logged message if unsupported
|
||||
unsupported_message = (
|
||||
"Flash attention version >=2.5.7 "
|
||||
"is required to use paged attention. "
|
||||
"Switching to compatibility mode. \n"
|
||||
"This disables parallel batching "
|
||||
"and features that rely on it (ex. CFG). \n"
|
||||
"Please upgrade your environment by running an update script "
|
||||
"(update_scripts/"
|
||||
f"update_deps.{'bat' if platform.system() == 'Windows' else 'sh'})\n\n"
|
||||
"Or you can manually run a requirements update "
|
||||
"using the following command:\n\n"
|
||||
"For CUDA 12.1:\n"
|
||||
"pip install --upgrade .[cu121]\n\n"
|
||||
"NOTE: Windows users must use CUDA 12.x to use flash-attn."
|
||||
)
|
||||
|
||||
required_version = version.parse("2.5.7")
|
||||
try:
|
||||
current_version = version.parse(package_version("flash-attn").split("+")[0])
|
||||
except PackageNotFoundError:
|
||||
logger.warning(unsupported_message)
|
||||
return False
|
||||
|
||||
if current_version < required_version:
|
||||
logger.warning(unsupported_message)
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def exllama_disabled_flash_attn(no_flash_attn: bool):
|
||||
unsupported_message = (
|
||||
"ExllamaV2 has disabled Flash Attention. \n"
|
||||
|
|
|
|||
|
|
@ -27,13 +27,14 @@ from common.gen_logging import (
|
|||
log_generation_params,
|
||||
log_metrics,
|
||||
)
|
||||
from common.hardware import hardware_supports_flash_attn
|
||||
from common.health import HealthManager
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate, find_prompt_template
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from common.utils import coalesce, unwrap
|
||||
from endpoints.core.types.model import ModelCard
|
||||
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
||||
|
||||
|
||||
class ExllamaV3Container(BaseModelContainer):
|
||||
|
|
@ -59,11 +60,16 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
tokenizer: Tokenizer
|
||||
config: Config
|
||||
generator: Optional[AsyncGenerator] = None
|
||||
|
||||
# Class-specific vars
|
||||
gpu_split: List[float] | None = None
|
||||
gpu_split_auto: bool = True
|
||||
autosplit_reserve: List[float] = [96 / 1024]
|
||||
max_seq_len: int
|
||||
use_tp: bool = False
|
||||
max_seq_len: int = 4096
|
||||
cache_size: int = 4096
|
||||
chunk_size: int = 2048
|
||||
max_batch_size: Optional[int] = None
|
||||
|
||||
# Required methods
|
||||
@classmethod
|
||||
|
|
@ -90,8 +96,8 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
self.model = Model.from_config(self.config)
|
||||
self.tokenizer = Tokenizer.from_config(self.config)
|
||||
|
||||
self.max_seq_len = kwargs.get("max_seq_len")
|
||||
self.cache = Cache(self.model, max_num_tokens=self.max_seq_len)
|
||||
# Fallback to 4096 since exl3 can't fetch from HF's config.json
|
||||
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = await find_prompt_template(
|
||||
|
|
@ -102,6 +108,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
gpu_count = torch.cuda.device_count()
|
||||
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
||||
gpu_split = unwrap(kwargs.get("gpu_split"), None)
|
||||
gpu_device_list = list(range(0, gpu_count))
|
||||
|
||||
# Set GPU split options
|
||||
if gpu_count == 1:
|
||||
|
|
@ -114,6 +121,12 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
# Enable manual GPU split if provided
|
||||
if gpu_split:
|
||||
self.gpu_split = gpu_split
|
||||
|
||||
gpu_device_list = [
|
||||
device_idx
|
||||
for device_idx, memory in enumerate(self.gpu_split)
|
||||
if memory > 0
|
||||
]
|
||||
elif gpu_split_auto and not self.use_tp:
|
||||
# Otherwise fallback to autosplit settings
|
||||
self.gpu_split_auto = gpu_split_auto
|
||||
|
|
@ -126,10 +139,87 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
self.autosplit_reserve = [
|
||||
value / 1024 for value in autosplit_reserve_megabytes
|
||||
]
|
||||
|
||||
if not hardware_supports_flash_attn(gpu_device_list):
|
||||
gpu_unsupported_message = (
|
||||
"Unable to run ExllamaV3 because an unsupported GPU is "
|
||||
"found in this configuration. \n"
|
||||
"All GPUs must be ampere "
|
||||
"(30 series) or newer. AMD GPUs are not supported."
|
||||
)
|
||||
|
||||
logger.warning(gpu_unsupported_message)
|
||||
|
||||
raise RuntimeError(gpu_unsupported_message)
|
||||
|
||||
# Cache
|
||||
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
|
||||
self.cache_size = self.adjust_cache_size(user_cache_size)
|
||||
self.cache = Cache(self.model, max_num_tokens=self.cache_size)
|
||||
|
||||
# Max batch size
|
||||
self.max_batch_size = kwargs.get("max_batch_size")
|
||||
|
||||
# Make sure chunk size is >= 256, keep near or below max seq len
|
||||
user_chunk_size = unwrap(kwargs.get("chunk_size"), 2048)
|
||||
self.chunk_size = self.adjust_chunk_size(user_chunk_size)
|
||||
|
||||
# TODO: speculative decoding
|
||||
|
||||
return self
|
||||
|
||||
def adjust_cache_size(self, cache_size):
|
||||
if cache_size < self.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is smaller than the "
|
||||
"desired context length.\n"
|
||||
"Overriding cache_size to max_seq_len. "
|
||||
)
|
||||
|
||||
cache_size = self.max_seq_len
|
||||
|
||||
# Enforce a multiple of 256 for cache size
|
||||
# Overestimate to ensure that the cache isn't below max_seq_len
|
||||
cache_remainder = cache_size % 256
|
||||
if cache_remainder != 0:
|
||||
rounded_cache_size = int(256 * ((cache_size - cache_remainder) / 256 + 1))
|
||||
|
||||
logger.warning(
|
||||
f"The given cache size ({cache_size}) is "
|
||||
"not a multiple of 256.\n"
|
||||
"Overriding cache_size with an overestimated value of "
|
||||
f"{rounded_cache_size} tokens."
|
||||
)
|
||||
|
||||
cache_size = rounded_cache_size
|
||||
|
||||
# Warn user if cache size may be inadequate for CFG
|
||||
if cache_size < 2 * self.max_seq_len:
|
||||
logger.warning(
|
||||
f"The given cache_size ({cache_size}) is less than 2 * max_seq_len "
|
||||
"and may be too small for requests using CFG. \n"
|
||||
"Ignore this warning if you do not plan on using CFG."
|
||||
)
|
||||
|
||||
return cache_size
|
||||
|
||||
def adjust_chunk_size(self, user_chunk_size: int):
|
||||
chunk_size = sorted((256, user_chunk_size, self.max_seq_len))[1]
|
||||
chunk_remainder = chunk_size % 256
|
||||
if chunk_remainder != 0:
|
||||
rounded_chunk_size = int(256 * ((chunk_size - chunk_remainder) / 256 + 1))
|
||||
|
||||
logger.warning(
|
||||
f"The given chunk size ({chunk_size}) is "
|
||||
"not a multiple of 256.\n"
|
||||
"Overriding chunk_size with an overestimated value of "
|
||||
f"{rounded_chunk_size} tokens."
|
||||
)
|
||||
|
||||
chunk_size = rounded_chunk_size
|
||||
|
||||
return chunk_size
|
||||
|
||||
def model_info(self) -> ModelCard:
|
||||
"""
|
||||
Returns a dictionary of the current model's configuration parameters.
|
||||
|
|
@ -138,7 +228,25 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
Model parameters provided by the backend
|
||||
"""
|
||||
|
||||
pass
|
||||
model_params = ModelCardParameters(
|
||||
max_seq_len=self.max_seq_len,
|
||||
cache_size=self.cache_size,
|
||||
max_batch_size=self.max_batch_size,
|
||||
# cache_mode=self.cache_mode,
|
||||
chunk_size=self.chunk_size,
|
||||
use_vision=self.use_vision,
|
||||
)
|
||||
|
||||
if self.prompt_template:
|
||||
model_params.prompt_template = self.prompt_template.name
|
||||
model_params.prompt_template_content = self.prompt_template.raw_template
|
||||
|
||||
model_card = ModelCard(
|
||||
id=self.model_dir.name,
|
||||
parameters=model_params,
|
||||
)
|
||||
|
||||
return model_card
|
||||
|
||||
async def wait_for_jobs(self, skip_wait: bool = False):
|
||||
"""
|
||||
|
|
@ -241,6 +349,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
cache=self.cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_chunk_size=self.chunk_size,
|
||||
)
|
||||
|
||||
# Update the state of the container var
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue