From dcb36e9ab2bae39a0c46baecbfadb3a4da5cb74c Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Tue, 15 Apr 2025 23:38:46 -0400 Subject: [PATCH] Model: Remove extra unwraps The base sampler request already specifies the defaults, so don't unwrap in this way. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- backends/exllamav2/model.py | 91 ++++++++++++++++--------------------- common/sampling.py | 2 +- 2 files changed, 41 insertions(+), 52 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index e0c7cce..6accc88 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1059,9 +1059,9 @@ class ExllamaV2Container: prompts = [prompt] - token_healing = unwrap(kwargs.get("token_healing"), False) + token_healing = kwargs.get("token_healing") generate_window = max( - unwrap(kwargs.get("generate_window"), 512), self.config.max_seq_len // 8 + kwargs.get("generate_window"), self.config.max_seq_len // 8 ) # Sampler settings @@ -1071,35 +1071,34 @@ class ExllamaV2Container: kwargs = self.check_unsupported_settings(**kwargs) # Apply settings - partial(gen_settings.temperature, 1.0) - gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0) - gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False) - gen_settings.smoothing_factor = unwrap(kwargs.get("smoothing_factor"), 0.0) - gen_settings.top_k = unwrap(kwargs.get("top_k"), 0) - gen_settings.top_p = unwrap(kwargs.get("top_p"), 1.0) - gen_settings.top_a = unwrap(kwargs.get("top_a"), 0.0) - gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0) - gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0) - gen_settings.typical = unwrap(kwargs.get("typical"), 1.0) - gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False) - gen_settings.skew = unwrap(kwargs.get("skew"), 0) + gen_settings.temperature = kwargs.get("temperature") + gen_settings.temperature_last = kwargs.get("temperature_last") + gen_settings.smoothing_factor = kwargs.get("smoothing_factor") + gen_settings.top_k = kwargs.get("top_k") + gen_settings.top_p = kwargs.get("top_p") + gen_settings.top_a = kwargs.get("top_a") + gen_settings.min_p = kwargs.get("min_p") + gen_settings.tfs = kwargs.get("tfs") + gen_settings.typical = kwargs.get("typical") + gen_settings.mirostat = kwargs.get("mirostat") + gen_settings.skew = kwargs.get("skew") # XTC - xtc_probability = unwrap(kwargs.get("xtc_probability"), 0.0) + xtc_probability = kwargs.get("xtc_probability") if xtc_probability > 0.0: gen_settings.xtc_probability = xtc_probability # 0.1 is the default for this value - gen_settings.xtc_threshold = unwrap(kwargs.get("xtc_threshold", 0.1)) + gen_settings.xtc_threshold = kwargs.get("xtc_threshold") # DynaTemp settings - max_temp = unwrap(kwargs.get("max_temp"), 1.0) - min_temp = unwrap(kwargs.get("min_temp"), 1.0) + max_temp = kwargs.get("max_temp") + min_temp = kwargs.get("min_temp") if max_temp > min_temp: gen_settings.max_temp = max_temp gen_settings.min_temp = min_temp - gen_settings.temp_exponent = unwrap(kwargs.get("temp_exponent"), 1.0) + gen_settings.temp_exponent = kwargs.get("temp_exponent") else: # Force to default values gen_settings.max_temp = 1.0 @@ -1116,11 +1115,11 @@ class ExllamaV2Container: ) # Default tau and eta fallbacks don't matter if mirostat is off - gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5) - gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1) + gen_settings.mirostat_tau = kwargs.get("mirostat_tau") + gen_settings.mirostat_eta = kwargs.get("mirostat_eta") # Set CFG scale and negative prompt - cfg_scale = unwrap(kwargs.get("cfg_scale"), 1.0) + cfg_scale = kwargs.get("cfg_scale") negative_prompt = None if cfg_scale not in [None, 1.0]: if self.paged: @@ -1139,15 +1138,9 @@ class ExllamaV2Container: ) # Penalties - gen_settings.token_repetition_penalty = unwrap( - kwargs.get("repetition_penalty"), 1.0 - ) - gen_settings.token_frequency_penalty = unwrap( - kwargs.get("frequency_penalty"), 0.0 - ) - gen_settings.token_presence_penalty = unwrap( - kwargs.get("presence_penalty"), 0.0 - ) + gen_settings.token_repetition_penalty = kwargs.get("repetition_penalty") + gen_settings.token_frequency_penalty = kwargs.get("frequency_penalty") + gen_settings.token_presence_penalty = kwargs.get("presence_penalty") # Applies for all penalties despite being called token_repetition_range gen_settings.token_repetition_range = unwrap( @@ -1175,16 +1168,14 @@ class ExllamaV2Container: ) # DRY options - dry_multiplier = unwrap(kwargs.get("dry_multiplier"), 0.0) + dry_multiplier = kwargs.get("dry_multiplier") # < 0 = disabled if dry_multiplier > 0: gen_settings.dry_multiplier = dry_multiplier - gen_settings.dry_allowed_length = unwrap( - kwargs.get("dry_allowed_length"), 0 - ) - gen_settings.dry_base = unwrap(kwargs.get("dry_base"), 0.0) + gen_settings.dry_allowed_length = kwargs.get("dry_allowed_length") + gen_settings.dry_base = kwargs.get("dry_base") # Exl2 has dry_range as 0 for unlimited unlike -1 for penalty_range # Use max_seq_len as the fallback to stay consistent @@ -1203,24 +1194,24 @@ class ExllamaV2Container: grammar_handler = ExLlamaV2Grammar() # Add JSON schema filter if it exists - json_schema = unwrap(kwargs.get("json_schema")) + json_schema = kwargs.get("json_schema") if json_schema: grammar_handler.add_json_schema_filter( json_schema, self.model, self.tokenizer ) # Add regex filter if it exists - regex_pattern = unwrap(kwargs.get("regex_pattern")) + regex_pattern = kwargs.get("regex_pattern") if regex_pattern: grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer) # Add EBNF filter if it exists - grammar_string = unwrap(kwargs.get("grammar_string")) + grammar_string = kwargs.get("grammar_string") if grammar_string: grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer) # Set banned strings - banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), []) + banned_strings = kwargs.get("banned_strings") if banned_strings and len(grammar_handler.filters) > 0: logger.warning( "Disabling banned_strings because " @@ -1229,18 +1220,16 @@ class ExllamaV2Container: banned_strings = [] - stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), []) - add_bos_token = unwrap(kwargs.get("add_bos_token"), True) - ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False) + stop_conditions = kwargs.get("stop") + add_bos_token = kwargs.get("add_bos_token"), True + ban_eos_token = kwargs.get("ban_eos_token"), False logit_bias = kwargs.get("logit_bias") # Logprobs - request_logprobs = unwrap(kwargs.get("logprobs"), 0) + request_logprobs = kwargs.get("logprobs") # Speculative Ngram - self.generator.speculative_ngram = unwrap( - kwargs.get("speculative_ngram"), False - ) + self.generator.speculative_ngram = kwargs.get("speculative_ngram") # Override sampler settings for temp = 0 if gen_settings.temperature == 0: @@ -1255,12 +1244,12 @@ class ExllamaV2Container: ) # Set banned tokens - banned_tokens = unwrap(kwargs.get("banned_tokens"), []) + banned_tokens = kwargs.get("banned_tokens") if banned_tokens: gen_settings.disallow_tokens(self.tokenizer, banned_tokens) # Set allowed tokens - allowed_tokens = unwrap(kwargs.get("allowed_tokens"), []) + allowed_tokens = kwargs.get("allowed_tokens") if allowed_tokens: gen_settings.allow_tokens(self.tokenizer, allowed_tokens) @@ -1361,10 +1350,10 @@ class ExllamaV2Container: ) # Set min_tokens to generate while keeping EOS banned - min_tokens = unwrap(kwargs.get("min_tokens"), 0) + min_tokens = kwargs.get("min_tokens") # This is an inverse of skip_special_tokens - decode_special_tokens = unwrap(not kwargs.get("skip_special_tokens"), False) + decode_special_tokens = not kwargs.get("skip_special_tokens") # Log prompt to console. Add the BOS token if specified log_prompt( diff --git a/common/sampling.py b/common/sampling.py index d2c230c..c7ef934 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -42,7 +42,7 @@ class BaseSamplerRequest(BaseModel): ) generate_window: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("generate_window"), + default_factory=lambda: get_default_sampler_value("generate_window", 512), examples=[512], ge=0, )