diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 89834f9..aa3c1af 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -791,6 +791,7 @@ class ExllamaV2Container: ) stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) + banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), []) add_bos_token = unwrap(kwargs.get("add_bos_token"), True) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) logit_bias = kwargs.get("logit_bias") @@ -960,6 +961,22 @@ class ExllamaV2Container: ) min_tokens = 0 + # Check if banned_strings is supported + # TODO: Remove when a new version of ExllamaV2 is released + if banned_strings: + begin_stream_signature = signature(self.generator.begin_stream_ex) + + try: + _bound_vars = begin_stream_signature.bind_partial( + banned_strings=[] + ) + begin_stream_args["banned_strings"] = banned_strings + except TypeError: + logger.warning( + "banned_strings is not supported by the currently " + "installed ExLlamaV2 version." + ) + # Log generation options to console # Some options are too large, so log the args instead log_generation_params( @@ -979,6 +996,7 @@ class ExllamaV2Container: logprobs=request_logprobs, stop_conditions=stop_conditions, banned_tokens=banned_tokens, + banned_strings=banned_strings, logit_bias=logit_bias, ) diff --git a/common/sampling.py b/common/sampling.py index 9e90a0f..10ef93c 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -32,6 +32,10 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("stop", []) ) + banned_strings: Optional[Union[str, List[str]]] = Field( + default_factory=lambda: get_default_sampler_value("banned_strings", []) + ) + token_healing: Optional[bool] = Field( default_factory=lambda: get_default_sampler_value("token_healing", False) ) @@ -257,6 +261,10 @@ class BaseSamplerRequest(BaseModel): if self.stop and isinstance(self.stop, str): self.stop = [self.stop] + # Convert banned_strings to an array of strings + if self.banned_strings and isinstance(self.banned_strings, str): + self.banned_strings = [self.banned_strings] + # Convert string banned tokens to an integer list if self.banned_tokens and isinstance(self.banned_tokens, str): self.banned_tokens = [ @@ -268,6 +276,7 @@ class BaseSamplerRequest(BaseModel): "min_tokens": self.min_tokens, "generate_window": self.generate_window, "stop": self.stop, + "banned_strings": self.banned_strings, "add_bos_token": self.add_bos_token, "ban_eos_token": self.ban_eos_token, "skip_special_tokens": self.skip_special_tokens, diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index ee8c32e..f3dac71 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -18,6 +18,10 @@ stop: override: [] force: false additive: false +banned_strings: + override: [] + force: false + additive: false token_healing: override: false force: false