Model: Repetition penalty range -> penalty range

All penalties can have a sustain (range) applied to them in exl2,
so clarify the parameter.

However, the default behaviors change based on if freq OR pres pen
is enabled. For the sanity of OAI users, have freq and pres pen only
apply on the output tokens when range is -1 (default).

But, repetition penalty still functions the same way where -1 means
the range is the max seq len.

Doing this prevents gibberish output when using the more modern freq
and presence penalties similar to llamacpp.

NOTE: This logic is still subject to change in the future, but I believe
it hits the happy medium for users who want defaults and users who want
to tinker around with the sampling knobs.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-28 18:10:19 -05:00
parent c72d30918c
commit 5dc2df68be
2 changed files with 23 additions and 6 deletions

View file

@ -77,9 +77,13 @@ class CommonCompletionRequest(BaseModel):
logit_bias: Optional[Dict[int, float]] = None
# Aliased variables
repetition_range: Optional[int] = Field(
penalty_range: Optional[int] = Field(
default=-1,
validation_alias=AliasChoices("repetition_range", "repetition_penalty_range"),
validation_alias=AliasChoices(
"penalty_range",
"repetition_range",
"repetition_penalty_range",
),
)
def to_gen_params(self):
@ -106,7 +110,7 @@ class CommonCompletionRequest(BaseModel):
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"repetition_penalty": self.repetition_penalty,
"repetition_range": self.repetition_range,
"penalty_range": self.penalty_range,
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,

View file

@ -521,7 +521,7 @@ class ModelContainer:
'presence_penalty' (float): Token presence penalty (default: 0.0)
'repetition_penalty' (float): Token repetition penalty
(default: 1.15)
'repetition_range' (int): Repetition penalty range
'penalty_range' (int): Penalty range
(default: whole context)
'repetition_decay' (int): Repetition penalty range
(default: same as range)
@ -575,15 +575,24 @@ class ModelContainer:
gen_settings.token_repetition_penalty = unwrap(
kwargs.get("repetition_penalty"), 1.0
)
# Applies for all penalties despite being called token_repetition_range
gen_settings.token_repetition_range = unwrap(
kwargs.get("repetition_range"), self.config.max_seq_len
kwargs.get("penalty_range"), self.config.max_seq_len
)
# Dynamically scale penalty range to output tokens
# Only do this if freq/pres pen is enabled and the repetition range is -1
auto_scale_penalty_range = (
gen_settings.token_frequency_penalty != 0
or gen_settings.token_presence_penalty != 0
) and gen_settings.token_repetition_range == -1
# Always make sure the fallback is 0 if range < 0
# It's technically fine to use -1, but this just validates the passed
# fallback
# Always default to 0 if something goes wrong
if gen_settings.token_repetition_range <= 0:
if gen_settings.token_repetition_range < 0:
fallback_decay = 0
else:
fallback_decay = gen_settings.token_repetition_range
@ -609,6 +618,7 @@ class ModelContainer:
max_tokens=max_tokens,
**vars(gen_settings),
token_healing=token_healing,
auto_scale_penalty_range=auto_scale_penalty_range,
add_bos_token=add_bos_token,
ban_eos_token=ban_eos_token,
stop_conditions=stop_conditions,
@ -684,6 +694,9 @@ class ModelContainer:
loras=self.active_loras,
)
if auto_scale_penalty_range:
gen_settings.token_repetition_range = generated_tokens
# Generate
chunk, eos, tokens = self.generator.stream()