Adds aliases for min_temp and max_temp (#58)

* Adds aliases for min_temp and max_temp

* Sampling: Add dynatemp_exponent alias
This commit is contained in:
erinmaybe 2024-02-03 21:51:29 -05:00 committed by GitHub
parent a769d90bad
commit fa2acb2828
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -43,19 +43,6 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("temperature_last", False)
)
max_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("max_temp", 0.0),
)
min_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("min_temp", 0.0),
)
temp_exponent: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
examples=[1.0],
)
smoothing_factor: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
)
@ -154,6 +141,24 @@ class BaseSamplerRequest(BaseModel):
examples=[1.0],
)
max_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("max_temp", 0.0),
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
description="Aliases: dynatemp_high",
)
min_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("min_temp", 0.0),
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
description="Aliases: dynatemp_low",
)
temp_exponent: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
examples=[1.0],
)
def to_gen_params(self):
"""Converts samplers to internal generation params"""