Model: Fix parsing of stop conditions
Add the EOS token into stop strings after checking kwargs. If ban_eos_token is on, don't add the EOS token in for extra measure. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
282b5b2931
commit
d5551352bf
1 changed files with 12 additions and 6 deletions
18
model.py
18
model.py
|
|
@ -11,7 +11,7 @@ from exllamav2.generator import(
|
|||
ExLlamaV2StreamingGenerator,
|
||||
ExLlamaV2Sampler
|
||||
)
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
# Bytes to reserve on first device when loading with auto split
|
||||
auto_split_reserve_bytes = 96 * 1024**2
|
||||
|
|
@ -237,7 +237,7 @@ class ModelContainer:
|
|||
'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15)
|
||||
'repetition_range' (int): Repetition penalty range (default: whole context)
|
||||
'repetition_decay' (int): Repetition penalty range (default: same as range)
|
||||
'stop' (list): List of stop strings/tokens to end response (default: [EOS])
|
||||
'stop' (List[Union[str, int]]): List of stop strings/tokens to end response (default: [EOS])
|
||||
'max_tokens' (int): Max no. tokens in response (default: 150)
|
||||
'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True)
|
||||
'ban_eos_token' (bool): Bans the EOS token from generation (default: False)
|
||||
|
|
@ -271,9 +271,15 @@ class ModelContainer:
|
|||
gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len)
|
||||
gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range)
|
||||
|
||||
# Ban the EOS token if specified
|
||||
if kwargs.get("ban_eos_token", False):
|
||||
stop_conditions: List[Union[str, int]] = kwargs.get("stop", [])
|
||||
ban_eos_token = kwargs.get("ban_eos_token", False)
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions as well.
|
||||
|
||||
if ban_eos_token:
|
||||
gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id])
|
||||
else:
|
||||
stop_conditions.append(self.tokenizer.eos_token_id)
|
||||
|
||||
# Override sampler settings for temp = 0
|
||||
|
||||
|
|
@ -283,9 +289,9 @@ class ModelContainer:
|
|||
gen_settings.top_p = 0
|
||||
gen_settings.typical = 0
|
||||
|
||||
# Stop conditions
|
||||
# Stop conditions
|
||||
|
||||
self.generator.set_stop_conditions(kwargs.get("stop", [self.tokenizer.eos_token_id]))
|
||||
self.generator.set_stop_conditions(stop_conditions)
|
||||
|
||||
# Tokenized context
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue