API: Add banned_tokens

Appends the banned tokens to the generation. This is equivalent of
setting logit bias to -100 on a specific set of tokens.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-28 00:40:34 -04:00
parent 5750826120
commit 6f9da97114
2 changed files with 17 additions and 0 deletions

View file

@ -177,6 +177,13 @@ class BaseSamplerRequest(BaseModel):
examples=[1.0],
)
banned_tokens: Optional[Union[List[int], str]] = Field(
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
description="Aliases: custom_token_bans",
examples=[[128, 330]],
)
# TODO: Return back to adaptable class-based validation But that's just too much
# abstraction compared to simple if statements at the moment
def validate_params(self):
@ -245,6 +252,9 @@ class BaseSamplerRequest(BaseModel):
if isinstance(self.stop, str):
self.stop = [self.stop]
if isinstance(self.banned_tokens, str):
self.banned_tokens = list(map(int, self.banned_tokens.split(",")))
gen_params = {
"max_tokens": self.max_tokens,
"generate_window": self.generate_window,
@ -254,6 +264,7 @@ class BaseSamplerRequest(BaseModel):
"skip_special_tokens": self.skip_special_tokens,
"token_healing": self.token_healing,
"logit_bias": self.logit_bias,
"banned_tokens": self.banned_tokens,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"min_temp": self.min_temp,