From aa657fa6e913b0f43af845bb3d35b06f8e889816 Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Thu, 1 May 2025 22:51:15 -0400 Subject: [PATCH] API: Ignore add_bos_token in chat completions When fetching special tokens from the model, don't factor in the add_bos_token and ban_eos_token parameters as switches. In addition, change the internal handling of add_bos_token to an optional boolean. This allows us to fallback to the model when selecting whether or not to add the BOS token, especially for chat completions. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- backends/base_model_container.py | 2 +- backends/exllamav2/model.py | 10 ++++------ endpoints/OAI/types/chat_completion.py | 8 +++++++- endpoints/OAI/utils/chat_completion.py | 10 +--------- 4 files changed, 13 insertions(+), 17 deletions(-) diff --git a/backends/base_model_container.py b/backends/base_model_container.py index 6336d4d..631bfbc 100644 --- a/backends/base_model_container.py +++ b/backends/base_model_container.py @@ -123,7 +123,7 @@ class BaseModelContainer(abc.ABC): pass @abc.abstractmethod - def get_special_tokens(self, **kwargs) -> Dict[str, Any]: + def get_special_tokens(self) -> Dict[str, Any]: """ Gets special tokens used by the model/tokenizer. diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index b9552b2..7554ae3 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -843,12 +843,10 @@ class ExllamaV2Container(BaseModelContainer): decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True), )[0] - def get_special_tokens( - self, add_bos_token: bool = True, ban_eos_token: bool = False - ): + def get_special_tokens(self): return { - "bos_token": self.tokenizer.bos_token if add_bos_token else "", - "eos_token": self.tokenizer.eos_token if not ban_eos_token else "", + "bos_token": self.tokenizer.bos_token, + "eos_token": self.tokenizer.eos_token, "pad_token": self.tokenizer.pad_token, "unk_token": self.tokenizer.unk_token, } @@ -1242,7 +1240,7 @@ class ExllamaV2Container(BaseModelContainer): ) and gen_settings.token_repetition_range == -1 stop_conditions = params.stop - add_bos_token = params.add_bos_token + add_bos_token = unwrap(params.add_bos_token, True) ban_eos_token = params.ban_eos_token # Fetch EOS tokens from generation_config if they exist diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 86a2247..d1209a6 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator from pydantic.json_schema import SkipJsonSchema from time import time from typing import Literal, Union, List, Optional, Dict @@ -82,6 +82,12 @@ class ChatCompletionRequest(CommonCompletionRequest): tool_call_end: SkipJsonSchema[Optional[str]] = None tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema + @field_validator("add_bos_token", mode="after") + def force_bos_token(cls, v): + """Always disable add_bos_token with chat completions.""" + + return None + class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 25de949..0c89f7e 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -238,9 +238,7 @@ async def format_messages_with_template( # store the list of dicts rather than the ToolCallProcessor object. message.tool_calls = ToolCallProcessor.dump(message.tool_calls) - special_tokens_dict = model.container.get_special_tokens( - add_bos_token, ban_eos_token - ) + special_tokens_dict = model.container.get_special_tokens() template_vars.update({"messages": messages, **special_tokens_dict}) @@ -285,12 +283,6 @@ async def apply_chat_template( "add_generation_prompt is False" ) - # Removes the starting BOS token if present - # This is to prevent add_bos_token from adding multiple bos tokens - bos_token = template_vars.get("bos_token") - if bos_token and prompt.startswith(bos_token): - prompt = prompt.removeprefix(bos_token) - # Add template metadata await _append_template_metadata(data, template_vars)