diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ae1d7e4..ec2ec61 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -515,6 +515,14 @@ class ExllamaV2Container: "installed ExLlamaV2 version." ) + if (unwrap(kwargs.get("max_temp"), 0.0)) > 0.0 and not hasattr( + ExLlamaV2Sampler.Settings, "max_temp" + ): + logger.warning( + "DynaTemp parameters are not supported by the currently " + "installed ExLlamaV2 version." + ) + def generate(self, prompt: str, **kwargs): """Generate a response to a prompt""" generation = list(self.generate_gen(prompt, **kwargs)) @@ -579,6 +587,7 @@ class ExllamaV2Container: # Sampler settings gen_settings = ExLlamaV2Sampler.Settings() + # TODO: Migrate settings validation to different function self.check_unsupported_settings(**kwargs) # Apply settings @@ -592,6 +601,22 @@ class ExllamaV2Container: gen_settings.typical = unwrap(kwargs.get("typical"), 1.0) gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False) + # DynaTemp settings + if hasattr(gen_settings, "max_temp"): + max_temp = unwrap(kwargs.get("max_temp"), 0.0) + min_temp = unwrap(kwargs.get("min_temp"), 0.0) + + if max_temp < min_temp or ( + 0 not in {min_temp, max_temp} and max_temp == min_temp + ): + logger.warning( + "Max temp is less than or equal to min temp, skipping DynaTemp." + ) + + gen_settings.max_temp = max_temp + gen_settings.min_temp = min_temp + gen_settings.temp_exponent = kwargs.get("temp_exponent") + # Default tau and eta fallbacks don't matter if mirostat is off gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5) gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1) diff --git a/common/sampling.py b/common/sampling.py index 8c28002..0f84cb5 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -43,6 +43,19 @@ class SamplerParams(BaseModel): default_factory=lambda: get_default_sampler_value("temperature_last", False) ) + max_temp: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("max_temp", 0.0), + ) + + min_temp: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("min_temp", 0.0), + ) + + temp_exponent: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0), + examples=[1.0], + ) + top_k: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("top_k", 0) ) @@ -157,6 +170,9 @@ class SamplerParams(BaseModel): "logit_bias": self.logit_bias, "temperature": self.temperature, "temperature_last": self.temperature_last, + "min_temp": self.min_temp, + "max_temp": self.max_temp, + "temp_exponent": self.temp_exponent, "top_k": self.top_k, "top_p": self.top_p, "top_a": self.top_a, diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index eae17ab..e6e2258 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -30,6 +30,15 @@ temperature: temperature_last: override: false force: false +min_temp: + override: 0.0 + force: false +max_temp: + override: 0.0 + force: false +temp_exponent: + override: 0.0 + force: false # MARK: Alphabet soup top_k: