API: Ignore add_bos_token in chat completions
When fetching special tokens from the model, don't factor in the add_bos_token and ban_eos_token parameters as switches. In addition, change the internal handling of add_bos_token to an optional boolean. This allows us to fallback to the model when selecting whether or not to add the BOS token, especially for chat completions. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
3960612d38
commit
aa657fa6e9
4 changed files with 13 additions and 17 deletions
|
|
@ -1,4 +1,4 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic.json_schema import SkipJsonSchema
|
||||
from time import time
|
||||
from typing import Literal, Union, List, Optional, Dict
|
||||
|
|
@ -82,6 +82,12 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
|||
tool_call_end: SkipJsonSchema[Optional[str]] = None
|
||||
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
|
||||
|
||||
@field_validator("add_bos_token", mode="after")
|
||||
def force_bos_token(cls, v):
|
||||
"""Always disable add_bos_token with chat completions."""
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
||||
|
|
|
|||
|
|
@ -238,9 +238,7 @@ async def format_messages_with_template(
|
|||
# store the list of dicts rather than the ToolCallProcessor object.
|
||||
message.tool_calls = ToolCallProcessor.dump(message.tool_calls)
|
||||
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
add_bos_token, ban_eos_token
|
||||
)
|
||||
special_tokens_dict = model.container.get_special_tokens()
|
||||
|
||||
template_vars.update({"messages": messages, **special_tokens_dict})
|
||||
|
||||
|
|
@ -285,12 +283,6 @@ async def apply_chat_template(
|
|||
"add_generation_prompt is False"
|
||||
)
|
||||
|
||||
# Removes the starting BOS token if present
|
||||
# This is to prevent add_bos_token from adding multiple bos tokens
|
||||
bos_token = template_vars.get("bos_token")
|
||||
if bos_token and prompt.startswith(bos_token):
|
||||
prompt = prompt.removeprefix(bos_token)
|
||||
|
||||
# Add template metadata
|
||||
await _append_template_metadata(data, template_vars)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue