Sampling: Add universal validation system
Rather than maintaining yet another function to validate sampler ranges/values, embed them in fields which allows for less maintainence in the future. Also add validation for existing samplers that can corrupt the sampling stack if set improperly. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
9f1d891490
commit
7e730e3507
1 changed files with 48 additions and 3 deletions
|
|
@ -37,6 +37,8 @@ class BaseSamplerRequest(BaseModel):
|
|||
temperature: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value >= 0.0,
|
||||
validation_error="Temperature must be a non-negative value",
|
||||
)
|
||||
|
||||
temperature_last: Optional[bool] = Field(
|
||||
|
|
@ -45,14 +47,21 @@ class BaseSamplerRequest(BaseModel):
|
|||
|
||||
smoothing_factor: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
|
||||
sample_validator=lambda value: value >= 0.0,
|
||||
validation_error="Smoothing factor must be a non-negative value",
|
||||
)
|
||||
|
||||
top_k: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("top_k", 0)
|
||||
default_factory=lambda: get_default_sampler_value("top_k", 0),
|
||||
sample_validator=lambda value: value >= 0,
|
||||
validation_error="Top K must be a non-negative value",
|
||||
)
|
||||
|
||||
top_p: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("top_p", 1.0), examples=[1.0]
|
||||
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value >= 0.0 and value <= 1.0,
|
||||
validation_error="Top P must be in [0, 1]",
|
||||
)
|
||||
|
||||
top_a: Optional[float] = Field(
|
||||
|
|
@ -64,7 +73,8 @@ class BaseSamplerRequest(BaseModel):
|
|||
)
|
||||
|
||||
tfs: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("tfs", 1.0)
|
||||
default_factory=lambda: get_default_sampler_value("tfs", 1.0),
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
frequency_penalty: Optional[float] = Field(
|
||||
|
|
@ -78,6 +88,8 @@ class BaseSamplerRequest(BaseModel):
|
|||
repetition_penalty: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value > 0.0,
|
||||
validation_error="Repetition penalty must be a positive value",
|
||||
)
|
||||
|
||||
repetition_decay: Optional[int] = Field(
|
||||
|
|
@ -122,6 +134,8 @@ class BaseSamplerRequest(BaseModel):
|
|||
validation_alias=AliasChoices("typical", "typical_p"),
|
||||
description="Aliases: typical_p",
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value > 0.0 and value <= 1.0,
|
||||
validation_error="Typical must be in (0, 1]",
|
||||
)
|
||||
|
||||
penalty_range: Optional[int] = Field(
|
||||
|
|
@ -145,26 +159,57 @@ class BaseSamplerRequest(BaseModel):
|
|||
default_factory=lambda: get_default_sampler_value("max_temp", 1.0),
|
||||
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
|
||||
description="Aliases: dynatemp_high",
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value >= 0.0,
|
||||
validation_error="Max temperature must be a non-negative value",
|
||||
)
|
||||
|
||||
min_temp: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("min_temp", 1.0),
|
||||
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
|
||||
description="Aliases: dynatemp_low",
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value >= 0.0,
|
||||
validation_error="Min temperature must be a non-negative value",
|
||||
)
|
||||
|
||||
temp_exponent: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
|
||||
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
|
||||
examples=[1.0],
|
||||
sample_validator=lambda value: value >= 0.0,
|
||||
validation_error="Temperature exponent must be a non-negative value",
|
||||
)
|
||||
|
||||
def validate_params(self):
|
||||
"""
|
||||
Validates if the class field satisfies a condition if present.
|
||||
|
||||
Validators are present in the extras section of a Pydantic field
|
||||
to make it easy for adding more samplers if needed.
|
||||
"""
|
||||
|
||||
for field_name, field_info in self.model_fields.items():
|
||||
extra_field_info = unwrap(field_info.json_schema_extra, {})
|
||||
if not extra_field_info:
|
||||
continue
|
||||
|
||||
sample_validator = extra_field_info.get("sample_validator")
|
||||
validation_error = unwrap(extra_field_info.get("validation_error"), "")
|
||||
|
||||
if sample_validator:
|
||||
value = getattr(self, field_name)
|
||||
if not sample_validator(value):
|
||||
raise ValueError(f"{validation_error}. Got {value}")
|
||||
|
||||
def to_gen_params(self, **kwargs):
|
||||
"""Converts samplers to internal generation params"""
|
||||
|
||||
# Add forced overrides if present
|
||||
apply_forced_sampler_overrides(self)
|
||||
|
||||
self.validate_params()
|
||||
|
||||
# Convert stop to an array of strings
|
||||
if isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue