Model: Simplify add_bos_token handling
Set add_bos_token to True by default in the tokenizer_config stub. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
4cb3e5d5b1
commit
242f6b7d2a
3 changed files with 6 additions and 10 deletions
|
|
@ -127,9 +127,6 @@ class BaseModelContainer(abc.ABC):
|
||||||
"""
|
"""
|
||||||
Gets special tokens used by the model/tokenizer.
|
Gets special tokens used by the model/tokenizer.
|
||||||
|
|
||||||
Args:
|
|
||||||
**kwargs: Options like add_bos_token, ban_eos_token.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping special token names (e.g., 'bos_token', 'eos_token')
|
A dictionary mapping special token names (e.g., 'bos_token', 'eos_token')
|
||||||
to their string or ID representation.
|
to their string or ID representation.
|
||||||
|
|
|
||||||
|
|
@ -838,7 +838,9 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
return (
|
return (
|
||||||
self.tokenizer.encode(
|
self.tokenizer.encode(
|
||||||
text,
|
text,
|
||||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
add_bos=unwrap(
|
||||||
|
kwargs.get("add_bos_token"), self.tokenizer_config.add_bos_token
|
||||||
|
),
|
||||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||||
embeddings=mm_embeddings_content,
|
embeddings=mm_embeddings_content,
|
||||||
)
|
)
|
||||||
|
|
@ -1254,14 +1256,11 @@ class ExllamaV2Container(BaseModelContainer):
|
||||||
stop_conditions = params.stop
|
stop_conditions = params.stop
|
||||||
ban_eos_token = params.ban_eos_token
|
ban_eos_token = params.ban_eos_token
|
||||||
|
|
||||||
print(self.tokenizer_config.add_bos_token)
|
|
||||||
# Set add_bos_token for generation
|
# Set add_bos_token for generation
|
||||||
add_bos_token = coalesce(
|
add_bos_token = unwrap(
|
||||||
params.add_bos_token, self.tokenizer_config.add_bos_token, True
|
params.add_bos_token, self.tokenizer_config.add_bos_token
|
||||||
)
|
)
|
||||||
|
|
||||||
print(add_bos_token)
|
|
||||||
|
|
||||||
# Fetch EOS tokens from generation_config if they exist
|
# Fetch EOS tokens from generation_config if they exist
|
||||||
eos_tokens = (
|
eos_tokens = (
|
||||||
self.generation_config.eos_tokens()
|
self.generation_config.eos_tokens()
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ class TokenizerConfig(BaseModel):
|
||||||
An abridged version of HuggingFace's tokenizer config.
|
An abridged version of HuggingFace's tokenizer config.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
add_bos_token: Optional[bool] = None
|
add_bos_token: Optional[bool] = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_file(cls, model_directory: pathlib.Path):
|
async def from_file(cls, model_directory: pathlib.Path):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue