API: Add banned_tokens
Appends the banned tokens to the generation. This is equivalent of setting logit bias to -100 on a specific set of tokens. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
5750826120
commit
6f9da97114
2 changed files with 17 additions and 0 deletions
|
|
@ -177,6 +177,13 @@ class BaseSamplerRequest(BaseModel):
|
|||
examples=[1.0],
|
||||
)
|
||||
|
||||
banned_tokens: Optional[Union[List[int], str]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
|
||||
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
|
||||
description="Aliases: custom_token_bans",
|
||||
examples=[[128, 330]],
|
||||
)
|
||||
|
||||
# TODO: Return back to adaptable class-based validation But that's just too much
|
||||
# abstraction compared to simple if statements at the moment
|
||||
def validate_params(self):
|
||||
|
|
@ -245,6 +252,9 @@ class BaseSamplerRequest(BaseModel):
|
|||
if isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
|
||||
if isinstance(self.banned_tokens, str):
|
||||
self.banned_tokens = list(map(int, self.banned_tokens.split(",")))
|
||||
|
||||
gen_params = {
|
||||
"max_tokens": self.max_tokens,
|
||||
"generate_window": self.generate_window,
|
||||
|
|
@ -254,6 +264,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
"skip_special_tokens": self.skip_special_tokens,
|
||||
"token_healing": self.token_healing,
|
||||
"logit_bias": self.logit_bias,
|
||||
"banned_tokens": self.banned_tokens,
|
||||
"temperature": self.temperature,
|
||||
"temperature_last": self.temperature_last,
|
||||
"min_temp": self.min_temp,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue