Model: Repetition penalty range -> penalty range
All penalties can have a sustain (range) applied to them in exl2, so clarify the parameter. However, the default behaviors change based on if freq OR pres pen is enabled. For the sanity of OAI users, have freq and pres pen only apply on the output tokens when range is -1 (default). But, repetition penalty still functions the same way where -1 means the range is the max seq len. Doing this prevents gibberish output when using the more modern freq and presence penalties similar to llamacpp. NOTE: This logic is still subject to change in the future, but I believe it hits the happy medium for users who want defaults and users who want to tinker around with the sampling knobs. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
c72d30918c
commit
5dc2df68be
2 changed files with 23 additions and 6 deletions
|
|
@ -77,9 +77,13 @@ class CommonCompletionRequest(BaseModel):
|
|||
logit_bias: Optional[Dict[int, float]] = None
|
||||
|
||||
# Aliased variables
|
||||
repetition_range: Optional[int] = Field(
|
||||
penalty_range: Optional[int] = Field(
|
||||
default=-1,
|
||||
validation_alias=AliasChoices("repetition_range", "repetition_penalty_range"),
|
||||
validation_alias=AliasChoices(
|
||||
"penalty_range",
|
||||
"repetition_range",
|
||||
"repetition_penalty_range",
|
||||
),
|
||||
)
|
||||
|
||||
def to_gen_params(self):
|
||||
|
|
@ -106,7 +110,7 @@ class CommonCompletionRequest(BaseModel):
|
|||
"frequency_penalty": self.frequency_penalty,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"repetition_range": self.repetition_range,
|
||||
"penalty_range": self.penalty_range,
|
||||
"repetition_decay": self.repetition_decay,
|
||||
"mirostat": self.mirostat_mode == 2,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
|
|
|
|||
19
model.py
19
model.py
|
|
@ -521,7 +521,7 @@ class ModelContainer:
|
|||
'presence_penalty' (float): Token presence penalty (default: 0.0)
|
||||
'repetition_penalty' (float): Token repetition penalty
|
||||
(default: 1.15)
|
||||
'repetition_range' (int): Repetition penalty range
|
||||
'penalty_range' (int): Penalty range
|
||||
(default: whole context)
|
||||
'repetition_decay' (int): Repetition penalty range
|
||||
(default: same as range)
|
||||
|
|
@ -575,15 +575,24 @@ class ModelContainer:
|
|||
gen_settings.token_repetition_penalty = unwrap(
|
||||
kwargs.get("repetition_penalty"), 1.0
|
||||
)
|
||||
|
||||
# Applies for all penalties despite being called token_repetition_range
|
||||
gen_settings.token_repetition_range = unwrap(
|
||||
kwargs.get("repetition_range"), self.config.max_seq_len
|
||||
kwargs.get("penalty_range"), self.config.max_seq_len
|
||||
)
|
||||
|
||||
# Dynamically scale penalty range to output tokens
|
||||
# Only do this if freq/pres pen is enabled and the repetition range is -1
|
||||
auto_scale_penalty_range = (
|
||||
gen_settings.token_frequency_penalty != 0
|
||||
or gen_settings.token_presence_penalty != 0
|
||||
) and gen_settings.token_repetition_range == -1
|
||||
|
||||
# Always make sure the fallback is 0 if range < 0
|
||||
# It's technically fine to use -1, but this just validates the passed
|
||||
# fallback
|
||||
# Always default to 0 if something goes wrong
|
||||
if gen_settings.token_repetition_range <= 0:
|
||||
if gen_settings.token_repetition_range < 0:
|
||||
fallback_decay = 0
|
||||
else:
|
||||
fallback_decay = gen_settings.token_repetition_range
|
||||
|
|
@ -609,6 +618,7 @@ class ModelContainer:
|
|||
max_tokens=max_tokens,
|
||||
**vars(gen_settings),
|
||||
token_healing=token_healing,
|
||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||
add_bos_token=add_bos_token,
|
||||
ban_eos_token=ban_eos_token,
|
||||
stop_conditions=stop_conditions,
|
||||
|
|
@ -684,6 +694,9 @@ class ModelContainer:
|
|||
loras=self.active_loras,
|
||||
)
|
||||
|
||||
if auto_scale_penalty_range:
|
||||
gen_settings.token_repetition_range = generated_tokens
|
||||
|
||||
# Generate
|
||||
chunk, eos, tokens = self.generator.stream()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue