Model: Add EOS token support from generation_config.json
GenerationConfig is meant to override various parts of the model on generation within the transformers lib. Rather than implementing the entire GenerationConfig framework (since it's pretty redundant), add in multi eos_token support like VLLM. The GenerationConfig is used only for generation, but can be used for other uses if needed. If there's more necessary parameters in the future, add those in as well. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
933c5afef0
commit
8824ea0205
1 changed files with 29 additions and 2 deletions
|
|
@ -36,6 +36,7 @@ from common.templating import (
|
|||
get_template_from_model_json,
|
||||
get_template_from_file,
|
||||
)
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from common.utils import coalesce, unwrap
|
||||
|
||||
|
||||
|
|
@ -57,6 +58,7 @@ class ExllamaV2Container:
|
|||
# Internal config vars
|
||||
cache_mode: str = "FP16"
|
||||
use_cfg: bool = False
|
||||
generation_config: Optional[GenerationConfig] = None
|
||||
|
||||
# GPU split vars
|
||||
gpu_split: Optional[list] = None
|
||||
|
|
@ -193,6 +195,21 @@ class ExllamaV2Container:
|
|||
kwargs.get("prompt_template"), model_directory
|
||||
)
|
||||
|
||||
# Load generation config overrides
|
||||
generation_config_path = (
|
||||
pathlib.Path(self.config.model_dir) / "generation_config.json"
|
||||
)
|
||||
if generation_config_path.exists():
|
||||
try:
|
||||
self.generation_config = GenerationConfig.from_file(
|
||||
generation_config_path.parent
|
||||
)
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"Skipping generation config load because of an unexpected error."
|
||||
)
|
||||
|
||||
# Catch all for template lookup errors
|
||||
if self.prompt_template:
|
||||
logger.info(
|
||||
|
|
@ -566,6 +583,7 @@ class ExllamaV2Container:
|
|||
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
||||
)[0]
|
||||
|
||||
# TODO: Maybe support generation_config for eos_token
|
||||
def get_special_tokens(
|
||||
self, add_bos_token: bool = True, ban_eos_token: bool = False
|
||||
):
|
||||
|
|
@ -840,13 +858,20 @@ class ExllamaV2Container:
|
|||
grammar_string, gen_settings, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Fetch EOS tokens from generation_config if they exist
|
||||
eos_tokens = (
|
||||
self.generation_config.eos_tokens()
|
||||
if self.generation_config
|
||||
else [self.tokenizer.eos_token_id]
|
||||
)
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions
|
||||
# as well.
|
||||
# Set this below logging to avoid polluting the stop strings array
|
||||
if ban_eos_token:
|
||||
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
||||
gen_settings.disallow_tokens(self.tokenizer, eos_tokens)
|
||||
else:
|
||||
stop_conditions.append(self.tokenizer.eos_token_id)
|
||||
stop_conditions += eos_tokens
|
||||
|
||||
# Stop conditions
|
||||
self.generator.set_stop_conditions(stop_conditions)
|
||||
|
|
@ -891,6 +916,8 @@ class ExllamaV2Container:
|
|||
token_healing=token_healing,
|
||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||
generate_window=generate_window,
|
||||
bos_token_id=self.tokenizer.bos_token_id,
|
||||
eos_token_id=eos_tokens,
|
||||
add_bos_token=add_bos_token,
|
||||
ban_eos_token=ban_eos_token,
|
||||
speculative_ngram=self.generator.speculative_ngram,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue