API: Add allowed_tokens support
This is the opposite of banned tokens. Exllama specific implementation of #181. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
10d9419f90
commit
21712578cf
3 changed files with 25 additions and 2 deletions
|
|
@ -1099,6 +1099,11 @@ class ExllamaV2Container:
|
||||||
if banned_tokens:
|
if banned_tokens:
|
||||||
gen_settings.disallow_tokens(self.tokenizer, banned_tokens)
|
gen_settings.disallow_tokens(self.tokenizer, banned_tokens)
|
||||||
|
|
||||||
|
# Set allowed tokens
|
||||||
|
allowed_tokens = unwrap(kwargs.get("allowed_tokens"), [])
|
||||||
|
if allowed_tokens:
|
||||||
|
gen_settings.allow_tokens(self.tokenizer, allowed_tokens)
|
||||||
|
|
||||||
# Set logit bias
|
# Set logit bias
|
||||||
if logit_bias:
|
if logit_bias:
|
||||||
# Create a vocab tensor if it doesn't exist for token biasing
|
# Create a vocab tensor if it doesn't exist for token biasing
|
||||||
|
|
@ -1167,7 +1172,7 @@ class ExllamaV2Container:
|
||||||
log_prompt(
|
log_prompt(
|
||||||
f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}",
|
f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}",
|
||||||
request_id,
|
request_id,
|
||||||
negative_prompt
|
negative_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create and add a new job
|
# Create and add a new job
|
||||||
|
|
@ -1313,6 +1318,7 @@ class ExllamaV2Container:
|
||||||
logprobs=request_logprobs,
|
logprobs=request_logprobs,
|
||||||
stop_conditions=stop_conditions,
|
stop_conditions=stop_conditions,
|
||||||
banned_tokens=banned_tokens,
|
banned_tokens=banned_tokens,
|
||||||
|
allowed_tokens=allowed_tokens,
|
||||||
banned_strings=banned_strings,
|
banned_strings=banned_strings,
|
||||||
logit_bias=logit_bias,
|
logit_bias=logit_bias,
|
||||||
filters=grammar_handler.filters,
|
filters=grammar_handler.filters,
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,13 @@ class BaseSamplerRequest(BaseModel):
|
||||||
examples=[[128, 330]],
|
examples=[[128, 330]],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
allowed_tokens: Optional[Union[List[int], str]] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("allowed_tokens", []),
|
||||||
|
validation_alias=AliasChoices("allowed_tokens", "allowed_token_ids"),
|
||||||
|
description="Aliases: allowed_token_ids",
|
||||||
|
examples=[[128, 330]],
|
||||||
|
)
|
||||||
|
|
||||||
token_healing: Optional[bool] = Field(
|
token_healing: Optional[bool] = Field(
|
||||||
default_factory=lambda: get_default_sampler_value("token_healing", False)
|
default_factory=lambda: get_default_sampler_value("token_healing", False)
|
||||||
)
|
)
|
||||||
|
|
@ -287,12 +294,17 @@ class BaseSamplerRequest(BaseModel):
|
||||||
if self.banned_strings and isinstance(self.banned_strings, str):
|
if self.banned_strings and isinstance(self.banned_strings, str):
|
||||||
self.banned_strings = [self.banned_strings]
|
self.banned_strings = [self.banned_strings]
|
||||||
|
|
||||||
# Convert string banned tokens to an integer list
|
# Convert string banned and allowed tokens to an integer list
|
||||||
if self.banned_tokens and isinstance(self.banned_tokens, str):
|
if self.banned_tokens and isinstance(self.banned_tokens, str):
|
||||||
self.banned_tokens = [
|
self.banned_tokens = [
|
||||||
int(x) for x in self.banned_tokens.split(",") if x.isdigit()
|
int(x) for x in self.banned_tokens.split(",") if x.isdigit()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if self.allowed_tokens and isinstance(self.allowed_tokens, str):
|
||||||
|
self.allowed_tokens = [
|
||||||
|
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
|
||||||
|
]
|
||||||
|
|
||||||
gen_params = {
|
gen_params = {
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
"min_tokens": self.min_tokens,
|
"min_tokens": self.min_tokens,
|
||||||
|
|
@ -305,6 +317,7 @@ class BaseSamplerRequest(BaseModel):
|
||||||
"token_healing": self.token_healing,
|
"token_healing": self.token_healing,
|
||||||
"logit_bias": self.logit_bias,
|
"logit_bias": self.logit_bias,
|
||||||
"banned_tokens": self.banned_tokens,
|
"banned_tokens": self.banned_tokens,
|
||||||
|
"allowed_tokens": self.allowed_tokens,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"temperature_last": self.temperature_last,
|
"temperature_last": self.temperature_last,
|
||||||
"min_temp": self.min_temp,
|
"min_temp": self.min_temp,
|
||||||
|
|
|
||||||
|
|
@ -126,6 +126,10 @@ banned_tokens:
|
||||||
override: []
|
override: []
|
||||||
force: false
|
force: false
|
||||||
additive: false
|
additive: false
|
||||||
|
allowed_tokens:
|
||||||
|
override: []
|
||||||
|
force: false
|
||||||
|
additive: false
|
||||||
|
|
||||||
# MARK: CFG scale
|
# MARK: CFG scale
|
||||||
cfg_scale:
|
cfg_scale:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue