Model: Add EOS token support from generation_config.json

GenerationConfig is meant to override various parts of the model
on generation within the transformers lib. Rather than implementing
the entire GenerationConfig framework (since it's pretty redundant),
add in multi eos_token support like VLLM.

The GenerationConfig is used only for generation, but can be used
for other uses if needed.

If there's more necessary parameters in the future, add those in
as well.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-19 22:52:32 -04:00
parent 933c5afef0
commit 8824ea0205

View file

@ -36,6 +36,7 @@ from common.templating import (
get_template_from_model_json,
get_template_from_file,
)
from common.transformers_utils import GenerationConfig
from common.utils import coalesce, unwrap
@ -57,6 +58,7 @@ class ExllamaV2Container:
# Internal config vars
cache_mode: str = "FP16"
use_cfg: bool = False
generation_config: Optional[GenerationConfig] = None
# GPU split vars
gpu_split: Optional[list] = None
@ -193,6 +195,21 @@ class ExllamaV2Container:
kwargs.get("prompt_template"), model_directory
)
# Load generation config overrides
generation_config_path = (
pathlib.Path(self.config.model_dir) / "generation_config.json"
)
if generation_config_path.exists():
try:
self.generation_config = GenerationConfig.from_file(
generation_config_path.parent
)
except Exception:
logger.error(traceback.format_exc())
logger.warning(
"Skipping generation config load because of an unexpected error."
)
# Catch all for template lookup errors
if self.prompt_template:
logger.info(
@ -566,6 +583,7 @@ class ExllamaV2Container:
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]
# TODO: Maybe support generation_config for eos_token
def get_special_tokens(
self, add_bos_token: bool = True, ban_eos_token: bool = False
):
@ -840,13 +858,20 @@ class ExllamaV2Container:
grammar_string, gen_settings, self.model, self.tokenizer
)
# Fetch EOS tokens from generation_config if they exist
eos_tokens = (
self.generation_config.eos_tokens()
if self.generation_config
else [self.tokenizer.eos_token_id]
)
# Ban the EOS token if specified. If not, append to stop conditions
# as well.
# Set this below logging to avoid polluting the stop strings array
if ban_eos_token:
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
gen_settings.disallow_tokens(self.tokenizer, eos_tokens)
else:
stop_conditions.append(self.tokenizer.eos_token_id)
stop_conditions += eos_tokens
# Stop conditions
self.generator.set_stop_conditions(stop_conditions)
@ -891,6 +916,8 @@ class ExllamaV2Container:
token_healing=token_healing,
auto_scale_penalty_range=auto_scale_penalty_range,
generate_window=generate_window,
bos_token_id=self.tokenizer.bos_token_id,
eos_token_id=eos_tokens,
add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token,
speculative_ngram=self.generator.speculative_ngram,