diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 4ffded7..3455455 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1091,7 +1091,7 @@ class ExllamaV2Container: gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0) gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0) gen_settings.typical = unwrap(kwargs.get("typical"), 1.0) - gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False) + gen_settings.mirostat = unwrap(kwargs.get("mirostat_mode"), 0) == 2 gen_settings.skew = unwrap(kwargs.get("skew"), 0) # XTC diff --git a/common/sampling.py b/common/sampling.py index 7e5ded4..456c815 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -195,10 +195,9 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", []) ) - mirostat: Optional[bool] = False - mirostat_mode: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("mirostat_mode", 0) + default_factory=lambda: get_default_sampler_value("mirostat_mode", 0), + alias=AliasChoices("mirostat_mode", "mirostat"), ) mirostat_tau: Optional[float] = Field( @@ -325,15 +324,6 @@ class BaseSamplerRequest(BaseModel): ) return [] # Return empty list if parsing fails - @field_validator("mirostat_mode", mode="before") - def convert_mirostat(cls, v, field_info): - """Mirostat is enabled if mirostat_mode == 2.""" - - if v == 2: - field_info.data["mirostat"] = True - - return v - @model_validator(mode="after") def after_validate(self): # FIXME: find a better way to register this