Model: Remove extra unwraps

The base sampler request already specifies the defaults, so don't
unwrap in this way.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-15 23:38:46 -04:00
parent 11ed3cf5ee
commit dcb36e9ab2
2 changed files with 41 additions and 52 deletions

View file

@ -1059,9 +1059,9 @@ class ExllamaV2Container:
prompts = [prompt]
token_healing = unwrap(kwargs.get("token_healing"), False)
token_healing = kwargs.get("token_healing")
generate_window = max(
unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8
kwargs.get("generate_window"), self.config.max_seq_len // 8
)
# Sampler settings
@ -1071,35 +1071,34 @@ class ExllamaV2Container:
kwargs = self.check_unsupported_settings(**kwargs)
# Apply settings
partial(gen_settings.temperature, 1.0)
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
gen_settings.smoothing_factor = unwrap(kwargs.get("smoothing_factor"), 0.0)
gen_settings.top_k = unwrap(kwargs.get("top_k"), 0)
gen_settings.top_p = unwrap(kwargs.get("top_p"), 1.0)
gen_settings.top_a = unwrap(kwargs.get("top_a"), 0.0)
gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0)
gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0)
gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
gen_settings.skew = unwrap(kwargs.get("skew"), 0)
gen_settings.temperature = kwargs.get("temperature")
gen_settings.temperature_last = kwargs.get("temperature_last")
gen_settings.smoothing_factor = kwargs.get("smoothing_factor")
gen_settings.top_k = kwargs.get("top_k")
gen_settings.top_p = kwargs.get("top_p")
gen_settings.top_a = kwargs.get("top_a")
gen_settings.min_p = kwargs.get("min_p")
gen_settings.tfs = kwargs.get("tfs")
gen_settings.typical = kwargs.get("typical")
gen_settings.mirostat = kwargs.get("mirostat")
gen_settings.skew = kwargs.get("skew")
# XTC
xtc_probability = unwrap(kwargs.get("xtc_probability"), 0.0)
xtc_probability = kwargs.get("xtc_probability")
if xtc_probability > 0.0:
gen_settings.xtc_probability = xtc_probability
# 0.1 is the default for this value
gen_settings.xtc_threshold = unwrap(kwargs.get("xtc_threshold", 0.1))
gen_settings.xtc_threshold = kwargs.get("xtc_threshold")
# DynaTemp settings
max_temp = unwrap(kwargs.get("max_temp"), 1.0)
min_temp = unwrap(kwargs.get("min_temp"), 1.0)
max_temp = kwargs.get("max_temp")
min_temp = kwargs.get("min_temp")
if max_temp > min_temp:
gen_settings.max_temp = max_temp
gen_settings.min_temp = min_temp
gen_settings.temp_exponent = unwrap(kwargs.get("temp_exponent"), 1.0)
gen_settings.temp_exponent = kwargs.get("temp_exponent")
else:
# Force to default values
gen_settings.max_temp = 1.0
@ -1116,11 +1115,11 @@ class ExllamaV2Container:
)
# Default tau and eta fallbacks don't matter if mirostat is off
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
gen_settings.mirostat_tau = kwargs.get("mirostat_tau")
gen_settings.mirostat_eta = kwargs.get("mirostat_eta")
# Set CFG scale and negative prompt
cfg_scale = unwrap(kwargs.get("cfg_scale"), 1.0)
cfg_scale = kwargs.get("cfg_scale")
negative_prompt = None
if cfg_scale not in [None, 1.0]:
if self.paged:
@ -1139,15 +1138,9 @@ class ExllamaV2Container:
)
# Penalties
gen_settings.token_repetition_penalty = unwrap(
kwargs.get("repetition_penalty"), 1.0
)
gen_settings.token_frequency_penalty = unwrap(
kwargs.get("frequency_penalty"), 0.0
)
gen_settings.token_presence_penalty = unwrap(
kwargs.get("presence_penalty"), 0.0
)
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty")
gen_settings.token_frequency_penalty = kwargs.get("frequency_penalty")
gen_settings.token_presence_penalty = kwargs.get("presence_penalty")
# Applies for all penalties despite being called token_repetition_range
gen_settings.token_repetition_range = unwrap(
@ -1175,16 +1168,14 @@ class ExllamaV2Container:
)
# DRY options
dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 0.0)
dry_multiplier = kwargs.get("dry_multiplier")
# < 0 = disabled
if dry_multiplier > 0:
gen_settings.dry_multiplier = dry_multiplier
gen_settings.dry_allowed_length = unwrap(
kwargs.get("dry_allowed_length"), 0
)
gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 0.0)
gen_settings.dry_allowed_length = kwargs.get("dry_allowed_length")
gen_settings.dry_base = kwargs.get("dry_base")
# Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range
# Use max_seq_len as the fallback to stay consistent
@ -1203,24 +1194,24 @@ class ExllamaV2Container:
grammar_handler = ExLlamaV2Grammar()
# Add JSON schema filter if it exists
json_schema = unwrap(kwargs.get("json_schema"))
json_schema = kwargs.get("json_schema")
if json_schema:
grammar_handler.add_json_schema_filter(
json_schema, self.model, self.tokenizer
)
# Add regex filter if it exists
regex_pattern = unwrap(kwargs.get("regex_pattern"))
regex_pattern = kwargs.get("regex_pattern")
if regex_pattern:
grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer)
# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
grammar_string = kwargs.get("grammar_string")
if grammar_string:
grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer)
# Set banned strings
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
banned_strings = kwargs.get("banned_strings")
if banned_strings and len(grammar_handler.filters) > 0:
logger.warning(
"Disabling banned_strings because "
@ -1229,18 +1220,16 @@ class ExllamaV2Container:
banned_strings = []
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
stop_conditions = kwargs.get("stop")
add_bos_token = kwargs.get("add_bos_token"), True
ban_eos_token = kwargs.get("ban_eos_token"), False
logit_bias = kwargs.get("logit_bias")
# Logprobs
request_logprobs = unwrap(kwargs.get("logprobs"), 0)
request_logprobs = kwargs.get("logprobs")
# Speculative Ngram
self.generator.speculative_ngram = unwrap(
kwargs.get("speculative_ngram"), False
)
self.generator.speculative_ngram = kwargs.get("speculative_ngram")
# Override sampler settings for temp = 0
if gen_settings.temperature == 0:
@ -1255,12 +1244,12 @@ class ExllamaV2Container:
)
# Set banned tokens
banned_tokens = unwrap(kwargs.get("banned_tokens"), [])
banned_tokens = kwargs.get("banned_tokens")
if banned_tokens:
gen_settings.disallow_tokens(self.tokenizer, banned_tokens)
# Set allowed tokens
allowed_tokens = unwrap(kwargs.get("allowed_tokens"), [])
allowed_tokens = kwargs.get("allowed_tokens")
if allowed_tokens:
gen_settings.allow_tokens(self.tokenizer, allowed_tokens)
@ -1361,10 +1350,10 @@ class ExllamaV2Container:
)
# Set min_tokens to generate while keeping EOS banned
min_tokens = unwrap(kwargs.get("min_tokens"), 0)
min_tokens = kwargs.get("min_tokens")
# This is an inverse of skip_special_tokens
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
decode_special_tokens = not kwargs.get("skip_special_tokens")
# Log prompt to console. Add the BOS token if specified
log_prompt(

View file

@ -42,7 +42,7 @@ class BaseSamplerRequest(BaseModel):
)
generate_window: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("generate_window"),
default_factory=lambda: get_default_sampler_value("generate_window", 512),
examples=[512],
ge=0,
)