Model: Add CFG support

Dynamic generator needed multiple prompts to be tokenized and sent
for them to be sampled in serial, but generated in parallel.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-24 21:03:33 -04:00 committed by Brian Dashore
parent 06ff47e2b4
commit 5f0fb9c4ff

View file

@ -192,9 +192,9 @@ class ExllamaV2Container:
# Turn off flash attention if CFG is on
# Workaround until batched FA2 is fixed in exllamav2 upstream
self.config.no_flash_attn = (
True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False)
)
# self.config.no_flash_attn = (
# True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False)
# )
# Try to set prompt template
self.prompt_template = self.find_prompt_template(
@ -509,19 +509,17 @@ class ExllamaV2Container:
if value:
yield value
batch_size = 2 if self.use_cfg else 1
if self.cache_mode == "Q4":
self.cache = ExLlamaV2Cache_Q4(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
self.model, lazy=self.gpu_split_auto, batch_size=1
)
elif self.cache_mode == "FP8":
self.cache = ExLlamaV2Cache_8bit(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
self.model, lazy=self.gpu_split_auto, batch_size=1
)
else:
self.cache = ExLlamaV2Cache(
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
self.model, lazy=self.gpu_split_auto, batch_size=1
)
# Load model with autosplit
@ -686,6 +684,8 @@ class ExllamaV2Container:
for kwargs, check common/sampling.py
"""
prompts = [prompt]
token_healing = unwrap(kwargs.get("token_healing"), False)
generate_window = max(
unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8
@ -748,10 +748,12 @@ class ExllamaV2Container:
negative_prompt = unwrap(
kwargs.get("negative_prompt"), self.tokenizer.bos_token
)
prompts.append(negative_prompt)
else:
logger.warning(
"CFG is currently disabled. "
"Please reload your model with use_cfg = True.",
"If your GPU is supported, reload your model with use_cfg = True"
)
gen_settings.token_repetition_penalty = unwrap(
@ -873,22 +875,16 @@ class ExllamaV2Container:
else:
stop_conditions += eos_tokens
# Tokenized context
ids, offsets = self.tokenizer.encode(
[prompt, negative_prompt]
if negative_prompt and gen_settings.cfg_scale not in [None, 1.0]
else prompt,
add_bos=add_bos_token,
encode_special_tokens=True,
return_offsets=True,
)
mask = (
self.tokenizer.padding_mask(ids)
if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0]
else None
)
context_len = len(ids[0])
# Encode both positive and negative prompts
input_ids = [
self.tokenizer.encode(
prompt, add_bos=add_bos_token, encode_special_tokens=True
)
for prompt in prompts
]
# The first index will always be the positive prompt
context_len = input_ids[0].size(dim=-1)
if context_len > self.config.max_seq_len:
logger.warning(
f"Context length {context_len} is greater than max_seq_len "
@ -896,12 +892,10 @@ class ExllamaV2Container:
"metrics may not be accurate."
)
prompt_tokens = ids.shape[-1]
# Automatically set max_tokens to fill up the context
# This should be an OK default, but may be changed in the future
max_tokens = unwrap(
kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens
kwargs.get("max_tokens"), self.config.max_seq_len - context_len
)
# Set min_tokens to generate while keeping EOS banned
@ -941,7 +935,7 @@ class ExllamaV2Container:
job_id = uuid.uuid4().hex
job = ExLlamaV2DynamicJobAsync(
self.generator,
input_ids=ids,
input_ids=input_ids,
max_new_tokens=max_tokens,
gen_settings=gen_settings,
stop_conditions=stop_conditions,
@ -976,7 +970,7 @@ class ExllamaV2Container:
generation = {
"text": chunk,
"prompt_tokens": prompt_tokens,
"prompt_tokens": context_len,
"generated_tokens": generated_tokens,
"offset": len(full_response),
}