Model: Add CFG support
CFG, or classifier-free guidance helps push a model in different directions based on what the user provides. Currently, CFG is ignored if the negative prompt is blank (it shouldn't be used in that way anyways). Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
bb7a8e4614
commit
b378773d0a
6 changed files with 96 additions and 18 deletions
|
|
@ -75,6 +75,7 @@ class CommonCompletionRequest(BaseModel):
|
|||
add_bos_token: Optional[bool] = True
|
||||
ban_eos_token: Optional[bool] = False
|
||||
logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]])
|
||||
negative_prompt: Optional[str] = None
|
||||
|
||||
# Aliased variables
|
||||
penalty_range: Optional[int] = Field(
|
||||
|
|
@ -86,6 +87,10 @@ class CommonCompletionRequest(BaseModel):
|
|||
),
|
||||
)
|
||||
|
||||
cfg_scale: Optional[float] = Field(
|
||||
default=1.0, validation_alias=AliasChoices("cfg_scale", "guidance_scale")
|
||||
)
|
||||
|
||||
def to_gen_params(self):
|
||||
"""Converts to internal generation parameters."""
|
||||
# Convert stop to an array of strings
|
||||
|
|
@ -115,4 +120,6 @@ class CommonCompletionRequest(BaseModel):
|
|||
"mirostat": self.mirostat_mode == 2,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
"mirostat_eta": self.mirostat_eta,
|
||||
"cfg_scale": self.cfg_scale,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -83,6 +83,7 @@ class ModelLoadRequest(BaseModel):
|
|||
cache_mode: Optional[str] = "FP16"
|
||||
prompt_template: Optional[str] = None
|
||||
num_experts_per_token: Optional[int] = None
|
||||
use_cfg: Optional[bool] = None
|
||||
draft: Optional[DraftModelLoadRequest] = None
|
||||
|
||||
|
||||
|
|
|
|||
5
args.py
5
args.py
|
|
@ -106,6 +106,11 @@ def add_model_args(parser: argparse.ArgumentParser):
|
|||
type=int,
|
||||
help="Number of experts to use per token in MoE models",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--use-cfg",
|
||||
type=str_to_bool,
|
||||
help="Enables CFG support",
|
||||
)
|
||||
|
||||
|
||||
def add_logging_args(parser: argparse.ArgumentParser):
|
||||
|
|
|
|||
|
|
@ -85,6 +85,10 @@ model:
|
|||
# NOTE: For MoE models (ex. Mixtral) only!
|
||||
#num_experts_per_token:
|
||||
|
||||
# Enables CFG support (default: False)
|
||||
# WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream)
|
||||
use_cfg: False
|
||||
|
||||
# Options for draft models (speculative decoding). This will use more VRAM!
|
||||
#draft:
|
||||
# Overrides the directory to look for draft (default: models)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
"""
|
||||
Functions for logging generation events.
|
||||
"""
|
||||
from typing import Dict
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, Optional
|
||||
|
||||
from logger import init_logger
|
||||
|
||||
|
|
@ -53,12 +53,16 @@ def log_generation_params(**kwargs):
|
|||
logger.info(f"Generation options: {kwargs}\n")
|
||||
|
||||
|
||||
def log_prompt(prompt: str):
|
||||
def log_prompt(prompt: str, negative_prompt: Optional[str]):
|
||||
"""Logs the prompt to console."""
|
||||
if PREFERENCES.prompt:
|
||||
formatted_prompt = "\n" + prompt
|
||||
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")
|
||||
|
||||
if negative_prompt:
|
||||
formatted_negative_prompt = "\n" + negative_prompt
|
||||
logger.info(f"Negative Prompt: {formatted_negative_prompt}\n")
|
||||
|
||||
|
||||
def log_response(response: str):
|
||||
"""Logs the response to console."""
|
||||
|
|
|
|||
89
model.py
89
model.py
|
|
@ -47,6 +47,7 @@ class ModelContainer:
|
|||
cache_fp8: bool = False
|
||||
gpu_split_auto: bool = True
|
||||
gpu_split: Optional[list] = None
|
||||
use_cfg: bool = False
|
||||
|
||||
active_loras: List[ExLlamaV2Lora] = []
|
||||
|
||||
|
|
@ -95,6 +96,8 @@ class ModelContainer:
|
|||
tensors, per device
|
||||
'no_flash_attn' (bool): Turns off flash attention
|
||||
(increases vram usage) (default: False)
|
||||
'use_cfg" (bool): Enables CFG support. Disables flash attention
|
||||
(default: False)
|
||||
"""
|
||||
|
||||
self.quiet = quiet
|
||||
|
|
@ -135,8 +138,18 @@ class ModelContainer:
|
|||
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
|
||||
)
|
||||
|
||||
# Turn off flash attention?
|
||||
self.config.no_flash_attn = unwrap(kwargs.get("no_flash_attention"), False)
|
||||
if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"):
|
||||
self.use_cfg = unwrap(kwargs.get("use_cfg"), False)
|
||||
else:
|
||||
logger.warning(
|
||||
"CFG is not supported by the currently installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
# Turn off flash attention if CFG is on
|
||||
# Workaround until batched FA2 is fixed in exllamav2 upstream
|
||||
self.config.no_flash_attn = (
|
||||
True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False)
|
||||
)
|
||||
|
||||
# low_mem is currently broken in exllamav2. Don't use it until it's
|
||||
# fixed.
|
||||
|
|
@ -348,10 +361,15 @@ class ModelContainer:
|
|||
if isinstance(value, str):
|
||||
yield value
|
||||
|
||||
batch_size = 2 if self.use_cfg else 1
|
||||
if self.cache_fp8:
|
||||
self.cache = ExLlamaV2Cache_8bit(self.model, lazy=self.gpu_split_auto)
|
||||
self.cache = ExLlamaV2Cache_8bit(
|
||||
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
|
||||
)
|
||||
else:
|
||||
self.cache = ExLlamaV2Cache(self.model, lazy=self.gpu_split_auto)
|
||||
self.cache = ExLlamaV2Cache(
|
||||
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
|
||||
)
|
||||
|
||||
if self.gpu_split_auto:
|
||||
reserve = [AUTO_SPLIT_RESERVE_BYTES] + [0] * 16
|
||||
|
|
@ -561,6 +579,19 @@ class ModelContainer:
|
|||
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
|
||||
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
|
||||
|
||||
# Set CFG scale and negative prompt
|
||||
cfg_scale = unwrap(kwargs.get("cfg_scale"), 1.0)
|
||||
negative_prompt = None
|
||||
if cfg_scale not in [None, 1.0]:
|
||||
if self.use_cfg:
|
||||
gen_settings.cfg_scale = cfg_scale
|
||||
negative_prompt = kwargs.get("negative_prompt")
|
||||
else:
|
||||
logger.warn(
|
||||
"CFG is currently disabled. "
|
||||
+ "Please reload your model with use_cfg = True.",
|
||||
)
|
||||
|
||||
gen_settings.token_presence_penalty = unwrap(
|
||||
kwargs.get("presence_penalty"), 0.0
|
||||
)
|
||||
|
|
@ -635,7 +666,7 @@ class ModelContainer:
|
|||
)
|
||||
|
||||
# Log prompt to console
|
||||
log_prompt(prompt)
|
||||
log_prompt(prompt, negative_prompt)
|
||||
|
||||
# Set logit bias
|
||||
if logit_bias:
|
||||
|
|
@ -663,8 +694,18 @@ class ModelContainer:
|
|||
self.generator.set_stop_conditions(stop_conditions)
|
||||
|
||||
# Tokenized context
|
||||
ids = self.tokenizer.encode(
|
||||
prompt, add_bos=add_bos_token, encode_special_tokens=True
|
||||
ids, offsets = self.tokenizer.encode(
|
||||
[prompt, negative_prompt]
|
||||
if negative_prompt and gen_settings.cfg_scale not in [None, 1.0]
|
||||
else prompt,
|
||||
add_bos=add_bos_token,
|
||||
encode_special_tokens=True,
|
||||
return_offsets=True,
|
||||
)
|
||||
mask = (
|
||||
self.tokenizer.padding_mask(ids)
|
||||
if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0]
|
||||
else None
|
||||
)
|
||||
context_len = len(ids[0])
|
||||
|
||||
|
|
@ -683,7 +724,7 @@ class ModelContainer:
|
|||
start_time = time.time()
|
||||
last_chunk_time = start_time
|
||||
|
||||
save_tokens = torch.empty((1, 0), dtype=torch.bool)
|
||||
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
|
||||
chunk_buffer = ""
|
||||
chunk_tokens = 0
|
||||
|
||||
|
|
@ -691,17 +732,31 @@ class ModelContainer:
|
|||
# Ingest prompt
|
||||
if chunk_tokens == 0:
|
||||
ids = torch.cat((ids, save_tokens), dim=-1)
|
||||
save_tokens = torch.empty((1, 0), dtype=torch.bool)
|
||||
save_tokens = torch.empty((ids.shape[0], 0), dtype=torch.bool)
|
||||
overflow = ids.shape[-1] + generate_window - self.config.max_seq_len
|
||||
active_ids = ids[:, max(0, overflow) :]
|
||||
chunk_tokens = self.config.max_seq_len - active_ids.shape[-1]
|
||||
|
||||
self.generator.begin_stream(
|
||||
active_ids,
|
||||
gen_settings,
|
||||
token_healing=token_healing,
|
||||
loras=self.active_loras,
|
||||
)
|
||||
# Split for exllama versions that have CFG
|
||||
if self.use_cfg:
|
||||
self.generator.begin_stream(
|
||||
active_ids,
|
||||
gen_settings,
|
||||
token_healing=token_healing,
|
||||
loras=self.active_loras,
|
||||
input_mask=mask,
|
||||
position_offsets=offsets,
|
||||
)
|
||||
else:
|
||||
self.generator.begin_stream(
|
||||
active_ids,
|
||||
gen_settings,
|
||||
token_healing=token_healing,
|
||||
loras=self.active_loras,
|
||||
)
|
||||
|
||||
# Reset offsets for subsequent passes if the context is truncated
|
||||
offsets = None
|
||||
|
||||
if auto_scale_penalty_range:
|
||||
gen_settings.token_repetition_range = generated_tokens
|
||||
|
|
@ -714,7 +769,9 @@ class ModelContainer:
|
|||
ids[:, -1] = self.generator.sequence_ids[:, -2]
|
||||
token_healing = False
|
||||
|
||||
save_tokens = torch.cat((save_tokens, tokens), dim=-1)
|
||||
save_tokens = torch.cat(
|
||||
(save_tokens, tokens.expand(save_tokens.shape[0], -1)), dim=-1
|
||||
)
|
||||
chunk_buffer += chunk
|
||||
|
||||
generated_tokens += 1
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue