diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 9fe9519..a8bca56 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1099,6 +1099,11 @@ class ExllamaV2Container: if banned_tokens: gen_settings.disallow_tokens(self.tokenizer, banned_tokens) + # Set allowed tokens + allowed_tokens = unwrap(kwargs.get("allowed_tokens"), []) + if allowed_tokens: + gen_settings.allow_tokens(self.tokenizer, allowed_tokens) + # Set logit bias if logit_bias: # Create a vocab tensor if it doesn't exist for token biasing @@ -1167,7 +1172,7 @@ class ExllamaV2Container: log_prompt( f"{self.tokenizer.bos_token if add_bos_token else ''}{prompt}", request_id, - negative_prompt + negative_prompt, ) # Create and add a new job @@ -1313,6 +1318,7 @@ class ExllamaV2Container: logprobs=request_logprobs, stop_conditions=stop_conditions, banned_tokens=banned_tokens, + allowed_tokens=allowed_tokens, banned_strings=banned_strings, logit_bias=logit_bias, filters=grammar_handler.filters, diff --git a/common/sampling.py b/common/sampling.py index e0eb158..56c5b34 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -50,6 +50,13 @@ class BaseSamplerRequest(BaseModel): examples=[[128, 330]], ) + allowed_tokens: Optional[Union[List[int], str]] = Field( + default_factory=lambda: get_default_sampler_value("allowed_tokens", []), + validation_alias=AliasChoices("allowed_tokens", "allowed_token_ids"), + description="Aliases: allowed_token_ids", + examples=[[128, 330]], + ) + token_healing: Optional[bool] = Field( default_factory=lambda: get_default_sampler_value("token_healing", False) ) @@ -287,12 +294,17 @@ class BaseSamplerRequest(BaseModel): if self.banned_strings and isinstance(self.banned_strings, str): self.banned_strings = [self.banned_strings] - # Convert string banned tokens to an integer list + # Convert string banned and allowed tokens to an integer list if self.banned_tokens and isinstance(self.banned_tokens, str): self.banned_tokens = [ int(x) for x in self.banned_tokens.split(",") if x.isdigit() ] + if self.allowed_tokens and isinstance(self.allowed_tokens, str): + self.allowed_tokens = [ + int(x) for x in self.allowed_tokens.split(",") if x.isdigit() + ] + gen_params = { "max_tokens": self.max_tokens, "min_tokens": self.min_tokens, @@ -305,6 +317,7 @@ class BaseSamplerRequest(BaseModel): "token_healing": self.token_healing, "logit_bias": self.logit_bias, "banned_tokens": self.banned_tokens, + "allowed_tokens": self.allowed_tokens, "temperature": self.temperature, "temperature_last": self.temperature_last, "min_temp": self.min_temp, diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index e9367a3..b20b042 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -126,6 +126,10 @@ banned_tokens: override: [] force: false additive: false +allowed_tokens: + override: [] + force: false + additive: false # MARK: CFG scale cfg_scale: