From f1ea15d77eb67093abfc469f216db798a8b72fa2 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 2 Feb 2024 23:40:10 -0500 Subject: [PATCH] Model: Remove backwards compatability hacks Now that exllamav2 is required to be the latest, don't add attribute checks unless the feature is not in the release build. Signed-off-by: kingbri --- backends/exllamav2/model.py | 76 +++++++++++++------------------------ 1 file changed, 27 insertions(+), 49 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index a9b1549..84c6724 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -139,23 +139,10 @@ class ExllamaV2Container: ) # Enable CFG if present - use_cfg = unwrap(kwargs.get("use_cfg"), False) - if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"): - self.use_cfg = use_cfg - elif use_cfg: - logger.warning( - "CFG is not supported by the currently installed ExLlamaV2 version." - ) + self.use_cfg = unwrap(kwargs.get("use_cfg"), False) # Enable fasttensors loading if present - use_fasttensors = unwrap(kwargs.get("fasttensors"), False) - if hasattr(ExLlamaV2Config, "fasttensors"): - self.config.fasttensors = use_fasttensors - elif use_fasttensors: - logger.warning( - "fasttensors is not supported by " - "the currently installed ExllamaV2 version." - ) + self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False) # Turn off flash attention if CFG is on # Workaround until batched FA2 is fixed in exllamav2 upstream @@ -189,13 +176,7 @@ class ExllamaV2Container: # Set num of experts per token if provided num_experts_override = kwargs.get("num_experts_per_token") if num_experts_override: - if hasattr(self.config, "num_experts_per_token"): - self.config.num_experts_per_token = num_experts_override - else: - logger.warning( - "MoE experts per token override is not " - "supported by the current ExLlamaV2 version." - ) + self.config.num_experts_per_token = kwargs.get("num_experts_per_token") chunk_size = min( unwrap(kwargs.get("chunk_size"), 2048), self.config.max_seq_len @@ -410,6 +391,10 @@ class ExllamaV2Container: self.draft_cache, ) + # Always return logprobs and logits + self.generator.return_probabilities = True + self.generator.return_logits = True + logger.info("Model successfully loaded.") def unload(self, loras_only: bool = False): @@ -470,21 +455,7 @@ class ExllamaV2Container: def check_unsupported_settings(self, **kwargs): """Check and warn the user if a sampler is unsupported. Meant for dev wheels!""" - if kwargs.get("max_temp") > 0.0 and not hasattr( - ExLlamaV2Sampler.Settings, "max_temp" - ): - logger.warning( - "DynaTemp parameters are not supported by the currently " - "installed ExLlamaV2 version." - ) - - if (unwrap(kwargs.get("smoothing_factor"), 0.0)) > 0.0 and not hasattr( - ExLlamaV2Sampler.Settings, "smoothing_factor" - ): - logger.warning( - "Smoothing factor is not supported by the currently " - "installed ExLlamaV2 version." - ) + pass def generate(self, prompt: str, **kwargs): """Generate a response to a prompt""" @@ -566,20 +537,27 @@ class ExllamaV2Container: gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False) # DynaTemp settings - if hasattr(gen_settings, "max_temp"): - max_temp = unwrap(kwargs.get("max_temp"), 0.0) - min_temp = unwrap(kwargs.get("min_temp"), 0.0) - - if max_temp < min_temp or ( - 0 not in {min_temp, max_temp} and max_temp == min_temp - ): - logger.warning( - "Max temp is less than or equal to min temp, skipping DynaTemp." - ) + max_temp = unwrap(kwargs.get("max_temp"), 0.0) + min_temp = unwrap(kwargs.get("min_temp"), 0.0) + if max_temp > min_temp: gen_settings.max_temp = max_temp gen_settings.min_temp = min_temp - gen_settings.temp_exponent = kwargs.get("temp_exponent") + gen_settings.temp_exponent = unwrap(kwargs.get("temp_exponent"), 1.0) + else: + # Force to default values + gen_settings.max_temp = 0.0 + gen_settings.min_temp = 0.0 + gen_settings.temp_exponent = 1.0 + + # Warn if max/min temp values are > 0 + # and if they're less than or equal to each other + if max_temp < min_temp or ( + 0 not in {min_temp, max_temp} and max_temp == min_temp + ): + logger.warning( + "Max temp is less than or equal to min temp, skipping DynaTemp." + ) # Default tau and eta fallbacks don't matter if mirostat is off gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5) @@ -760,7 +738,7 @@ class ExllamaV2Container: gen_settings.token_repetition_range = generated_tokens # Generate - chunk, eos, tokens, _, *extra_parts = self.generator.stream() + chunk, eos, tokens, _, _ = self.generator.stream() if token_healing: # Extract healed token