diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 2df3bd5..f911e5e 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -791,6 +791,7 @@ class ExllamaV2Container: ) stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) + banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), []) add_bos_token = unwrap(kwargs.get("add_bos_token"), True) ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) logit_bias = kwargs.get("logit_bias") @@ -905,6 +906,9 @@ class ExllamaV2Container: kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens ) + # Set min_tokens to generate while keeping EOS banned + min_tokens = unwrap(kwargs.get("min_tokens"), 0) + # This is an inverse of skip_special_tokens decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False) @@ -925,26 +929,40 @@ class ExllamaV2Container: } ) - # Check if decode_special_tokens is supported - # TODO: Remove when a new version of ExllamaV2 is released - if decode_special_tokens: + # MARK: Function signature checks. Not needed in newer ExllamaV2 versions + + # Check if temporary token bans are supported + if min_tokens: + stream_signature = signature(self.generator.stream_ex) + + try: + _bound_vars = stream_signature.bind_partial(ban_tokens=[]) + except TypeError: + logger.warning( + "min_tokens is not supported by the currently " + "installed ExLlamaV2 version." + ) + min_tokens = 0 + + # Check if banned_strings is supported + if banned_strings: begin_stream_signature = signature(self.generator.begin_stream_ex) try: - _bound_vars = begin_stream_signature.bind_partial( - decode_special_tokens=True - ) - begin_stream_args["decode_special_tokens"] = decode_special_tokens + _bound_vars = begin_stream_signature.bind_partial(banned_strings=[]) + begin_stream_args["banned_strings"] = banned_strings except TypeError: logger.warning( - "skip_special_tokens is not supported by the currently " + "banned_strings is not supported by the currently " "installed ExLlamaV2 version." ) + banned_strings = [] # Log generation options to console # Some options are too large, so log the args instead log_generation_params( max_tokens=max_tokens, + min_tokens=min_tokens, stream=kwargs.get("stream"), **gen_settings_log_dict, token_healing=token_healing, @@ -959,6 +977,7 @@ class ExllamaV2Container: logprobs=request_logprobs, stop_conditions=stop_conditions, banned_tokens=banned_tokens, + banned_strings=banned_strings, logit_bias=logit_bias, ) @@ -997,7 +1016,10 @@ class ExllamaV2Container: # Run dict generation # Guarantees return of chunk, eos, and chunk_token_ids - raw_generation = self.generator.stream_ex() + if generated_tokens < min_tokens: + raw_generation = self.generator.stream_ex(ban_tokens=eos_tokens) + else: + raw_generation = self.generator.stream_ex() if token_healing: # Extract healed token diff --git a/common/sampling.py b/common/sampling.py index 6c22bfd..10ef93c 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -18,6 +18,11 @@ class BaseSamplerRequest(BaseModel): examples=[150], ) + min_tokens: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("min_tokens", 0), + examples=[0], + ) + generate_window: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("generate_window"), examples=[512], @@ -27,6 +32,10 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("stop", []) ) + banned_strings: Optional[Union[str, List[str]]] = Field( + default_factory=lambda: get_default_sampler_value("banned_strings", []) + ) + token_healing: Optional[bool] = Field( default_factory=lambda: get_default_sampler_value("token_healing", False) ) @@ -252,6 +261,10 @@ class BaseSamplerRequest(BaseModel): if self.stop and isinstance(self.stop, str): self.stop = [self.stop] + # Convert banned_strings to an array of strings + if self.banned_strings and isinstance(self.banned_strings, str): + self.banned_strings = [self.banned_strings] + # Convert string banned tokens to an integer list if self.banned_tokens and isinstance(self.banned_tokens, str): self.banned_tokens = [ @@ -260,8 +273,10 @@ class BaseSamplerRequest(BaseModel): gen_params = { "max_tokens": self.max_tokens, + "min_tokens": self.min_tokens, "generate_window": self.generate_window, "stop": self.stop, + "banned_strings": self.banned_strings, "add_bos_token": self.add_bos_token, "ban_eos_token": self.ban_eos_token, "skip_special_tokens": self.skip_special_tokens, diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index d91b1a1..f3dac71 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -11,10 +11,17 @@ max_tokens: override: 150 force: false +min_tokens: + override: 0 + force: false stop: override: [] force: false additive: false +banned_strings: + override: [] + force: false + additive: false token_healing: override: false force: false