Sampling: Make validators simpler
Injecting into Pydantic fields caused issues with serialization for documentation rendering. Rather than reinvent the wheel again, switch to a chain of if statements for now. This may change in the future if subclasses from the base sampler request need to be validated as well. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
f627485534
commit
a79c42ff4c
1 changed files with 49 additions and 32 deletions
|
|
@ -37,8 +37,6 @@ class BaseSamplerRequest(BaseModel):
|
|||
temperature: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value >= 0.0,
|
||||
validation_error="Temperature must be a non-negative value",
|
||||
)
|
||||
|
||||
temperature_last: Optional[bool] = Field(
|
||||
|
|
@ -47,21 +45,15 @@ class BaseSamplerRequest(BaseModel):
|
|||
|
||||
smoothing_factor: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
|
||||
sample_validator=lambda value: value >= 0.0,
|
||||
validation_error="Smoothing factor must be a non-negative value",
|
||||
)
|
||||
|
||||
top_k: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("top_k", 0),
|
||||
sample_validator=lambda value: value >= 0,
|
||||
validation_error="Top K must be a non-negative value",
|
||||
)
|
||||
|
||||
top_p: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value >= 0.0 and value <= 1.0,
|
||||
validation_error="Top P must be in [0, 1]",
|
||||
)
|
||||
|
||||
top_a: Optional[float] = Field(
|
||||
|
|
@ -88,8 +80,6 @@ class BaseSamplerRequest(BaseModel):
|
|||
repetition_penalty: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value > 0.0,
|
||||
validation_error="Repetition penalty must be a positive value",
|
||||
)
|
||||
|
||||
repetition_decay: Optional[int] = Field(
|
||||
|
|
@ -134,8 +124,6 @@ class BaseSamplerRequest(BaseModel):
|
|||
validation_alias=AliasChoices("typical", "typical_p"),
|
||||
description="Aliases: typical_p",
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value > 0.0 and value <= 1.0,
|
||||
validation_error="Typical must be in (0, 1]",
|
||||
)
|
||||
|
||||
penalty_range: Optional[int] = Field(
|
||||
|
|
@ -160,8 +148,6 @@ class BaseSamplerRequest(BaseModel):
|
|||
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
|
||||
description="Aliases: dynatemp_high",
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value >= 0.0,
|
||||
validation_error="Max temperature must be a non-negative value",
|
||||
)
|
||||
|
||||
min_temp: Optional[float] = Field(
|
||||
|
|
@ -169,38 +155,69 @@ class BaseSamplerRequest(BaseModel):
|
|||
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
|
||||
description="Aliases: dynatemp_low",
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value >= 0.0,
|
||||
validation_error="Min temperature must be a non-negative value",
|
||||
)
|
||||
|
||||
temp_exponent: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
|
||||
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value >= 0.0,
|
||||
validation_error="Temperature exponent must be a non-negative value",
|
||||
)
|
||||
|
||||
# TODO: Return back to adaptable class-based validation But that's just too much
|
||||
# abstraction compared to simple if statements at the moment
|
||||
def validate_params(self):
|
||||
"""
|
||||
Validates if the class field satisfies a condition if present.
|
||||
|
||||
Validators are present in the extras section of a Pydantic field
|
||||
to make it easy for adding more samplers if needed.
|
||||
Validates sampler parameters to be within sane ranges.
|
||||
"""
|
||||
|
||||
for field_name, field_info in self.model_fields.items():
|
||||
extra_field_info = unwrap(field_info.json_schema_extra, {})
|
||||
if not extra_field_info:
|
||||
continue
|
||||
# Temperature
|
||||
if self.temperature < 0.0:
|
||||
raise ValueError(
|
||||
"Temperature must be a non-negative value. " f"Got {self.temperature}"
|
||||
)
|
||||
|
||||
sample_validator = extra_field_info.get("sample_validator")
|
||||
validation_error = unwrap(extra_field_info.get("validation_error"), "")
|
||||
# Smoothing factor
|
||||
if self.smoothing_factor < 0.0:
|
||||
raise ValueError(
|
||||
"Smoothing factor must be a non-negative value. "
|
||||
f"Got {self.smoothing_factor}"
|
||||
)
|
||||
|
||||
if sample_validator:
|
||||
value = getattr(self, field_name)
|
||||
if not sample_validator(value):
|
||||
raise ValueError(f"{validation_error}. Got {value}")
|
||||
# Top K
|
||||
if self.top_k < 0:
|
||||
raise ValueError("Top K must be a non-negative value. " f"Got {self.top_k}")
|
||||
|
||||
# Top P
|
||||
if self.top_p < 0.0 or self.top_p > 1.0:
|
||||
raise ValueError("Top P must be in [0, 1]. " f"Got {self.top_p}")
|
||||
|
||||
# Repetition Penalty
|
||||
if self.repetition_penalty <= 0.0:
|
||||
raise ValueError(
|
||||
"Repetition penalty must be a positive value. "
|
||||
f"Got {self.repetition_penalty}"
|
||||
)
|
||||
|
||||
# Typical
|
||||
if self.typical <= 0 and self.typical > 1:
|
||||
raise ValueError("Typical must be in (0, 1]. " f"Got {self.typical}")
|
||||
|
||||
# Dynatemp values
|
||||
if self.max_temp < 0.0:
|
||||
raise ValueError(
|
||||
"Max temp must be a non-negative value. ", f"Got {self.max_temp}"
|
||||
)
|
||||
|
||||
if self.min_temp < 0.0:
|
||||
raise ValueError(
|
||||
"Min temp must be a non-negative value. ", f"Got {self.min_temp}"
|
||||
)
|
||||
|
||||
if self.temp_exponent < 0.0:
|
||||
raise ValueError(
|
||||
"Temp exponent must be a non-negative value. ",
|
||||
f"Got {self.temp_exponent}",
|
||||
)
|
||||
|
||||
def to_gen_params(self, **kwargs):
|
||||
"""Converts samplers to internal generation params"""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue