API + Model: Add new parameters and clean up documentation
The example JSON fields were changed because of the new sampler default strategy. Fix these by manually changing the values. Also add support for fasttensors and expose generate_window to the API. It's recommended to not adjust generate_window as it's dynamically scaled based on max_seq_len by default. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
90fb41a77a
commit
fc4570220c
4 changed files with 45 additions and 10 deletions
|
|
@ -138,13 +138,25 @@ class ExllamaV2Container:
|
|||
kwargs.get("rope_alpha"), self.calculate_rope_alpha(base_seq_len)
|
||||
)
|
||||
|
||||
# Enable CFG if present
|
||||
use_cfg = unwrap(kwargs.get("use_cfg"), False)
|
||||
if hasattr(ExLlamaV2Sampler.Settings, "cfg_scale"):
|
||||
self.use_cfg = unwrap(kwargs.get("use_cfg"), False)
|
||||
else:
|
||||
self.use_cfg = use_cfg
|
||||
elif use_cfg:
|
||||
logger.warning(
|
||||
"CFG is not supported by the currently installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
# Enable fasttensors loading if present
|
||||
use_fasttensors = unwrap(kwargs.get("fasttensors"), False)
|
||||
if hasattr(ExLlamaV2Config, "fasttensors"):
|
||||
self.config.fasttensors = use_fasttensors
|
||||
elif use_fasttensors:
|
||||
logger.warning(
|
||||
"fasttensors is not supported by "
|
||||
"the currently installed ExllamaV2 version."
|
||||
)
|
||||
|
||||
# Turn off flash attention if CFG is on
|
||||
# Workaround until batched FA2 is fixed in exllamav2 upstream
|
||||
self.config.no_flash_attn = (
|
||||
|
|
@ -668,6 +680,7 @@ class ExllamaV2Container:
|
|||
**vars(gen_settings),
|
||||
token_healing=token_healing,
|
||||
auto_scale_penalty_range=auto_scale_penalty_range,
|
||||
generate_window=generate_window,
|
||||
add_bos_token=add_bos_token,
|
||||
ban_eos_token=ban_eos_token,
|
||||
stop_conditions=stop_conditions,
|
||||
|
|
|
|||
|
|
@ -17,7 +17,13 @@ class SamplerParams(BaseModel):
|
|||
"""Common class for sampler params that are used in APIs"""
|
||||
|
||||
max_tokens: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("max_tokens", 150)
|
||||
default_factory=lambda: get_default_sampler_value("max_tokens", 150),
|
||||
examples=[150],
|
||||
)
|
||||
|
||||
generate_window: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("generate_window"),
|
||||
examples=[512],
|
||||
)
|
||||
|
||||
stop: Optional[Union[str, List[str]]] = Field(
|
||||
|
|
@ -29,7 +35,8 @@ class SamplerParams(BaseModel):
|
|||
)
|
||||
|
||||
temperature: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("temperature", 1.0)
|
||||
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
temperature_last: Optional[bool] = Field(
|
||||
|
|
@ -41,7 +48,7 @@ class SamplerParams(BaseModel):
|
|||
)
|
||||
|
||||
top_p: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("top_p", 1.0)
|
||||
default_factory=lambda: get_default_sampler_value("top_p", 1.0), examples=[1.0]
|
||||
)
|
||||
|
||||
top_a: Optional[float] = Field(
|
||||
|
|
@ -65,7 +72,8 @@ class SamplerParams(BaseModel):
|
|||
)
|
||||
|
||||
repetition_penalty: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0)
|
||||
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
repetition_decay: Optional[int] = Field(
|
||||
|
|
@ -77,11 +85,13 @@ class SamplerParams(BaseModel):
|
|||
)
|
||||
|
||||
mirostat_tau: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5)
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5),
|
||||
examples=[1.5],
|
||||
)
|
||||
|
||||
mirostat_eta: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3)
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3),
|
||||
examples=[0.3],
|
||||
)
|
||||
|
||||
add_bos_token: Optional[bool] = Field(
|
||||
|
|
@ -89,7 +99,8 @@ class SamplerParams(BaseModel):
|
|||
)
|
||||
|
||||
ban_eos_token: Optional[bool] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("ban_eos_token", False)
|
||||
default_factory=lambda: get_default_sampler_value("ban_eos_token", False),
|
||||
examples=[False],
|
||||
)
|
||||
|
||||
logit_bias: Optional[Dict[int, float]] = Field(
|
||||
|
|
@ -106,6 +117,7 @@ class SamplerParams(BaseModel):
|
|||
default_factory=lambda: get_default_sampler_value("typical", 1.0),
|
||||
validation_alias=AliasChoices("typical", "typical_p"),
|
||||
description="Aliases: typical_p",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
penalty_range: Optional[int] = Field(
|
||||
|
|
@ -122,6 +134,7 @@ class SamplerParams(BaseModel):
|
|||
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
|
||||
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
|
||||
description="Aliases: guidance_scale",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
def to_gen_params(self):
|
||||
|
|
@ -135,8 +148,9 @@ class SamplerParams(BaseModel):
|
|||
self.stop = [self.stop]
|
||||
|
||||
return {
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
"generate_window": self.generate_window,
|
||||
"stop": self.stop,
|
||||
"add_bos_token": self.add_bos_token,
|
||||
"ban_eos_token": self.ban_eos_token,
|
||||
"token_healing": self.token_healing,
|
||||
|
|
|
|||
|
|
@ -97,6 +97,9 @@ model:
|
|||
# WARNING: This flag disables Flash Attention! (a stopgap fix until it's fixed in upstream)
|
||||
#use_cfg: False
|
||||
|
||||
# Enables fasttensors to possibly increase model loading speeds (default: False)
|
||||
#fasttensors: true
|
||||
|
||||
# Options for draft models (speculative decoding). This will use more VRAM!
|
||||
#draft:
|
||||
# Overrides the directory to look for draft (default: models)
|
||||
|
|
|
|||
|
|
@ -18,6 +18,11 @@ token_healing:
|
|||
override: false
|
||||
force: false
|
||||
|
||||
# Commented out because the default is dynamically scaled
|
||||
#generate_window:
|
||||
#override: 512
|
||||
#force: false
|
||||
|
||||
# MARK: Temperature
|
||||
temperature:
|
||||
override: 1.0
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue