Model: Change FA2 and paged attention checks

The dynamic generator requires Flash attention 2.5.7 or higher to
be installed. This is only supported on Nvidia's 30 series and higher.

If a card is AMD or lower than the 30 series, switch to compatability
mode which functions the same way as the older generator, except
without parallel batching and any features that depend on it, such as
CFG.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-24 22:33:47 -04:00 committed by Brian Dashore
parent c2d3675408
commit 408c66a1f2
3 changed files with 31 additions and 35 deletions

View file

@ -3,8 +3,6 @@
import gc
import math
import pathlib
import threading
import time
import traceback
import torch
import uuid
@ -57,10 +55,11 @@ class ExllamaV2Container:
generator: Optional[ExLlamaV2DynamicGeneratorAsync] = None
prompt_template: Optional[PromptTemplate] = None
active_loras: List[ExLlamaV2Lora] = []
paged: bool = True
# Internal config vars
cache_mode: str = "FP16"
use_cfg: bool = False
max_batch_size: int = 20
generation_config: Optional[GenerationConfig] = None
# GPU split vars
@ -115,10 +114,6 @@ class ExllamaV2Container:
available devices (default: True)
'gpu_split' (list[float]): Allocation for weights and (some)
tensors, per device
'no_flash_attn' (bool): Turns off flash attention
(increases vram usage) (default: False)
'use_cfg" (bool): Enables CFG support. Disables flash attention
(default: False)
"""
self.quiet = quiet
@ -184,18 +179,9 @@ class ExllamaV2Container:
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
)
# Enable CFG if present
self.use_cfg = unwrap(kwargs.get("use_cfg"), False)
# Enable fasttensors loading if present
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)
# Turn off flash attention if CFG is on
# Workaround until batched FA2 is fixed in exllamav2 upstream
# self.config.no_flash_attn = (
# True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False)
# )
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
@ -345,7 +331,6 @@ class ExllamaV2Container:
"cache_mode": self.cache_mode,
"chunk_size": self.config.max_input_len,
"num_experts_per_token": self.config.num_experts_per_token,
"use_cfg": self.use_cfg,
"prompt_template": self.prompt_template.name
if self.prompt_template
else None,
@ -420,10 +405,24 @@ class ExllamaV2Container:
async for value in iterate_in_threadpool(model_load_generator):
yield value
# TODO: Change these!
# Set the max batch size and check if paged support is available
max_batch_size = 1 if self.config.no_flash_attn else 20
paged = not self.config.no_flash_attn
# Disable paged mode if the user's min GPU is supported (ampere and above)
min_compute_capability = min(
set(
[
torch.cuda.get_device_capability(device=module.device_idx)[0]
for module in self.model.modules
if module.device_idx >= 0
]
)
)
if torch.version.hip or min_compute_capability < 8:
logger.warning(
"An unsupported GPU is found in this configuration. "
"Switching to compatibility mode. This disables parallel batching."
)
self.paged = False
self.max_batch_size = 1
# Create async generator
self.generator = ExLlamaV2DynamicGeneratorAsync(
@ -432,8 +431,8 @@ class ExllamaV2Container:
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=max_batch_size,
paged=paged,
max_batch_size=self.max_batch_size,
paged=self.paged,
)
# Clean up any extra vram usage from torch and cuda
@ -741,7 +740,7 @@ class ExllamaV2Container:
cfg_scale = unwrap(kwargs.get("cfg_scale"), 1.0)
negative_prompt = None
if cfg_scale not in [None, 1.0]:
if self.use_cfg:
if self.paged:
gen_settings.cfg_scale = cfg_scale
# If the negative prompt is empty, use the BOS token
@ -752,8 +751,8 @@ class ExllamaV2Container:
prompts.append(negative_prompt)
else:
logger.warning(
"CFG is currently disabled. "
"If your GPU is supported, reload your model with use_cfg = True"
"CFG is currently disabled because paged mode is disabled. "
"Please use an ampere (30 series) or higher GPU for CFG support."
)
gen_settings.token_repetition_penalty = unwrap(

View file

@ -100,9 +100,6 @@ model:
# Leave blank to automatically calculate alpha
#rope_alpha: 1.0
# Disable Flash-attention 2. Set to True for GPUs lower than Nvidia's 3000 series. (default: False)
#no_flash_attention: False
# Enable different cache modes for VRAM savings (slight performance hit).
# Possible values FP16, FP8, Q4. (default: FP16)
#cache_mode: FP16
@ -111,6 +108,12 @@ model:
# NOTE: Effects vary depending on the model. An ideal value is between 512 and 4096
#chunk_size: 2048
# Set the maximum amount of prompts to process at one time (batch)
# This will be automatically adjusted depending on the cache size.
# A max batch size of 1 processes prompts one at a time.
# NOTE: Only available for Nvidia ampere (30 series) and above GPUs
#max_batch_size: 20
# Set the prompt template for this model. If empty, attempts to look for the model's chat template. (default: None)
# If a model contains multiple templates in its tokenizer_config.json, set prompt_template to the name
# of the template you want to use.
@ -122,10 +125,6 @@ model:
# NOTE: For MoE models (ex. Mixtral) only!
#num_experts_per_token:
# Enables CFG support (default: False)
# WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream)
#use_cfg: False
# Enables fasttensors to possibly increase model loading speeds (default: False)
#fasttensors: true

View file

@ -19,7 +19,6 @@ class ModelCardParameters(BaseModel):
chunk_size: Optional[int] = 2048
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
use_cfg: Optional[bool] = None
# Draft is another model, so include it in the card params
draft: Optional["ModelCard"] = None
@ -94,7 +93,6 @@ class ModelLoadRequest(BaseModel):
chunk_size: Optional[int] = 2048
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
use_cfg: Optional[bool] = None
fasttensors: Optional[bool] = False
draft: Optional[DraftModelLoadRequest] = None
skip_queue: Optional[bool] = False