From d5551352bf306d34d41f44741a0f6b2dee8706ea Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 16 Nov 2023 17:15:33 -0500 Subject: [PATCH] Model: Fix parsing of stop conditions Add the EOS token into stop strings after checking kwargs. If ban_eos_token is on, don't add the EOS token in for extra measure. Signed-off-by: kingbri --- model.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/model.py b/model.py index 724628f..69d9c2d 100644 --- a/model.py +++ b/model.py @@ -11,7 +11,7 @@ from exllamav2.generator import( ExLlamaV2StreamingGenerator, ExLlamaV2Sampler ) -from typing import List, Optional +from typing import List, Optional, Union # Bytes to reserve on first device when loading with auto split auto_split_reserve_bytes = 96 * 1024**2 @@ -237,7 +237,7 @@ class ModelContainer: 'repetition_penalty' (float): Token repetition/presence penalty (default: 1.15) 'repetition_range' (int): Repetition penalty range (default: whole context) 'repetition_decay' (int): Repetition penalty range (default: same as range) - 'stop' (list): List of stop strings/tokens to end response (default: [EOS]) + 'stop' (List[Union[str, int]]): List of stop strings/tokens to end response (default: [EOS]) 'max_tokens' (int): Max no. tokens in response (default: 150) 'add_bos_token' (bool): Adds the BOS token to the start of the prompt (default: True) 'ban_eos_token' (bool): Bans the EOS token from generation (default: False) @@ -271,9 +271,15 @@ class ModelContainer: gen_settings.token_repetition_range = kwargs.get("repetition_range", self.config.max_seq_len) gen_settings.token_repetition_decay = kwargs.get("repetition_decay", gen_settings.token_repetition_range) - # Ban the EOS token if specified - if kwargs.get("ban_eos_token", False): + stop_conditions: List[Union[str, int]] = kwargs.get("stop", []) + ban_eos_token = kwargs.get("ban_eos_token", False) + + # Ban the EOS token if specified. If not, append to stop conditions as well. + + if ban_eos_token: gen_settings.disallow_tokens(self.tokenizer, [self.tokenizer.eos_token_id]) + else: + stop_conditions.append(self.tokenizer.eos_token_id) # Override sampler settings for temp = 0 @@ -283,9 +289,9 @@ class ModelContainer: gen_settings.top_p = 0 gen_settings.typical = 0 - # Stop conditions + # Stop conditions - self.generator.set_stop_conditions(kwargs.get("stop", [self.tokenizer.eos_token_id])) + self.generator.set_stop_conditions(stop_conditions) # Tokenized context