Model: Remove extra unwraps
The base sampler request already specifies the defaults, so don't unwrap in this way. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
11ed3cf5ee
commit
dcb36e9ab2
2 changed files with 41 additions and 52 deletions
|
|
@ -1059,9 +1059,9 @@ class ExllamaV2Container:
|
|||
|
||||
prompts = [prompt]
|
||||
|
||||
token_healing = unwrap(kwargs.get("token_healing"), False)
|
||||
token_healing = kwargs.get("token_healing")
|
||||
generate_window = max(
|
||||
unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8
|
||||
kwargs.get("generate_window"), self.config.max_seq_len // 8
|
||||
)
|
||||
|
||||
# Sampler settings
|
||||
|
|
@ -1071,35 +1071,34 @@ class ExllamaV2Container:
|
|||
kwargs = self.check_unsupported_settings(**kwargs)
|
||||
|
||||
# Apply settings
|
||||
partial(gen_settings.temperature, 1.0)
|
||||
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
|
||||
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
|
||||
gen_settings.smoothing_factor = unwrap(kwargs.get("smoothing_factor"), 0.0)
|
||||
gen_settings.top_k = unwrap(kwargs.get("top_k"), 0)
|
||||
gen_settings.top_p = unwrap(kwargs.get("top_p"), 1.0)
|
||||
gen_settings.top_a = unwrap(kwargs.get("top_a"), 0.0)
|
||||
gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0)
|
||||
gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0)
|
||||
gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
|
||||
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
|
||||
gen_settings.skew = unwrap(kwargs.get("skew"), 0)
|
||||
gen_settings.temperature = kwargs.get("temperature")
|
||||
gen_settings.temperature_last = kwargs.get("temperature_last")
|
||||
gen_settings.smoothing_factor = kwargs.get("smoothing_factor")
|
||||
gen_settings.top_k = kwargs.get("top_k")
|
||||
gen_settings.top_p = kwargs.get("top_p")
|
||||
gen_settings.top_a = kwargs.get("top_a")
|
||||
gen_settings.min_p = kwargs.get("min_p")
|
||||
gen_settings.tfs = kwargs.get("tfs")
|
||||
gen_settings.typical = kwargs.get("typical")
|
||||
gen_settings.mirostat = kwargs.get("mirostat")
|
||||
gen_settings.skew = kwargs.get("skew")
|
||||
|
||||
# XTC
|
||||
xtc_probability = unwrap(kwargs.get("xtc_probability"), 0.0)
|
||||
xtc_probability = kwargs.get("xtc_probability")
|
||||
if xtc_probability > 0.0:
|
||||
gen_settings.xtc_probability = xtc_probability
|
||||
|
||||
# 0.1 is the default for this value
|
||||
gen_settings.xtc_threshold = unwrap(kwargs.get("xtc_threshold", 0.1))
|
||||
gen_settings.xtc_threshold = kwargs.get("xtc_threshold")
|
||||
|
||||
# DynaTemp settings
|
||||
max_temp = unwrap(kwargs.get("max_temp"), 1.0)
|
||||
min_temp = unwrap(kwargs.get("min_temp"), 1.0)
|
||||
max_temp = kwargs.get("max_temp")
|
||||
min_temp = kwargs.get("min_temp")
|
||||
|
||||
if max_temp > min_temp:
|
||||
gen_settings.max_temp = max_temp
|
||||
gen_settings.min_temp = min_temp
|
||||
gen_settings.temp_exponent = unwrap(kwargs.get("temp_exponent"), 1.0)
|
||||
gen_settings.temp_exponent = kwargs.get("temp_exponent")
|
||||
else:
|
||||
# Force to default values
|
||||
gen_settings.max_temp = 1.0
|
||||
|
|
@ -1116,11 +1115,11 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Default tau and eta fallbacks don't matter if mirostat is off
|
||||
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
|
||||
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
|
||||
gen_settings.mirostat_tau = kwargs.get("mirostat_tau")
|
||||
gen_settings.mirostat_eta = kwargs.get("mirostat_eta")
|
||||
|
||||
# Set CFG scale and negative prompt
|
||||
cfg_scale = unwrap(kwargs.get("cfg_scale"), 1.0)
|
||||
cfg_scale = kwargs.get("cfg_scale")
|
||||
negative_prompt = None
|
||||
if cfg_scale not in [None, 1.0]:
|
||||
if self.paged:
|
||||
|
|
@ -1139,15 +1138,9 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Penalties
|
||||
gen_settings.token_repetition_penalty = unwrap(
|
||||
kwargs.get("repetition_penalty"), 1.0
|
||||
)
|
||||
gen_settings.token_frequency_penalty = unwrap(
|
||||
kwargs.get("frequency_penalty"), 0.0
|
||||
)
|
||||
gen_settings.token_presence_penalty = unwrap(
|
||||
kwargs.get("presence_penalty"), 0.0
|
||||
)
|
||||
gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty")
|
||||
gen_settings.token_frequency_penalty = kwargs.get("frequency_penalty")
|
||||
gen_settings.token_presence_penalty = kwargs.get("presence_penalty")
|
||||
|
||||
# Applies for all penalties despite being called token_repetition_range
|
||||
gen_settings.token_repetition_range = unwrap(
|
||||
|
|
@ -1175,16 +1168,14 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# DRY options
|
||||
dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 0.0)
|
||||
dry_multiplier = kwargs.get("dry_multiplier")
|
||||
|
||||
# < 0 = disabled
|
||||
if dry_multiplier > 0:
|
||||
gen_settings.dry_multiplier = dry_multiplier
|
||||
|
||||
gen_settings.dry_allowed_length = unwrap(
|
||||
kwargs.get("dry_allowed_length"), 0
|
||||
)
|
||||
gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 0.0)
|
||||
gen_settings.dry_allowed_length = kwargs.get("dry_allowed_length")
|
||||
gen_settings.dry_base = kwargs.get("dry_base")
|
||||
|
||||
# Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range
|
||||
# Use max_seq_len as the fallback to stay consistent
|
||||
|
|
@ -1203,24 +1194,24 @@ class ExllamaV2Container:
|
|||
grammar_handler = ExLlamaV2Grammar()
|
||||
|
||||
# Add JSON schema filter if it exists
|
||||
json_schema = unwrap(kwargs.get("json_schema"))
|
||||
json_schema = kwargs.get("json_schema")
|
||||
if json_schema:
|
||||
grammar_handler.add_json_schema_filter(
|
||||
json_schema, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Add regex filter if it exists
|
||||
regex_pattern = unwrap(kwargs.get("regex_pattern"))
|
||||
regex_pattern = kwargs.get("regex_pattern")
|
||||
if regex_pattern:
|
||||
grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer)
|
||||
|
||||
# Add EBNF filter if it exists
|
||||
grammar_string = unwrap(kwargs.get("grammar_string"))
|
||||
grammar_string = kwargs.get("grammar_string")
|
||||
if grammar_string:
|
||||
grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer)
|
||||
|
||||
# Set banned strings
|
||||
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
|
||||
banned_strings = kwargs.get("banned_strings")
|
||||
if banned_strings and len(grammar_handler.filters) > 0:
|
||||
logger.warning(
|
||||
"Disabling banned_strings because "
|
||||
|
|
@ -1229,18 +1220,16 @@ class ExllamaV2Container:
|
|||
|
||||
banned_strings = []
|
||||
|
||||
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
|
||||
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
|
||||
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
|
||||
stop_conditions = kwargs.get("stop")
|
||||
add_bos_token = kwargs.get("add_bos_token"), True
|
||||
ban_eos_token = kwargs.get("ban_eos_token"), False
|
||||
logit_bias = kwargs.get("logit_bias")
|
||||
|
||||
# Logprobs
|
||||
request_logprobs = unwrap(kwargs.get("logprobs"), 0)
|
||||
request_logprobs = kwargs.get("logprobs")
|
||||
|
||||
# Speculative Ngram
|
||||
self.generator.speculative_ngram = unwrap(
|
||||
kwargs.get("speculative_ngram"), False
|
||||
)
|
||||
self.generator.speculative_ngram = kwargs.get("speculative_ngram")
|
||||
|
||||
# Override sampler settings for temp = 0
|
||||
if gen_settings.temperature == 0:
|
||||
|
|
@ -1255,12 +1244,12 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Set banned tokens
|
||||
banned_tokens = unwrap(kwargs.get("banned_tokens"), [])
|
||||
banned_tokens = kwargs.get("banned_tokens")
|
||||
if banned_tokens:
|
||||
gen_settings.disallow_tokens(self.tokenizer, banned_tokens)
|
||||
|
||||
# Set allowed tokens
|
||||
allowed_tokens = unwrap(kwargs.get("allowed_tokens"), [])
|
||||
allowed_tokens = kwargs.get("allowed_tokens")
|
||||
if allowed_tokens:
|
||||
gen_settings.allow_tokens(self.tokenizer, allowed_tokens)
|
||||
|
||||
|
|
@ -1361,10 +1350,10 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Set min_tokens to generate while keeping EOS banned
|
||||
min_tokens = unwrap(kwargs.get("min_tokens"), 0)
|
||||
min_tokens = kwargs.get("min_tokens")
|
||||
|
||||
# This is an inverse of skip_special_tokens
|
||||
decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False)
|
||||
decode_special_tokens = not kwargs.get("skip_special_tokens")
|
||||
|
||||
# Log prompt to console. Add the BOS token if specified
|
||||
log_prompt(
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
)
|
||||
|
||||
generate_window: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("generate_window"),
|
||||
default_factory=lambda: get_default_sampler_value("generate_window", 512),
|
||||
examples=[512],
|
||||
ge=0,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue