Sampling: Add rudimentary DRY support

Adds DRY support based on the current exl2 dev API. Only change for
optimization is dry_max_ngram instead of using a closed range.

Currently, DRY range is aliased to dry_max_ngram.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-09-07 00:48:42 -04:00
parent d34756dc98
commit 05c3f1194f
2 changed files with 68 additions and 1 deletions

View file

@ -1,5 +1,6 @@
"""Common functions for sampling parameters"""
import json
import pathlib
import yaml
from copy import deepcopy
@ -140,6 +141,28 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
)
dry_allowed_length: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0)
)
dry_base: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("dry_base", 2.0)
)
dry_multiplier: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("dry_multiplier", 2.0)
)
# TODO: Remove these aliases
dry_max_ngram: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("dry_max_ngram", 20),
alias=AliasChoices("dry_max_ngram", "dry_penalty_last_n"),
)
dry_sequence_breakers: Optional[str] = Field(
default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", [])
)
mirostat_mode: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
)
@ -305,6 +328,17 @@ class BaseSamplerRequest(BaseModel):
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
]
# Convert sequence breakers into an array of strings
# NOTE: This sampler sucks to parse.
if self.dry_sequence_breakers:
if not self.dry_sequence_breakers.startswith("["):
self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]"
try:
self.dry_sequence_breakers = json.loads(self.dry_sequence_breakers)
except Exception:
self.dry_sequence_breakers = []
gen_params = {
"max_tokens": self.max_tokens,
"min_tokens": self.min_tokens,
@ -335,6 +369,11 @@ class BaseSamplerRequest(BaseModel):
"presence_penalty": self.presence_penalty,
"repetition_penalty": self.repetition_penalty,
"penalty_range": self.penalty_range,
"dry_allowed_length": self.dry_allowed_length,
"dry_base": self.dry_base,
"dry_max_ngram": self.dry_max_ngram,
"dry_multiplier": self.dry_multiplier,
"dry_sequence_breakers": self.dry_sequence_breakers,
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,