API: Remove unncessary Optional signatures

Optional isn't necessary if the function signature has a default
value.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-30 23:51:28 -04:00
parent ae75db1829
commit 7556dcf134
5 changed files with 47 additions and 51 deletions

View file

@ -23,90 +23,90 @@ class BaseSamplerRequest(BaseModel):
examples=[512],
)
stop: Optional[Union[str, List[str]]] = Field(
stop: Union[str, List[str]] = Field(
default_factory=lambda: get_default_sampler_value("stop", [])
)
token_healing: Optional[bool] = Field(
token_healing: bool = Field(
default_factory=lambda: get_default_sampler_value("token_healing", False)
)
temperature: Optional[float] = Field(
temperature: float = Field(
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
examples=[1.0],
)
temperature_last: Optional[bool] = Field(
temperature_last: bool = Field(
default_factory=lambda: get_default_sampler_value("temperature_last", False)
)
smoothing_factor: Optional[float] = Field(
smoothing_factor: float = Field(
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
)
top_k: Optional[int] = Field(
top_k: int = Field(
default_factory=lambda: get_default_sampler_value("top_k", 0),
)
top_p: Optional[float] = Field(
top_p: float = Field(
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
examples=[1.0],
)
top_a: Optional[float] = Field(
top_a: float = Field(
default_factory=lambda: get_default_sampler_value("top_a", 0.0)
)
min_p: Optional[float] = Field(
min_p: float = Field(
default_factory=lambda: get_default_sampler_value("min_p", 0.0)
)
tfs: Optional[float] = Field(
tfs: float = Field(
default_factory=lambda: get_default_sampler_value("tfs", 1.0),
examples=[1.0],
)
frequency_penalty: Optional[float] = Field(
frequency_penalty: float = Field(
default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0)
)
presence_penalty: Optional[float] = Field(
presence_penalty: float = Field(
default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0)
)
repetition_penalty: Optional[float] = Field(
repetition_penalty: float = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
examples=[1.0],
)
repetition_decay: Optional[int] = Field(
repetition_decay: int = Field(
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
)
mirostat_mode: Optional[int] = Field(
mirostat_mode: int = Field(
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
)
mirostat_tau: Optional[float] = Field(
mirostat_tau: float = Field(
default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5),
examples=[1.5],
)
mirostat_eta: Optional[float] = Field(
mirostat_eta: float = Field(
default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3),
examples=[0.3],
)
add_bos_token: Optional[bool] = Field(
add_bos_token: bool = Field(
default_factory=lambda: get_default_sampler_value("add_bos_token", True)
)
ban_eos_token: Optional[bool] = Field(
ban_eos_token: bool = Field(
default_factory=lambda: get_default_sampler_value("ban_eos_token", False),
examples=[False],
)
skip_special_tokens: Optional[bool] = Field(
skip_special_tokens: bool = Field(
default_factory=lambda: get_default_sampler_value("ban_eos_token", True),
examples=[True],
)
@ -133,14 +133,14 @@ class BaseSamplerRequest(BaseModel):
)
# Aliased variables
typical: Optional[float] = Field(
typical: float = Field(
default_factory=lambda: get_default_sampler_value("typical", 1.0),
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
examples=[1.0],
)
penalty_range: Optional[int] = Field(
penalty_range: int = Field(
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
validation_alias=AliasChoices(
"penalty_range",
@ -150,34 +150,34 @@ class BaseSamplerRequest(BaseModel):
description="Aliases: repetition_range, repetition_penalty_range",
)
cfg_scale: Optional[float] = Field(
cfg_scale: float = Field(
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
description="Aliases: guidance_scale",
examples=[1.0],
)
max_temp: Optional[float] = Field(
max_temp: float = Field(
default_factory=lambda: get_default_sampler_value("max_temp", 1.0),
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
description="Aliases: dynatemp_high",
examples=[1.0],
)
min_temp: Optional[float] = Field(
min_temp: float = Field(
default_factory=lambda: get_default_sampler_value("min_temp", 1.0),
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
description="Aliases: dynatemp_low",
examples=[1.0],
)
temp_exponent: Optional[float] = Field(
temp_exponent: float = Field(
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
examples=[1.0],
)
banned_tokens: Optional[Union[List[int], str]] = Field(
banned_tokens: Union[List[int], str] = Field(
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
description="Aliases: custom_token_bans",