diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 6d69f63..f69149a 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1083,7 +1083,7 @@ class ExllamaV2Container: gen_settings.min_p = params.min_p gen_settings.tfs = params.tfs gen_settings.typical = params.typical - gen_settings.mirostat = params.mirostat + gen_settings.mirostat = params.mirostat_mode == 2 gen_settings.skew = params.skew # XTC diff --git a/common/sampling.py b/common/sampling.py index 1b4bc69..5eb2dc8 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -195,10 +195,9 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", []) ) - mirostat: Optional[bool] = False - mirostat_mode: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("mirostat_mode", 0) + default_factory=lambda: get_default_sampler_value("mirostat_mode", 0), + alias=AliasChoices("mirostat_mode", "mirostat"), ) mirostat_tau: Optional[float] = Field( @@ -330,15 +329,6 @@ class BaseSamplerRequest(BaseModel): ) return [] # Return empty list if parsing fails - @field_validator("mirostat_mode", mode="before") - def convert_mirostat(cls, v, field_info): - """Mirostat is enabled if mirostat_mode == 2.""" - - if v == 2: - field_info.data["mirostat"] = True - - return v - @model_validator(mode="after") def after_validate(self): # FIXME: find a better way to register this