From 242f6b7d2a55f150e3dcc19bae38a6f11b3a000c Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Fri, 2 May 2025 21:30:18 -0400 Subject: [PATCH] Model: Simplify add_bos_token handling Set add_bos_token to True by default in the tokenizer_config stub. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- backends/base_model_container.py | 3 --- backends/exllamav2/model.py | 11 +++++------ common/transformers_utils.py | 2 +- 3 files changed, 6 insertions(+), 10 deletions(-) diff --git a/backends/base_model_container.py b/backends/base_model_container.py index 631bfbc..5c79867 100644 --- a/backends/base_model_container.py +++ b/backends/base_model_container.py @@ -127,9 +127,6 @@ class BaseModelContainer(abc.ABC): """ Gets special tokens used by the model/tokenizer. - Args: - **kwargs: Options like add_bos_token, ban_eos_token. - Returns: A dictionary mapping special token names (e.g., 'bos_token', 'eos_token') to their string or ID representation. diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index e677f04..b821d1a 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -838,7 +838,9 @@ class ExllamaV2Container(BaseModelContainer): return ( self.tokenizer.encode( text, - add_bos=unwrap(kwargs.get("add_bos_token"), True), + add_bos=unwrap( + kwargs.get("add_bos_token"), self.tokenizer_config.add_bos_token + ), encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True), embeddings=mm_embeddings_content, ) @@ -1254,14 +1256,11 @@ class ExllamaV2Container(BaseModelContainer): stop_conditions = params.stop ban_eos_token = params.ban_eos_token - print(self.tokenizer_config.add_bos_token) # Set add_bos_token for generation - add_bos_token = coalesce( - params.add_bos_token, self.tokenizer_config.add_bos_token, True + add_bos_token = unwrap( + params.add_bos_token, self.tokenizer_config.add_bos_token ) - print(add_bos_token) - # Fetch EOS tokens from generation_config if they exist eos_tokens = ( self.generation_config.eos_tokens() diff --git a/common/transformers_utils.py b/common/transformers_utils.py index 32ff52e..045312c 100644 --- a/common/transformers_utils.py +++ b/common/transformers_utils.py @@ -60,7 +60,7 @@ class TokenizerConfig(BaseModel): An abridged version of HuggingFace's tokenizer config. """ - add_bos_token: Optional[bool] = None + add_bos_token: Optional[bool] = True @classmethod async def from_file(cls, model_directory: pathlib.Path):