Model: Fix parsing of stop conditions

Add the EOS token into stop strings after checking kwargs. If
ban_eos_token is on, don't add the EOS token in for extra measure.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-16 17:15:33 -05:00
parent 282b5b2931
commit d5551352bf

View file

@ -11,7 +11,7 @@ from exllamav2.generator import(
ExLlamaV2StreamingGenerator,
ExLlamaV2Sampler
)
from typing import List, Optional
from typing import List, Optional, Union
# Bytes to reserve on first device when loading with auto split
auto_split_reserve_bytes = 96 * 1024**2
@ -237,7 +237,7 @@ class ModelContainer:
'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
'repetition_range' (int): Repetition penalty range (default: whole context)
'repetition_decay' (int): Repetition penalty range (default: same as range)
'stop' (list): List of stop strings/tokens to end response (default: [EOS])
'stop' (List[Union[str, int]]): List of stop strings/tokens to end response (default: [EOS])
'max_tokens' (int): Max no. tokens in response (default: 150)
'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True)
'ban_eos_token' (bool): Bans the EOS token from generation (default: False)
@ -271,9 +271,15 @@ class ModelContainer:
gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len)
gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range)
# Ban the EOS token if specified
if kwargs.get("ban_eos_token", False):
stop_conditions: List[Union[str, int]] = kwargs.get("stop", [])
ban_eos_token = kwargs.get("ban_eos_token", False)
# Ban the EOS token if specified. If not, append to stop conditions as well.
if ban_eos_token:
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
else:
stop_conditions.append(self.tokenizer.eos_token_id)
# Override sampler settings for temp = 0
@ -283,9 +289,9 @@ class ModelContainer:
gen_settings.top_p = 0
gen_settings.typical = 0
# Stop conditions
# Stop conditions
self.generator.set_stop_conditions(kwargs.get("stop", [self.tokenizer.eos_token_id]))
self.generator.set_stop_conditions(stop_conditions)
# Tokenized context