added quadratic sampling (#56)

* added quadratic sampling

* Update sample_preset.yml

* oops missed a spot

* Sampling: Fix smoothing factor semantics
This commit is contained in:
Alexander Abushady 2024-02-02 22:12:59 -05:00 committed by GitHub
parent 4a7b8b1b7a
commit d7c18855e7
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 17 additions and 0 deletions

View file

@ -523,6 +523,14 @@ class ExllamaV2Container:
"installed ExLlamaV2 version."
)
if (unwrap(kwargs.get("smoothing_factor"), 0.0)) > 0.0 and not hasattr(
ExLlamaV2Sampler.Settings, "smoothing_factor"
):
logger.warning(
"Smoothing factor is not supported by the currently "
"installed ExLlamaV2 version."
)
def generate(self, prompt: str, **kwargs):
"""Generate a response to a prompt"""
generation = list(self.generate_gen(prompt, **kwargs))
@ -593,6 +601,7 @@ class ExllamaV2Container:
# Apply settings
gen_settings.temperature = unwrap(kwargs.get("temperature"), 1.0)
gen_settings.temperature_last = unwrap(kwargs.get("temperature_last"), False)
gen_settings.smoothing_factor = unwrap(kwargs.get("smoothing_factor"), 0.0)
gen_settings.top_k = unwrap(kwargs.get("top_k"), 0)
gen_settings.top_p = unwrap(kwargs.get("top_p"), 1.0)
gen_settings.top_a = unwrap(kwargs.get("top_a"), 0.0)

View file

@ -56,6 +56,10 @@ class SamplerParams(BaseModel):
examples=[1.0],
)
smoothing_factor: Optional[float] = Field(
default_factor=lambda: get_default_sampler_value("smoothing_factor", 0.0),
)
top_k: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("top_k", 0)
)
@ -173,6 +177,7 @@ class SamplerParams(BaseModel):
"min_temp": self.min_temp,
"max_temp": self.max_temp,
"temp_exponent": self.temp_exponent,
"smoothing_factor": self.smoothing_factor,
"top_k": self.top_k,
"top_p": self.top_p,
"top_a": self.top_a,

View file

@ -39,6 +39,9 @@ max_temp:
temp_exponent:
override: 0.0
force: false
smoothing_factor:
override: 0.0
force: false
# MARK: Alphabet soup
top_k: