From 6bb5f8f599d617f94af85e0818c8a841fc0ed806 Mon Sep 17 00:00:00 2001 From: kingbri <8082010+kingbri1@users.noreply.github.com> Date: Wed, 16 Apr 2025 02:13:55 -0400 Subject: [PATCH] Sampling: Rewrite mirostat_mode parameter Apparently the "mirostat" parameter has been updated by frontends to pass a number. ExllamaV2 expects a boolean, but most pass a number anyway, so just alias mirostat_mode and mirostat together. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> --- backends/exllamav2/model.py | 2 +- common/sampling.py | 14 ++------------ 2 files changed, 3 insertions(+), 13 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 4ffded7..3455455 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1091,7 +1091,7 @@ class ExllamaV2Container: gen_settings.min_p = unwrap(kwargs.get("min_p"), 0.0) gen_settings.tfs = unwrap(kwargs.get("tfs"), 1.0) gen_settings.typical = unwrap(kwargs.get("typical"), 1.0) - gen_settings.mirostat = unwrap(kwargs.get("mirostat"), False) + gen_settings.mirostat = unwrap(kwargs.get("mirostat_mode"), 0) == 2 gen_settings.skew = unwrap(kwargs.get("skew"), 0) # XTC diff --git a/common/sampling.py b/common/sampling.py index 7e5ded4..456c815 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -195,10 +195,9 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", []) ) - mirostat: Optional[bool] = False - mirostat_mode: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("mirostat_mode", 0) + default_factory=lambda: get_default_sampler_value("mirostat_mode", 0), + alias=AliasChoices("mirostat_mode", "mirostat"), ) mirostat_tau: Optional[float] = Field( @@ -325,15 +324,6 @@ class BaseSamplerRequest(BaseModel): ) return [] # Return empty list if parsing fails - @field_validator("mirostat_mode", mode="before") - def convert_mirostat(cls, v, field_info): - """Mirostat is enabled if mirostat_mode == 2.""" - - if v == 2: - field_info.data["mirostat"] = True - - return v - @model_validator(mode="after") def after_validate(self): # FIXME: find a better way to register this