API: Add banned_tokens
Appends the banned tokens to the generation. This is equivalent of setting logit bias to -100 on a specific set of tokens. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
5750826120
commit
6f9da97114
2 changed files with 17 additions and 0 deletions
|
|
@ -813,6 +813,11 @@ class ExllamaV2Container:
|
||||||
# Store the gen settings for logging purposes
|
# Store the gen settings for logging purposes
|
||||||
gen_settings_log_dict = vars(gen_settings)
|
gen_settings_log_dict = vars(gen_settings)
|
||||||
|
|
||||||
|
# Set banned tokens
|
||||||
|
banned_tokens = unwrap(kwargs.get("banned_tokens"), [])
|
||||||
|
if banned_tokens:
|
||||||
|
gen_settings.disallow_tokens(self.tokenizer, banned_tokens)
|
||||||
|
|
||||||
# Set logit bias
|
# Set logit bias
|
||||||
if logit_bias:
|
if logit_bias:
|
||||||
# Create a vocab tensor if it doesn't exist for token biasing
|
# Create a vocab tensor if it doesn't exist for token biasing
|
||||||
|
|
@ -953,6 +958,7 @@ class ExllamaV2Container:
|
||||||
speculative_ngram=self.generator.speculative_ngram,
|
speculative_ngram=self.generator.speculative_ngram,
|
||||||
logprobs=request_logprobs,
|
logprobs=request_logprobs,
|
||||||
stop_conditions=stop_conditions,
|
stop_conditions=stop_conditions,
|
||||||
|
banned_tokens=banned_tokens,
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -177,6 +177,13 @@ class BaseSamplerRequest(BaseModel):
|
||||||
examples=[1.0],
|
examples=[1.0],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
banned_tokens: Optional[Union[List[int], str]] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
|
||||||
|
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
|
||||||
|
description="Aliases: custom_token_bans",
|
||||||
|
examples=[[128, 330]],
|
||||||
|
)
|
||||||
|
|
||||||
# TODO: Return back to adaptable class-based validation But that's just too much
|
# TODO: Return back to adaptable class-based validation But that's just too much
|
||||||
# abstraction compared to simple if statements at the moment
|
# abstraction compared to simple if statements at the moment
|
||||||
def validate_params(self):
|
def validate_params(self):
|
||||||
|
|
@ -245,6 +252,9 @@ class BaseSamplerRequest(BaseModel):
|
||||||
if isinstance(self.stop, str):
|
if isinstance(self.stop, str):
|
||||||
self.stop = [self.stop]
|
self.stop = [self.stop]
|
||||||
|
|
||||||
|
if isinstance(self.banned_tokens, str):
|
||||||
|
self.banned_tokens = list(map(int, self.banned_tokens.split(",")))
|
||||||
|
|
||||||
gen_params = {
|
gen_params = {
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
"generate_window": self.generate_window,
|
"generate_window": self.generate_window,
|
||||||
|
|
@ -254,6 +264,7 @@ class BaseSamplerRequest(BaseModel):
|
||||||
"skip_special_tokens": self.skip_special_tokens,
|
"skip_special_tokens": self.skip_special_tokens,
|
||||||
"token_healing": self.token_healing,
|
"token_healing": self.token_healing,
|
||||||
"logit_bias": self.logit_bias,
|
"logit_bias": self.logit_bias,
|
||||||
|
"banned_tokens": self.banned_tokens,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"temperature_last": self.temperature_last,
|
"temperature_last": self.temperature_last,
|
||||||
"min_temp": self.min_temp,
|
"min_temp": self.min_temp,
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue