Model: Add CFG support
Dynamic generator needed multiple prompts to be tokenized and sent for them to be sampled in serial, but generated in parallel. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
06ff47e2b4
commit
5f0fb9c4ff
1 changed files with 23 additions and 29 deletions
|
|
@ -192,9 +192,9 @@ class ExllamaV2Container:
|
|||
|
||||
# Turn off flash attention if CFG is on
|
||||
# Workaround until batched FA2 is fixed in exllamav2 upstream
|
||||
self.config.no_flash_attn = (
|
||||
True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False)
|
||||
)
|
||||
# self.config.no_flash_attn = (
|
||||
# True if self.use_cfg else unwrap(kwargs.get("no_flash_attention"), False)
|
||||
# )
|
||||
|
||||
# Try to set prompt template
|
||||
self.prompt_template = self.find_prompt_template(
|
||||
|
|
@ -509,19 +509,17 @@ class ExllamaV2Container:
|
|||
if value:
|
||||
yield value
|
||||
|
||||
batch_size = 2 if self.use_cfg else 1
|
||||
|
||||
if self.cache_mode == "Q4":
|
||||
self.cache = ExLlamaV2Cache_Q4(
|
||||
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
|
||||
self.model, lazy=self.gpu_split_auto, batch_size=1
|
||||
)
|
||||
elif self.cache_mode == "FP8":
|
||||
self.cache = ExLlamaV2Cache_8bit(
|
||||
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
|
||||
self.model, lazy=self.gpu_split_auto, batch_size=1
|
||||
)
|
||||
else:
|
||||
self.cache = ExLlamaV2Cache(
|
||||
self.model, lazy=self.gpu_split_auto, batch_size=batch_size
|
||||
self.model, lazy=self.gpu_split_auto, batch_size=1
|
||||
)
|
||||
|
||||
# Load model with autosplit
|
||||
|
|
@ -686,6 +684,8 @@ class ExllamaV2Container:
|
|||
for kwargs, check common/sampling.py
|
||||
"""
|
||||
|
||||
prompts = [prompt]
|
||||
|
||||
token_healing = unwrap(kwargs.get("token_healing"), False)
|
||||
generate_window = max(
|
||||
unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8
|
||||
|
|
@ -748,10 +748,12 @@ class ExllamaV2Container:
|
|||
negative_prompt = unwrap(
|
||||
kwargs.get("negative_prompt"), self.tokenizer.bos_token
|
||||
)
|
||||
|
||||
prompts.append(negative_prompt)
|
||||
else:
|
||||
logger.warning(
|
||||
"CFG is currently disabled. "
|
||||
"Please reload your model with use_cfg = True.",
|
||||
"If your GPU is supported, reload your model with use_cfg = True"
|
||||
)
|
||||
|
||||
gen_settings.token_repetition_penalty = unwrap(
|
||||
|
|
@ -873,22 +875,16 @@ class ExllamaV2Container:
|
|||
else:
|
||||
stop_conditions += eos_tokens
|
||||
|
||||
# Tokenized context
|
||||
ids, offsets = self.tokenizer.encode(
|
||||
[prompt, negative_prompt]
|
||||
if negative_prompt and gen_settings.cfg_scale not in [None, 1.0]
|
||||
else prompt,
|
||||
add_bos=add_bos_token,
|
||||
encode_special_tokens=True,
|
||||
return_offsets=True,
|
||||
)
|
||||
mask = (
|
||||
self.tokenizer.padding_mask(ids)
|
||||
if self.use_cfg and gen_settings.cfg_scale not in [None, 1.0]
|
||||
else None
|
||||
)
|
||||
context_len = len(ids[0])
|
||||
# Encode both positive and negative prompts
|
||||
input_ids = [
|
||||
self.tokenizer.encode(
|
||||
prompt, add_bos=add_bos_token, encode_special_tokens=True
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
|
||||
# The first index will always be the positive prompt
|
||||
context_len = input_ids[0].size(dim=-1)
|
||||
if context_len > self.config.max_seq_len:
|
||||
logger.warning(
|
||||
f"Context length {context_len} is greater than max_seq_len "
|
||||
|
|
@ -896,12 +892,10 @@ class ExllamaV2Container:
|
|||
"metrics may not be accurate."
|
||||
)
|
||||
|
||||
prompt_tokens = ids.shape[-1]
|
||||
|
||||
# Automatically set max_tokens to fill up the context
|
||||
# This should be an OK default, but may be changed in the future
|
||||
max_tokens = unwrap(
|
||||
kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens
|
||||
kwargs.get("max_tokens"), self.config.max_seq_len - context_len
|
||||
)
|
||||
|
||||
# Set min_tokens to generate while keeping EOS banned
|
||||
|
|
@ -941,7 +935,7 @@ class ExllamaV2Container:
|
|||
job_id = uuid.uuid4().hex
|
||||
job = ExLlamaV2DynamicJobAsync(
|
||||
self.generator,
|
||||
input_ids=ids,
|
||||
input_ids=input_ids,
|
||||
max_new_tokens=max_tokens,
|
||||
gen_settings=gen_settings,
|
||||
stop_conditions=stop_conditions,
|
||||
|
|
@ -976,7 +970,7 @@ class ExllamaV2Container:
|
|||
|
||||
generation = {
|
||||
"text": chunk,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_tokens": context_len,
|
||||
"generated_tokens": generated_tokens,
|
||||
"offset": len(full_response),
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue