From fc4570220cf0e5b42d5f5c999a7c0895228c750a Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 25 Jan 2024 00:11:30 -0500 Subject: [PATCH] API + Model: Add new parameters and clean up documentation The example JSON fields were changed because of the new sampler default strategy. Fix these by manually changing the values. Also add support for fasttensors and expose generate_window to the API. It's recommended to not adjust generate_window as it's dynamically scaled based on max_seq_len by default. Signed-off-by: kingbri --- backends/exllamav2/model.py | 17 ++++++++++++++-- common/sampling.py | 30 +++++++++++++++++++++-------- config_sample.yml | 3 +++ sampler_overrides/sample_preset.yml | 5 +++++ 4 files changed, 45 insertions(+), 10 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ac939d4..52764e2 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -138,13 +138,25 @@ class ExllamaV2Container: kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len) ) + # Enable CFG if present + use_cfg = unwrap(kwargs.get("use_cfg"), False) if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"): - self.use_cfg = unwrap(kwargs.get("use_cfg"), False) - else: + self.use_cfg = use_cfg + elif use_cfg: logger.warning( "CFG is not supported by the currently installed ExLlamaV2 version." ) + # Enable fasttensors loading if present + use_fasttensors = unwrap(kwargs.get("fasttensors"), False) + if hasattr(ExLlamaV2Config, "fasttensors"): + self.config.fasttensors = use_fasttensors + elif use_fasttensors: + logger.warning( + "fasttensors is not supported by " + "the currently installed ExllamaV2 version." + ) + # Turn off flash attention if CFG is on # Workaround until batched FA2 is fixed in exllamav2 upstream self.config.no_flash_attn = ( @@ -668,6 +680,7 @@ class ExllamaV2Container: **vars(gen_settings), token_healing=token_healing, auto_scale_penalty_range=auto_scale_penalty_range, + generate_window=generate_window, add_bos_token=add_bos_token, ban_eos_token=ban_eos_token, stop_conditions=stop_conditions, diff --git a/common/sampling.py b/common/sampling.py index 53defcc..8c28002 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -17,7 +17,13 @@ class SamplerParams(BaseModel): """Common class for sampler params that are used in APIs""" max_tokens: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("max_tokens", 150) + default_factory=lambda: get_default_sampler_value("max_tokens", 150), + examples=[150], + ) + + generate_window: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("generate_window"), + examples=[512], ) stop: Optional[Union[str, List[str]]] = Field( @@ -29,7 +35,8 @@ class SamplerParams(BaseModel): ) temperature: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("temperature", 1.0) + default_factory=lambda: get_default_sampler_value("temperature", 1.0), + examples=[1.0], ) temperature_last: Optional[bool] = Field( @@ -41,7 +48,7 @@ class SamplerParams(BaseModel): ) top_p: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("top_p", 1.0) + default_factory=lambda: get_default_sampler_value("top_p", 1.0), examples=[1.0] ) top_a: Optional[float] = Field( @@ -65,7 +72,8 @@ class SamplerParams(BaseModel): ) repetition_penalty: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0) + default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), + examples=[1.0], ) repetition_decay: Optional[int] = Field( @@ -77,11 +85,13 @@ class SamplerParams(BaseModel): ) mirostat_tau: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5) + default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5), + examples=[1.5], ) mirostat_eta: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3) + default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3), + examples=[0.3], ) add_bos_token: Optional[bool] = Field( @@ -89,7 +99,8 @@ class SamplerParams(BaseModel): ) ban_eos_token: Optional[bool] = Field( - default_factory=lambda: get_default_sampler_value("ban_eos_token", False) + default_factory=lambda: get_default_sampler_value("ban_eos_token", False), + examples=[False], ) logit_bias: Optional[Dict[int, float]] = Field( @@ -106,6 +117,7 @@ class SamplerParams(BaseModel): default_factory=lambda: get_default_sampler_value("typical", 1.0), validation_alias=AliasChoices("typical", "typical_p"), description="Aliases: typical_p", + examples=[1.0], ) penalty_range: Optional[int] = Field( @@ -122,6 +134,7 @@ class SamplerParams(BaseModel): default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0), validation_alias=AliasChoices("cfg_scale", "guidance_scale"), description="Aliases: guidance_scale", + examples=[1.0], ) def to_gen_params(self): @@ -135,8 +148,9 @@ class SamplerParams(BaseModel): self.stop = [self.stop] return { - "stop": self.stop, "max_tokens": self.max_tokens, + "generate_window": self.generate_window, + "stop": self.stop, "add_bos_token": self.add_bos_token, "ban_eos_token": self.ban_eos_token, "token_healing": self.token_healing, diff --git a/config_sample.yml b/config_sample.yml index 89368ac..cf1ddb5 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -97,6 +97,9 @@ model: # WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream) #use_cfg: False + # Enables fasttensors to possibly increase model loading speeds (default: False) + #fasttensors: true + # Options for draft models (speculative decoding). This will use more VRAM! #draft: # Overrides the directory to look for draft (default: models) diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml index 9c661a1..eae17ab 100644 --- a/sampler_overrides/sample_preset.yml +++ b/sampler_overrides/sample_preset.yml @@ -18,6 +18,11 @@ token_healing: override: false force: false +# Commented out because the default is dynamically scaled +#generate_window: + #override: 512 + #force: false + # MARK: Temperature temperature: override: 1.0