From d7c18855e7727e307e16923e0638c1371c61ea32 Mon Sep 17 00:00:00 2001 From: Alexander Abushady <44341163+AAbushady@users.noreply.github.com> Date: Fri, 2 Feb 2024 22:12:59 -0500 Subject: [PATCH] added quadratic sampling (#56) * added quadratic sampling * Update sample_preset.yml * oops missed a spot * Sampling: Fix smoothing factor semantics --- backends/exllamav2/model.py | 9 +++++++++ common/sampling.py | 5 +++++ sampler_overrides/sample_preset.yml | 3 +++ 3 files changed, 17 insertions(+) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ec2ec61..f908402 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -523,6 +523,14 @@ class ExllamaV2Container: "installed ExLlamaV2 version." ) + if (unwrap(kwargs.get("smoothing_factor"), 0.0)) > 0.0 and not hasattr( + ExLlamaV2Sampler.Settings, "smoothing_factor" + ): + logger.warning( + "Smoothing factor is not supported by the currently " + "installed ExLlamaV2 version." + ) + def generate(self, prompt: str, **kwargs): """Generate a response to a prompt""" generation = list(self.generate_gen(prompt, **kwargs)) @@ -593,6 +601,7 @@ class ExllamaV2Container: # Apply settings gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0) gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False) + gen_settings.smoothing_factor = unwrap(kwargs.get("smoothing_factor"), 0.0) gen_settings.top_k = unwrap(kwargs.get("top_k"), 0) gen_settings.top_p = unwrap(kwargs.get("top_p"), 1.0) gen_settings.top_a = unwrap(kwargs.get("top_a"), 0.0) diff --git a/common/sampling.py b/common/sampling.py index 0f84cb5..fac0678 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -56,6 +56,10 @@ class SamplerParams(BaseModel): examples=[1.0], ) + smoothing_factor: Optional[float] = Field( + default_factor=lambda: get_default_sampler_value("smoothing_factor", 0.0), + ) + top_k: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("top_k", 0) ) @@ -173,6 +177,7 @@ class SamplerParams(BaseModel): "min_temp": self.min_temp, "max_temp": self.max_temp, "temp_exponent": self.temp_exponent, + "smoothing_factor": self.smoothing_factor, "top_k": self.top_k, "top_p": self.top_p, "top_a": self.top_a, diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index e6e2258..3d1c42a 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -39,6 +39,9 @@ max_temp: temp_exponent: override: 0.0 force: false +smoothing_factor: + override: 0.0 + force: false # MARK: Alphabet soup top_k: