API: Add min_tokens

Bans the EOS token until the generation reaches a minimum length. This will not prevent the model from otherwise ending the generation early by outputting other stop conditions.
This commit is contained in:
DocShotgun 2024-05-10 12:30:17 -07:00
parent 643b53e347
commit a1df22668b
3 changed files with 33 additions and 1 deletions

View file

@ -905,6 +905,9 @@ class ExllamaV2Container:
kwargs.get("max_tokens"), self.config.max_seq_len - prompt_tokens
)
# Set min_tokens to generate while keeping EOS banned
min_tokens = unwrap(kwargs.get("min_tokens"), 0)
# This is an inverse of skip_special_tokens
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
@ -941,10 +944,27 @@ class ExllamaV2Container:
"installed ExLlamaV2 version."
)
# Check if temporary token bans are supported
# TODO: Remove when a new version of ExllamaV2 is released
if min_tokens:
stream_signature = signature(self.generator.stream_ex)
try:
_bound_vars = stream_signature.bind_partial(
ban_tokens=[]
)
except TypeError:
logger.warning(
"min_tokens is not supported by the currently "
"installed ExLlamaV2 version."
)
min_tokens = 0
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
max_tokens=max_tokens,
min_tokens=min_tokens,
stream=kwargs.get("stream"),
**gen_settings_log_dict,
token_healing=token_healing,
@ -997,7 +1017,10 @@ class ExllamaV2Container:
# Run dict generation
# Guarantees return of chunk, eos, and chunk_token_ids
raw_generation = self.generator.stream_ex()
if generated_tokens < min_tokens:
raw_generation = self.generator.stream_ex(ban_tokens=eos_tokens)
else:
raw_generation = self.generator.stream_ex()
if token_healing:
# Extract healed token

View file

@ -18,6 +18,11 @@ class BaseSamplerRequest(BaseModel):
examples=[150],
)
min_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("min_tokens", 0),
examples=[0],
)
generate_window: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("generate_window"),
examples=[512],
@ -260,6 +265,7 @@ class BaseSamplerRequest(BaseModel):
gen_params = {
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
"generate_window": self.generate_window,
"stop": self.stop,
"add_bos_token": self.add_bos_token,

View file

@ -11,6 +11,9 @@
max_tokens:
override: 150
force: false
min_tokens:
override: 0
force: false
stop:
override: []
force: false