Sampling: Add rudimentary DRY support
Adds DRY support based on the current exl2 dev API. Only change for optimization is dry_max_ngram instead of using a closed range. Currently, DRY range is aliased to dry_max_ngram. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
d34756dc98
commit
05c3f1194f
2 changed files with 68 additions and 1 deletions
|
|
@ -7,6 +7,7 @@ import pathlib
|
|||
import traceback
|
||||
import torch
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from exllamav2 import (
|
||||
ExLlamaV2,
|
||||
ExLlamaV2Config,
|
||||
|
|
@ -944,6 +945,14 @@ class ExllamaV2Container:
|
|||
Meant for dev wheels!
|
||||
"""
|
||||
|
||||
if unwrap(kwargs.get("dry_allowed_length"), 0) > 0 and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "dry_multiplier"
|
||||
):
|
||||
logger.warning(
|
||||
"DRY sampling is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
return kwargs
|
||||
|
||||
async def generate_gen(
|
||||
|
|
@ -1035,6 +1044,7 @@ class ExllamaV2Container:
|
|||
"Please use an ampere (30 series) or higher GPU for CFG support."
|
||||
)
|
||||
|
||||
# Penalties
|
||||
gen_settings.token_repetition_penalty = unwrap(
|
||||
kwargs.get("repetition_penalty"), 1.0
|
||||
)
|
||||
|
|
@ -1070,6 +1080,23 @@ class ExllamaV2Container:
|
|||
kwargs.get("repetition_decay"), fallback_decay, 0
|
||||
)
|
||||
|
||||
# DRY options
|
||||
dry_allowed_length = unwrap(kwargs.get("dry_allowed_length"), 0)
|
||||
|
||||
# 0 = disabled
|
||||
if dry_allowed_length:
|
||||
gen_settings.dry_allowed_length = dry_allowed_length
|
||||
gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 2.0)
|
||||
gen_settings.dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 2.0)
|
||||
gen_settings.dry_max_ngram = unwrap(kwargs.get("dry_max_ngram"), 20)
|
||||
|
||||
# Tokenize sequence breakers
|
||||
dry_sequence_breakers_json = kwargs.get("dry_sequence_breakers")
|
||||
if dry_sequence_breakers_json:
|
||||
gen_settings.dry_sequence_breakers = {
|
||||
self.encode_tokens(s)[-1] for s in dry_sequence_breakers_json
|
||||
}
|
||||
|
||||
# Initialize grammar handler
|
||||
grammar_handler = ExLlamaV2Grammar()
|
||||
|
||||
|
|
@ -1130,7 +1157,8 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Store the gen settings for logging purposes
|
||||
gen_settings_log_dict = vars(gen_settings)
|
||||
# Deepcopy to save a snapshot of vars
|
||||
gen_settings_log_dict = deepcopy(vars(gen_settings))
|
||||
|
||||
# Set banned tokens
|
||||
banned_tokens = unwrap(kwargs.get("banned_tokens"), [])
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
"""Common functions for sampling parameters"""
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
import yaml
|
||||
from copy import deepcopy
|
||||
|
|
@ -140,6 +141,28 @@ class BaseSamplerRequest(BaseModel):
|
|||
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
|
||||
)
|
||||
|
||||
dry_allowed_length: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0)
|
||||
)
|
||||
|
||||
dry_base: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_base", 2.0)
|
||||
)
|
||||
|
||||
dry_multiplier: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_multiplier", 2.0)
|
||||
)
|
||||
|
||||
# TODO: Remove these aliases
|
||||
dry_max_ngram: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_max_ngram", 20),
|
||||
alias=AliasChoices("dry_max_ngram", "dry_penalty_last_n"),
|
||||
)
|
||||
|
||||
dry_sequence_breakers: Optional[str] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", [])
|
||||
)
|
||||
|
||||
mirostat_mode: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
|
||||
)
|
||||
|
|
@ -305,6 +328,17 @@ class BaseSamplerRequest(BaseModel):
|
|||
int(x) for x in self.allowed_tokens.split(",") if x.isdigit()
|
||||
]
|
||||
|
||||
# Convert sequence breakers into an array of strings
|
||||
# NOTE: This sampler sucks to parse.
|
||||
if self.dry_sequence_breakers:
|
||||
if not self.dry_sequence_breakers.startswith("["):
|
||||
self.dry_sequence_breakers = f"[{self.dry_sequence_breakers}]"
|
||||
|
||||
try:
|
||||
self.dry_sequence_breakers = json.loads(self.dry_sequence_breakers)
|
||||
except Exception:
|
||||
self.dry_sequence_breakers = []
|
||||
|
||||
gen_params = {
|
||||
"max_tokens": self.max_tokens,
|
||||
"min_tokens": self.min_tokens,
|
||||
|
|
@ -335,6 +369,11 @@ class BaseSamplerRequest(BaseModel):
|
|||
"presence_penalty": self.presence_penalty,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"penalty_range": self.penalty_range,
|
||||
"dry_allowed_length": self.dry_allowed_length,
|
||||
"dry_base": self.dry_base,
|
||||
"dry_max_ngram": self.dry_max_ngram,
|
||||
"dry_multiplier": self.dry_multiplier,
|
||||
"dry_sequence_breakers": self.dry_sequence_breakers,
|
||||
"repetition_decay": self.repetition_decay,
|
||||
"mirostat": self.mirostat_mode == 2,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue