From c0b631ba929f6241480ac5a8029be5768e154f2e Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Fri, 10 May 2024 13:53:55 -0700 Subject: [PATCH] API: Add banned_strings From exllamav2: List of strings that the generator will refuse to output. As soon as a partial match happens, a checkpoint is saved that the generator can rewind to if need be. Subsequent tokens are then held until the full string is resolved (match or no match) and either emitted or discarded, accordingly. --- backends/exllamav2/model.py | 18 ++++++++++++++++++ common/sampling.py | 9 +++++++++ sampler_overrides/sample_preset.yml | 4 ++++ 3 files changed, 31 insertions(+) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 89834f9..aa3c1af 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -791,6 +791,7 @@ class ExllamaV2Container: ) stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) + banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), []) add_bos_token = unwrap(kwargs.get("add_bos_token"), True) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) logit_bias = kwargs.get("logit_bias") @@ -960,6 +961,22 @@ class ExllamaV2Container: ) min_tokens = 0 + # Check if banned_strings is supported + # TODO: Remove when a new version of ExllamaV2 is released + if banned_strings: + begin_stream_signature = signature(self.generator.begin_stream_ex) + + try: + _bound_vars = begin_stream_signature.bind_partial( + banned_strings=[] + ) + begin_stream_args["banned_strings"] = banned_strings + except TypeError: + logger.warning( + "banned_strings is not supported by the currently " + "installed ExLlamaV2 version." + ) + # Log generation options to console # Some options are too large, so log the args instead log_generation_params( @@ -979,6 +996,7 @@ class ExllamaV2Container: logprobs=request_logprobs, stop_conditions=stop_conditions, banned_tokens=banned_tokens, + banned_strings=banned_strings, logit_bias=logit_bias, ) diff --git a/common/sampling.py b/common/sampling.py index 9e90a0f..10ef93c 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -32,6 +32,10 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("stop", []) ) + banned_strings: Optional[Union[str, List[str]]] = Field( + default_factory=lambda: get_default_sampler_value("banned_strings", []) + ) + token_healing: Optional[bool] = Field( default_factory=lambda: get_default_sampler_value("token_healing", False) ) @@ -257,6 +261,10 @@ class BaseSamplerRequest(BaseModel): if self.stop and isinstance(self.stop, str): self.stop = [self.stop] + # Convert banned_strings to an array of strings + if self.banned_strings and isinstance(self.banned_strings, str): + self.banned_strings = [self.banned_strings] + # Convert string banned tokens to an integer list if self.banned_tokens and isinstance(self.banned_tokens, str): self.banned_tokens = [ @@ -268,6 +276,7 @@ class BaseSamplerRequest(BaseModel): "min_tokens": self.min_tokens, "generate_window": self.generate_window, "stop": self.stop, + "banned_strings": self.banned_strings, "add_bos_token": self.add_bos_token, "ban_eos_token": self.ban_eos_token, "skip_special_tokens": self.skip_special_tokens, diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index ee8c32e..f3dac71 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -18,6 +18,10 @@ stop: override: [] force: false additive: false +banned_strings: + override: [] + force: false + additive: false token_healing: override: false force: false