API: Add allowed_tokens support
This is the opposite of banned tokens. Exllama specific implementation of #181. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
10d9419f90
commit
21712578cf
3 changed files with 25 additions and 2 deletions
|
|
@ -1099,6 +1099,11 @@ class ExllamaV2Container:
|
|||
if banned_tokens:
|
||||
gen_settings.disallow_tokens(self.tokenizer, banned_tokens)
|
||||
|
||||
# Set allowed tokens
|
||||
allowed_tokens = unwrap(kwargs.get("allowed_tokens"), [])
|
||||
if allowed_tokens:
|
||||
gen_settings.allow_tokens(self.tokenizer, allowed_tokens)
|
||||
|
||||
# Set logit bias
|
||||
if logit_bias:
|
||||
# Create a vocab tensor if it doesn't exist for token biasing
|
||||
|
|
@ -1167,7 +1172,7 @@ class ExllamaV2Container:
|
|||
log_prompt(
|
||||
f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}",
|
||||
request_id,
|
||||
negative_prompt
|
||||
negative_prompt,
|
||||
)
|
||||
|
||||
# Create and add a new job
|
||||
|
|
@ -1313,6 +1318,7 @@ class ExllamaV2Container:
|
|||
logprobs=request_logprobs,
|
||||
stop_conditions=stop_conditions,
|
||||
banned_tokens=banned_tokens,
|
||||
allowed_tokens=allowed_tokens,
|
||||
banned_strings=banned_strings,
|
||||
logit_bias=logit_bias,
|
||||
filters=grammar_handler.filters,
|
||||
|
|
|
|||
|
|
@ -50,6 +50,13 @@ class BaseSamplerRequest(BaseModel):
|
|||
examples=[[128, 330]],
|
||||
)
|
||||
|
||||
allowed_tokens: Optional[Union[List[int], str]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("allowed_tokens", []),
|
||||
validation_alias=AliasChoices("allowed_tokens", "allowed_token_ids"),
|
||||
description="Aliases: allowed_token_ids",
|
||||
examples=[[128, 330]],
|
||||
)
|
||||
|
||||
token_healing: Optional[bool] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("token_healing", False)
|
||||
)
|
||||
|
|
@ -287,12 +294,17 @@ class BaseSamplerRequest(BaseModel):
|
|||
if self.banned_strings and isinstance(self.banned_strings, str):
|
||||
self.banned_strings = [self.banned_strings]
|
||||
|
||||
# Convert string banned tokens to an integer list
|
||||
# Convert string banned and allowed tokens to an integer list
|
||||
if self.banned_tokens and isinstance(self.banned_tokens, str):
|
||||
self.banned_tokens = [
|
||||
int(x) for x in self.banned_tokens.split(",") if x.isdigit()
|
||||
]
|
||||
|
||||
if self.allowed_tokens and isinstance(self.allowed_tokens, str):
|
||||
self.allowed_tokens = [
|
||||
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
|
||||
]
|
||||
|
||||
gen_params = {
|
||||
"max_tokens": self.max_tokens,
|
||||
"min_tokens": self.min_tokens,
|
||||
|
|
@ -305,6 +317,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
"token_healing": self.token_healing,
|
||||
"logit_bias": self.logit_bias,
|
||||
"banned_tokens": self.banned_tokens,
|
||||
"allowed_tokens": self.allowed_tokens,
|
||||
"temperature": self.temperature,
|
||||
"temperature_last": self.temperature_last,
|
||||
"min_temp": self.min_temp,
|
||||
|
|
|
|||
|
|
@ -126,6 +126,10 @@ banned_tokens:
|
|||
override: []
|
||||
force: false
|
||||
additive: false
|
||||
allowed_tokens:
|
||||
override: []
|
||||
force: false
|
||||
additive: false
|
||||
|
||||
# MARK: CFG scale
|
||||
cfg_scale:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue