From a1df22668b2cf5e81fa393680851e84b6058f51c Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Fri, 10 May 2024 12:30:17 -0700 Subject: [PATCH] API: Add min_tokens Bans the EOS token until the generation reaches a minimum length. This will not prevent the model from otherwise ending the generation early by outputting other stop conditions. --- backends/exllamav2/model.py | 25 ++++++++++++++++++++++++- common/sampling.py | 6 ++++++ sampler_overrides/sample_preset.yml | 3 +++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 2df3bd5..89834f9 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -905,6 +905,9 @@ class ExllamaV2Container: kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens ) + # Set min_tokens to generate while keeping EOS banned + min_tokens = unwrap(kwargs.get("min_tokens"), 0) + # This is an inverse of skip_special_tokens decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False) @@ -941,10 +944,27 @@ class ExllamaV2Container: "installed ExLlamaV2 version." ) + # Check if temporary token bans are supported + # TODO: Remove when a new version of ExllamaV2 is released + if min_tokens: + stream_signature = signature(self.generator.stream_ex) + + try: + _bound_vars = stream_signature.bind_partial( + ban_tokens=[] + ) + except TypeError: + logger.warning( + "min_tokens is not supported by the currently " + "installed ExLlamaV2 version." + ) + min_tokens = 0 + # Log generation options to console # Some options are too large, so log the args instead log_generation_params( max_tokens=max_tokens, + min_tokens=min_tokens, stream=kwargs.get("stream"), **gen_settings_log_dict, token_healing=token_healing, @@ -997,7 +1017,10 @@ class ExllamaV2Container: # Run dict generation # Guarantees return of chunk, eos, and chunk_token_ids - raw_generation = self.generator.stream_ex() + if generated_tokens < min_tokens: + raw_generation = self.generator.stream_ex(ban_tokens=eos_tokens) + else: + raw_generation = self.generator.stream_ex() if token_healing: # Extract healed token diff --git a/common/sampling.py b/common/sampling.py index 6c22bfd..9e90a0f 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -18,6 +18,11 @@ class BaseSamplerRequest(BaseModel): examples=[150], ) + min_tokens: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("min_tokens", 0), + examples=[0], + ) + generate_window: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("generate_window"), examples=[512], @@ -260,6 +265,7 @@ class BaseSamplerRequest(BaseModel): gen_params = { "max_tokens": self.max_tokens, + "min_tokens": self.min_tokens, "generate_window": self.generate_window, "stop": self.stop, "add_bos_token": self.add_bos_token, diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index d91b1a1..ee8c32e 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -11,6 +11,9 @@ max_tokens: override: 150 force: false +min_tokens: + override: 0 + force: false stop: override: [] force: false