API: Remove unncessary Optional signatures
Optional isn't necessary if the function signature has a default value. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
ae75db1829
commit
7556dcf134
5 changed files with 47 additions and 51 deletions
|
|
@ -23,90 +23,90 @@ class BaseSamplerRequest(BaseModel):
|
|||
examples=[512],
|
||||
)
|
||||
|
||||
stop: Optional[Union[str, List[str]]] = Field(
|
||||
stop: Union[str, List[str]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("stop", [])
|
||||
)
|
||||
|
||||
token_healing: Optional[bool] = Field(
|
||||
token_healing: bool = Field(
|
||||
default_factory=lambda: get_default_sampler_value("token_healing", False)
|
||||
)
|
||||
|
||||
temperature: Optional[float] = Field(
|
||||
temperature: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
temperature_last: Optional[bool] = Field(
|
||||
temperature_last: bool = Field(
|
||||
default_factory=lambda: get_default_sampler_value("temperature_last", False)
|
||||
)
|
||||
|
||||
smoothing_factor: Optional[float] = Field(
|
||||
smoothing_factor: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
|
||||
)
|
||||
|
||||
top_k: Optional[int] = Field(
|
||||
top_k: int = Field(
|
||||
default_factory=lambda: get_default_sampler_value("top_k", 0),
|
||||
)
|
||||
|
||||
top_p: Optional[float] = Field(
|
||||
top_p: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
top_a: Optional[float] = Field(
|
||||
top_a: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("top_a", 0.0)
|
||||
)
|
||||
|
||||
min_p: Optional[float] = Field(
|
||||
min_p: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("min_p", 0.0)
|
||||
)
|
||||
|
||||
tfs: Optional[float] = Field(
|
||||
tfs: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("tfs", 1.0),
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
frequency_penalty: Optional[float] = Field(
|
||||
frequency_penalty: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0)
|
||||
)
|
||||
|
||||
presence_penalty: Optional[float] = Field(
|
||||
presence_penalty: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0)
|
||||
)
|
||||
|
||||
repetition_penalty: Optional[float] = Field(
|
||||
repetition_penalty: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
repetition_decay: Optional[int] = Field(
|
||||
repetition_decay: int = Field(
|
||||
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
|
||||
)
|
||||
|
||||
mirostat_mode: Optional[int] = Field(
|
||||
mirostat_mode: int = Field(
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
|
||||
)
|
||||
|
||||
mirostat_tau: Optional[float] = Field(
|
||||
mirostat_tau: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5),
|
||||
examples=[1.5],
|
||||
)
|
||||
|
||||
mirostat_eta: Optional[float] = Field(
|
||||
mirostat_eta: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3),
|
||||
examples=[0.3],
|
||||
)
|
||||
|
||||
add_bos_token: Optional[bool] = Field(
|
||||
add_bos_token: bool = Field(
|
||||
default_factory=lambda: get_default_sampler_value("add_bos_token", True)
|
||||
)
|
||||
|
||||
ban_eos_token: Optional[bool] = Field(
|
||||
ban_eos_token: bool = Field(
|
||||
default_factory=lambda: get_default_sampler_value("ban_eos_token", False),
|
||||
examples=[False],
|
||||
)
|
||||
|
||||
skip_special_tokens: Optional[bool] = Field(
|
||||
skip_special_tokens: bool = Field(
|
||||
default_factory=lambda: get_default_sampler_value("ban_eos_token", True),
|
||||
examples=[True],
|
||||
)
|
||||
|
|
@ -133,14 +133,14 @@ class BaseSamplerRequest(BaseModel):
|
|||
)
|
||||
|
||||
# Aliased variables
|
||||
typical: Optional[float] = Field(
|
||||
typical: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("typical", 1.0),
|
||||
validation_alias=AliasChoices("typical", "typical_p"),
|
||||
description="Aliases: typical_p",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
penalty_range: Optional[int] = Field(
|
||||
penalty_range: int = Field(
|
||||
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
|
||||
validation_alias=AliasChoices(
|
||||
"penalty_range",
|
||||
|
|
@ -150,34 +150,34 @@ class BaseSamplerRequest(BaseModel):
|
|||
description="Aliases: repetition_range, repetition_penalty_range",
|
||||
)
|
||||
|
||||
cfg_scale: Optional[float] = Field(
|
||||
cfg_scale: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
|
||||
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
|
||||
description="Aliases: guidance_scale",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
max_temp: Optional[float] = Field(
|
||||
max_temp: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("max_temp", 1.0),
|
||||
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
|
||||
description="Aliases: dynatemp_high",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
min_temp: Optional[float] = Field(
|
||||
min_temp: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("min_temp", 1.0),
|
||||
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
|
||||
description="Aliases: dynatemp_low",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
temp_exponent: Optional[float] = Field(
|
||||
temp_exponent: float = Field(
|
||||
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
|
||||
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
banned_tokens: Optional[Union[List[int], str]] = Field(
|
||||
banned_tokens: Union[List[int], str] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
|
||||
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
|
||||
description="Aliases: custom_token_bans",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue