Samplers: Add dynamic temperature
Does not work if max_temp is less than or equal to min_temp. Sampler validation will have to be refactored in the future, so the dynamic temperature check will also be changed. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
3605067898
commit
4a7b8b1b7a
3 changed files with 50 additions and 0 deletions
|
|
@ -515,6 +515,14 @@ class ExllamaV2Container:
|
||||||
"installed ExLlamaV2 version."
|
"installed ExLlamaV2 version."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (unwrap(kwargs.get("max_temp"), 0.0)) > 0.0 and not hasattr(
|
||||||
|
ExLlamaV2Sampler.Settings, "max_temp"
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"DynaTemp parameters are not supported by the currently "
|
||||||
|
"installed ExLlamaV2 version."
|
||||||
|
)
|
||||||
|
|
||||||
def generate(self, prompt: str, **kwargs):
|
def generate(self, prompt: str, **kwargs):
|
||||||
"""Generate a response to a prompt"""
|
"""Generate a response to a prompt"""
|
||||||
generation = list(self.generate_gen(prompt, **kwargs))
|
generation = list(self.generate_gen(prompt, **kwargs))
|
||||||
|
|
@ -579,6 +587,7 @@ class ExllamaV2Container:
|
||||||
# Sampler settings
|
# Sampler settings
|
||||||
gen_settings = ExLlamaV2Sampler.Settings()
|
gen_settings = ExLlamaV2Sampler.Settings()
|
||||||
|
|
||||||
|
# TODO: Migrate settings validation to different function
|
||||||
self.check_unsupported_settings(**kwargs)
|
self.check_unsupported_settings(**kwargs)
|
||||||
|
|
||||||
# Apply settings
|
# Apply settings
|
||||||
|
|
@ -592,6 +601,22 @@ class ExllamaV2Container:
|
||||||
gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
|
gen_settings.typical = unwrap(kwargs.get("typical"), 1.0)
|
||||||
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
|
gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False)
|
||||||
|
|
||||||
|
# DynaTemp settings
|
||||||
|
if hasattr(gen_settings, "max_temp"):
|
||||||
|
max_temp = unwrap(kwargs.get("max_temp"), 0.0)
|
||||||
|
min_temp = unwrap(kwargs.get("min_temp"), 0.0)
|
||||||
|
|
||||||
|
if max_temp < min_temp or (
|
||||||
|
0 not in {min_temp, max_temp} and max_temp == min_temp
|
||||||
|
):
|
||||||
|
logger.warning(
|
||||||
|
"Max temp is less than or equal to min temp, skipping DynaTemp."
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_settings.max_temp = max_temp
|
||||||
|
gen_settings.min_temp = min_temp
|
||||||
|
gen_settings.temp_exponent = kwargs.get("temp_exponent")
|
||||||
|
|
||||||
# Default tau and eta fallbacks don't matter if mirostat is off
|
# Default tau and eta fallbacks don't matter if mirostat is off
|
||||||
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
|
gen_settings.mirostat_tau = unwrap(kwargs.get("mirostat_tau"), 1.5)
|
||||||
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
|
gen_settings.mirostat_eta = unwrap(kwargs.get("mirostat_eta"), 0.1)
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,19 @@ class SamplerParams(BaseModel):
|
||||||
default_factory=lambda: get_default_sampler_value("temperature_last", False)
|
default_factory=lambda: get_default_sampler_value("temperature_last", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
max_temp: Optional[float] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("max_temp", 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
min_temp: Optional[float] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("min_temp", 0.0),
|
||||||
|
)
|
||||||
|
|
||||||
|
temp_exponent: Optional[float] = Field(
|
||||||
|
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
|
||||||
|
examples=[1.0],
|
||||||
|
)
|
||||||
|
|
||||||
top_k: Optional[int] = Field(
|
top_k: Optional[int] = Field(
|
||||||
default_factory=lambda: get_default_sampler_value("top_k", 0)
|
default_factory=lambda: get_default_sampler_value("top_k", 0)
|
||||||
)
|
)
|
||||||
|
|
@ -157,6 +170,9 @@ class SamplerParams(BaseModel):
|
||||||
"logit_bias": self.logit_bias,
|
"logit_bias": self.logit_bias,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"temperature_last": self.temperature_last,
|
"temperature_last": self.temperature_last,
|
||||||
|
"min_temp": self.min_temp,
|
||||||
|
"max_temp": self.max_temp,
|
||||||
|
"temp_exponent": self.temp_exponent,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"top_a": self.top_a,
|
"top_a": self.top_a,
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,15 @@ temperature:
|
||||||
temperature_last:
|
temperature_last:
|
||||||
override: false
|
override: false
|
||||||
force: false
|
force: false
|
||||||
|
min_temp:
|
||||||
|
override: 0.0
|
||||||
|
force: false
|
||||||
|
max_temp:
|
||||||
|
override: 0.0
|
||||||
|
force: false
|
||||||
|
temp_exponent:
|
||||||
|
override: 0.0
|
||||||
|
force: false
|
||||||
|
|
||||||
# MARK: Alphabet soup
|
# MARK: Alphabet soup
|
||||||
top_k:
|
top_k:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue