API: Ignore add_bos_token in chat completions

When fetching special tokens from the model, don't factor in the
add_bos_token and ban_eos_token parameters as switches.

In addition, change the internal handling of add_bos_token to an optional
boolean. This allows us to fallback to the model when selecting whether
or not to add the BOS token, especially for chat completions.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-05-01 22:51:15 -04:00
parent 3960612d38
commit aa657fa6e9
4 changed files with 13 additions and 17 deletions

View file

@ -123,7 +123,7 @@ class BaseModelContainer(abc.ABC):
pass
@abc.abstractmethod
def get_special_tokens(self, **kwargs) -> Dict[str, Any]:
def get_special_tokens(self) -> Dict[str, Any]:
"""
Gets special tokens used by the model/tokenizer.

View file

@ -843,12 +843,10 @@ class ExllamaV2Container(BaseModelContainer):
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]
def get_special_tokens(
self, add_bos_token: bool = True, ban_eos_token: bool = False
):
def get_special_tokens(self):
return {
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
"bos_token": self.tokenizer.bos_token,
"eos_token": self.tokenizer.eos_token,
"pad_token": self.tokenizer.pad_token,
"unk_token": self.tokenizer.unk_token,
}
@ -1242,7 +1240,7 @@ class ExllamaV2Container(BaseModelContainer):
) and gen_settings.token_repetition_range == -1
stop_conditions = params.stop
add_bos_token = params.add_bos_token
add_bos_token = unwrap(params.add_bos_token, True)
ban_eos_token = params.ban_eos_token
# Fetch EOS tokens from generation_config if they exist

View file

@ -1,4 +1,4 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from pydantic.json_schema import SkipJsonSchema
from time import time
from typing import Literal, Union, List, Optional, Dict
@ -82,6 +82,12 @@ class ChatCompletionRequest(CommonCompletionRequest):
tool_call_end: SkipJsonSchema[Optional[str]] = None
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
@field_validator("add_bos_token", mode="after")
def force_bos_token(cls, v):
"""Always disable add_bos_token with chat completions."""
return None
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")

View file

@ -238,9 +238,7 @@ async def format_messages_with_template(
# store the list of dicts rather than the ToolCallProcessor object.
message.tool_calls = ToolCallProcessor.dump(message.tool_calls)
special_tokens_dict = model.container.get_special_tokens(
add_bos_token, ban_eos_token
)
special_tokens_dict = model.container.get_special_tokens()
template_vars.update({"messages": messages, **special_tokens_dict})
@ -285,12 +283,6 @@ async def apply_chat_template(
"add_generation_prompt is False"
)
# Removes the starting BOS token if present
# This is to prevent add_bos_token from adding multiple bos tokens
bos_token = template_vars.get("bos_token")
if bos_token and prompt.startswith(bos_token):
prompt = prompt.removeprefix(bos_token)
# Add template metadata
await _append_template_metadata(data, template_vars)