API + Model: Add new parameters and clean up documentation
The example JSON fields were changed because of the new sampler default strategy. Fix these by manually changing the values. Also add support for fasttensors and expose generate_window to the API. It's recommended to not adjust generate_window as it's dynamically scaled based on max_seq_len by default. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
90fb41a77a
commit
fc4570220c
4 changed files with 45 additions and 10 deletions
|
|
@ -17,7 +17,13 @@ class SamplerParams(BaseModel):
|
|||
"""Common class for sampler params that are used in APIs"""
|
||||
|
||||
max_tokens: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("max_tokens", 150)
|
||||
default_factory=lambda: get_default_sampler_value("max_tokens", 150),
|
||||
examples=[150],
|
||||
)
|
||||
|
||||
generate_window: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("generate_window"),
|
||||
examples=[512],
|
||||
)
|
||||
|
||||
stop: Optional[Union[str, List[str]]] = Field(
|
||||
|
|
@ -29,7 +35,8 @@ class SamplerParams(BaseModel):
|
|||
)
|
||||
|
||||
temperature: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("temperature", 1.0)
|
||||
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
temperature_last: Optional[bool] = Field(
|
||||
|
|
@ -41,7 +48,7 @@ class SamplerParams(BaseModel):
|
|||
)
|
||||
|
||||
top_p: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("top_p", 1.0)
|
||||
default_factory=lambda: get_default_sampler_value("top_p", 1.0), examples=[1.0]
|
||||
)
|
||||
|
||||
top_a: Optional[float] = Field(
|
||||
|
|
@ -65,7 +72,8 @@ class SamplerParams(BaseModel):
|
|||
)
|
||||
|
||||
repetition_penalty: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0)
|
||||
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
repetition_decay: Optional[int] = Field(
|
||||
|
|
@ -77,11 +85,13 @@ class SamplerParams(BaseModel):
|
|||
)
|
||||
|
||||
mirostat_tau: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5)
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5),
|
||||
examples=[1.5],
|
||||
)
|
||||
|
||||
mirostat_eta: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3)
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3),
|
||||
examples=[0.3],
|
||||
)
|
||||
|
||||
add_bos_token: Optional[bool] = Field(
|
||||
|
|
@ -89,7 +99,8 @@ class SamplerParams(BaseModel):
|
|||
)
|
||||
|
||||
ban_eos_token: Optional[bool] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("ban_eos_token", False)
|
||||
default_factory=lambda: get_default_sampler_value("ban_eos_token", False),
|
||||
examples=[False],
|
||||
)
|
||||
|
||||
logit_bias: Optional[Dict[int, float]] = Field(
|
||||
|
|
@ -106,6 +117,7 @@ class SamplerParams(BaseModel):
|
|||
default_factory=lambda: get_default_sampler_value("typical", 1.0),
|
||||
validation_alias=AliasChoices("typical", "typical_p"),
|
||||
description="Aliases: typical_p",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
penalty_range: Optional[int] = Field(
|
||||
|
|
@ -122,6 +134,7 @@ class SamplerParams(BaseModel):
|
|||
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
|
||||
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
|
||||
description="Aliases: guidance_scale",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
def to_gen_params(self):
|
||||
|
|
@ -135,8 +148,9 @@ class SamplerParams(BaseModel):
|
|||
self.stop = [self.stop]
|
||||
|
||||
return {
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
"generate_window": self.generate_window,
|
||||
"stop": self.stop,
|
||||
"add_bos_token": self.add_bos_token,
|
||||
"ban_eos_token": self.ban_eos_token,
|
||||
"token_healing": self.token_healing,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue