Sampling: Add universal validation system

Rather than maintaining yet another function to validate sampler
ranges/values, embed them in fields which allows for less
maintainence in the future.

Also add validation for existing samplers that can corrupt
the sampling stack if set improperly.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-02-10 01:47:13 -05:00
parent 9f1d891490
commit 7e730e3507

View file

@ -37,6 +37,8 @@ class BaseSamplerRequest(BaseModel):
temperature: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Temperature must be a non-negative value",
)
temperature_last: Optional[bool] = Field(
@ -45,14 +47,21 @@ class BaseSamplerRequest(BaseModel):
smoothing_factor: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
sample_validator=lambda value: value >= 0.0,
validation_error="Smoothing factor must be a non-negative value",
)
top_k: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("top_k", 0)
default_factory=lambda: get_default_sampler_value("top_k", 0),
sample_validator=lambda value: value >= 0,
validation_error="Top K must be a non-negative value",
)
top_p: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("top_p", 1.0), examples=[1.0]
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
examples=[1.0],
sample_validator=lambda value: value >= 0.0 and value <= 1.0,
validation_error="Top P must be in [0, 1]",
)
top_a: Optional[float] = Field(
@ -64,7 +73,8 @@ class BaseSamplerRequest(BaseModel):
)
tfs: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("tfs", 1.0)
default_factory=lambda: get_default_sampler_value("tfs", 1.0),
examples=[1.0],
)
frequency_penalty: Optional[float] = Field(
@ -78,6 +88,8 @@ class BaseSamplerRequest(BaseModel):
repetition_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
examples=[1.0],
sample_validator=lambda value: value > 0.0,
validation_error="Repetition penalty must be a positive value",
)
repetition_decay: Optional[int] = Field(
@ -122,6 +134,8 @@ class BaseSamplerRequest(BaseModel):
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
examples=[1.0],
sample_validator=lambda value: value > 0.0 and value <= 1.0,
validation_error="Typical must be in (0, 1]",
)
penalty_range: Optional[int] = Field(
@ -145,26 +159,57 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("max_temp", 1.0),
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
description="Aliases: dynatemp_high",
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Max temperature must be a non-negative value",
)
min_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("min_temp", 1.0),
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
description="Aliases: dynatemp_low",
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Min temperature must be a non-negative value",
)
temp_exponent: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
examples=[1.0],
sample_validator=lambda value: value >= 0.0,
validation_error="Temperature exponent must be a non-negative value",
)
def validate_params(self):
"""
Validates if the class field satisfies a condition if present.
Validators are present in the extras section of a Pydantic field
to make it easy for adding more samplers if needed.
"""
for field_name, field_info in self.model_fields.items():
extra_field_info = unwrap(field_info.json_schema_extra, {})
if not extra_field_info:
continue
sample_validator = extra_field_info.get("sample_validator")
validation_error = unwrap(extra_field_info.get("validation_error"), "")
if sample_validator:
value = getattr(self, field_name)
if not sample_validator(value):
raise ValueError(f"{validation_error}. Got {value}")
def to_gen_params(self, **kwargs):
"""Converts samplers to internal generation params"""
# Add forced overrides if present
apply_forced_sampler_overrides(self)
self.validate_params()
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]