Model: Add CFG support

CFG, or classifier-free guidance helps push a model in different
directions based on what the user provides.

Currently, CFG is ignored if the negative prompt is blank (it shouldn't
be used in that way anyways).

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-01-02 01:09:26 -05:00 committed by Brian Dashore
parent bb7a8e4614
commit b378773d0a
6 changed files with 96 additions and 18 deletions

View file

@ -75,6 +75,7 @@ class CommonCompletionRequest(BaseModel):
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]])
negative_prompt: Optional[str] = None
# Aliased variables
penalty_range: Optional[int] = Field(
@ -86,6 +87,10 @@ class CommonCompletionRequest(BaseModel):
),
)
cfg_scale: Optional[float] = Field(
default=1.0, validation_alias=AliasChoices("cfg_scale", "guidance_scale")
)
def to_gen_params(self):
"""Converts to internal generation parameters."""
# Convert stop to an array of strings
@ -115,4 +120,6 @@ class CommonCompletionRequest(BaseModel):
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta,
"cfg_scale": self.cfg_scale,
"negative_prompt": self.negative_prompt,
}

View file

@ -83,6 +83,7 @@ class ModelLoadRequest(BaseModel):
cache_mode: Optional[str] = "FP16"
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
use_cfg: Optional[bool] = None
draft: Optional[DraftModelLoadRequest] = None

View file

@ -106,6 +106,11 @@ def add_model_args(parser: argparse.ArgumentParser):
type=int,
help="Number of experts to use per token in MoE models",
)
model_group.add_argument(
"--use-cfg",
type=str_to_bool,
help="Enables CFG support",
)
def add_logging_args(parser: argparse.ArgumentParser):

View file

@ -85,6 +85,10 @@ model:
# NOTE: For MoE models (ex. Mixtral) only!
#num_experts_per_token:
# Enables CFG support (default: False)
# WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream)
use_cfg: False
# Options for draft models (speculative decoding). This will use more VRAM!
#draft:
# Overrides the directory to look for draft (default: models)

View file

@ -1,8 +1,8 @@
"""
Functions for logging generation events.
"""
from typing import Dict
from pydantic import BaseModel
from typing import Dict, Optional
from logger import init_logger
@ -53,12 +53,16 @@ def log_generation_params(**kwargs):
logger.info(f"Generation options: {kwargs}\n")
def log_prompt(prompt: str):
def log_prompt(prompt: str, negative_prompt: Optional[str]):
"""Logs the prompt to console."""
if PREFERENCES.prompt:
formatted_prompt = "\n" + prompt
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")
if negative_prompt:
formatted_negative_prompt = "\n" + negative_prompt
logger.info(f"Negative Prompt: {formatted_negative_prompt}\n")
def log_response(response: str):
"""Logs the response to console."""

View file

@ -47,6 +47,7 @@ class ModelContainer:
cache_fp8: bool = False
gpu_split_auto: bool = True
gpu_split: Optional[list] = None
use_cfg: bool = False
active_loras: List[ExLlamaV2Lora] = []
@ -95,6 +96,8 @@ class ModelContainer:
tensors, per device
'no_flash_attn' (bool): Turns off flash attention
(increases vram usage) (default: False)
'use_cfg" (bool): Enables CFG support. Disables flash attention
(default: False)
"""
self.quiet = quiet
@ -135,8 +138,18 @@ class ModelContainer:
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
)
# Turn off flash attention?
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False)
if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"):
self.use_cfg = unwrap(kwargs.get("use_cfg"), False)
else:
logger.warning(
"CFG is not supported by the currently installed ExLlamaV2 version."
)
# Turn off flash attention if CFG is on
# Workaround until batched FA2 is fixed in exllamav2 upstream
self.config.no_flash_attn = (
True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False)
)
# low_mem is currently broken in exllamav2. Don't use it until it's
# fixed.
@ -348,10 +361,15 @@ class ModelContainer:
if isinstance(value, str):
yield value
batch_size = 2 if self.use_cfg else 1
if self.cache_fp8:
self.cache = ExLlamaV2Cache_8bit(self.model, lazy=self.gpu_split_auto)
self.cache = ExLlamaV2Cache_8bit(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
)
else:
self.cache = ExLlamaV2Cache(self.model, lazy=self.gpu_split_auto)
self.cache = ExLlamaV2Cache(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
)
if self.gpu_split_auto:
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
@ -561,6 +579,19 @@ class ModelContainer:
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
# Set CFG scale and negative prompt
cfg_scale = unwrap(kwargs.get("cfg_scale"), 1.0)
negative_prompt = None
if cfg_scale not in [None, 1.0]:
if self.use_cfg:
gen_settings.cfg_scale = cfg_scale
negative_prompt = kwargs.get("negative_prompt")
else:
logger.warn(
"CFG is currently disabled. "
+ "Please reload your model with use_cfg = True.",
)
gen_settings.token_presence_penalty = unwrap(
kwargs.get("presence_penalty"), 0.0
)
@ -635,7 +666,7 @@ class ModelContainer:
)
# Log prompt to console
log_prompt(prompt)
log_prompt(prompt, negative_prompt)
# Set logit bias
if logit_bias:
@ -663,8 +694,18 @@ class ModelContainer:
self.generator.set_stop_conditions(stop_conditions)
# Tokenized context
ids = self.tokenizer.encode(
prompt, add_bos=add_bos_token, encode_special_tokens=True
ids, offsets = self.tokenizer.encode(
[prompt, negative_prompt]
if negative_prompt and gen_settings.cfg_scale not in [None, 1.0]
else prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
return_offsets=True,
)
mask = (
self.tokenizer.padding_mask(ids)
if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0]
else None
)
context_len = len(ids[0])
@ -683,7 +724,7 @@ class ModelContainer:
start_time = time.time()
last_chunk_time = start_time
save_tokens = torch.empty((1, 0), dtype=torch.bool)
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
chunk_buffer = ""
chunk_tokens = 0
@ -691,17 +732,31 @@ class ModelContainer:
# Ingest prompt
if chunk_tokens == 0:
ids = torch.cat((ids, save_tokens), dim=-1)
save_tokens = torch.empty((1, 0), dtype=torch.bool)
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
active_ids = ids[:, max(0, overflow) :]
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
self.generator.begin_stream(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
)
# Split for exllama versions that have CFG
if self.use_cfg:
self.generator.begin_stream(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
input_mask=mask,
position_offsets=offsets,
)
else:
self.generator.begin_stream(
active_ids,
gen_settings,
token_healing=token_healing,
loras=self.active_loras,
)
# Reset offsets for subsequent passes if the context is truncated
offsets = None
if auto_scale_penalty_range:
gen_settings.token_repetition_range = generated_tokens
@ -714,7 +769,9 @@ class ModelContainer:
ids[:, -1] = self.generator.sequence_ids[:, -2]
token_healing = False
save_tokens = torch.cat((save_tokens, tokens), dim=-1)
save_tokens = torch.cat(
(save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1
)
chunk_buffer += chunk
generated_tokens += 1