Sampling: Make validators simpler

Injecting into Pydantic fields caused issues with serialization for
documentation rendering. Rather than reinvent the wheel again,
switch to a chain of if statements for now. This may change in the
future if subclasses from the base sampler request need to be
validated as well.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-02-11 15:22:43 -05:00
parent f627485534
commit a79c42ff4c

View file

@ -37,8 +37,6 @@ class BaseSamplerRequest(BaseModel):
temperature: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Temperature must be a non-negative value",
)
temperature_last: Optional[bool] = Field(
@ -47,21 +45,15 @@ class BaseSamplerRequest(BaseModel):
smoothing_factor: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
sample_validator=lambda value: value >= 0.0,
validation_error="Smoothing factor must be a non-negative value",
)
top_k: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("top_k", 0),
sample_validator=lambda value: value >= 0,
validation_error="Top K must be a non-negative value",
)
top_p: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
examples=[1.0],
sample_validator=lambda value: value >= 0.0 and value <= 1.0,
validation_error="Top P must be in [0, 1]",
)
top_a: Optional[float] = Field(
@ -88,8 +80,6 @@ class BaseSamplerRequest(BaseModel):
repetition_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
examples=[1.0],
sample_validator=lambda value: value > 0.0,
validation_error="Repetition penalty must be a positive value",
)
repetition_decay: Optional[int] = Field(
@ -134,8 +124,6 @@ class BaseSamplerRequest(BaseModel):
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
examples=[1.0],
sample_validator=lambda value: value > 0.0 and value <= 1.0,
validation_error="Typical must be in (0, 1]",
)
penalty_range: Optional[int] = Field(
@ -160,8 +148,6 @@ class BaseSamplerRequest(BaseModel):
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
description="Aliases: dynatemp_high",
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Max temperature must be a non-negative value",
)
min_temp: Optional[float] = Field(
@ -169,38 +155,69 @@ class BaseSamplerRequest(BaseModel):
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
description="Aliases: dynatemp_low",
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Min temperature must be a non-negative value",
)
temp_exponent: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Temperature exponent must be a non-negative value",
)
# TODO: Return back to adaptable class-based validation But that's just too much
# abstraction compared to simple if statements at the moment
def validate_params(self):
"""
Validates if the class field satisfies a condition if present.
Validators are present in the extras section of a Pydantic field
to make it easy for adding more samplers if needed.
Validates sampler parameters to be within sane ranges.
"""
for field_name, field_info in self.model_fields.items():
extra_field_info = unwrap(field_info.json_schema_extra, {})
if not extra_field_info:
continue
# Temperature
if self.temperature < 0.0:
raise ValueError(
"Temperature must be a non-negative value. " f"Got {self.temperature}"
)
sample_validator = extra_field_info.get("sample_validator")
validation_error = unwrap(extra_field_info.get("validation_error"), "")
# Smoothing factor
if self.smoothing_factor < 0.0:
raise ValueError(
"Smoothing factor must be a non-negative value. "
f"Got {self.smoothing_factor}"
)
if sample_validator:
value = getattr(self, field_name)
if not sample_validator(value):
raise ValueError(f"{validation_error}. Got {value}")
# Top K
if self.top_k < 0:
raise ValueError("Top K must be a non-negative value. " f"Got {self.top_k}")
# Top P
if self.top_p < 0.0 or self.top_p > 1.0:
raise ValueError("Top P must be in [0, 1]. " f"Got {self.top_p}")
# Repetition Penalty
if self.repetition_penalty <= 0.0:
raise ValueError(
"Repetition penalty must be a positive value. "
f"Got {self.repetition_penalty}"
)
# Typical
if self.typical <= 0 and self.typical > 1:
raise ValueError("Typical must be in (0, 1]. " f"Got {self.typical}")
# Dynatemp values
if self.max_temp < 0.0:
raise ValueError(
"Max temp must be a non-negative value. ", f"Got {self.max_temp}"
)
if self.min_temp < 0.0:
raise ValueError(
"Min temp must be a non-negative value. ", f"Got {self.min_temp}"
)
if self.temp_exponent < 0.0:
raise ValueError(
"Temp exponent must be a non-negative value. ",
f"Got {self.temp_exponent}",
)
def to_gen_params(self, **kwargs):
"""Converts samplers to internal generation params"""