API: Remove unncessary Optional signatures

Optional isn't necessary if the function signature has a default
value.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-30 23:51:28 -04:00
parent ae75db1829
commit 7556dcf134
5 changed files with 47 additions and 51 deletions

View file

@ -23,90 +23,90 @@ class BaseSamplerRequest(BaseModel):
examples=[512],
)
stop: Optional[Union[str, List[str]]] = Field(
stop: Union[str, List[str]] = Field(
default_factory=lambda: get_default_sampler_value("stop", [])
)
token_healing: Optional[bool] = Field(
token_healing: bool = Field(
default_factory=lambda: get_default_sampler_value("token_healing", False)
)
temperature: Optional[float] = Field(
temperature: float = Field(
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
examples=[1.0],
)
temperature_last: Optional[bool] = Field(
temperature_last: bool = Field(
default_factory=lambda: get_default_sampler_value("temperature_last", False)
)
smoothing_factor: Optional[float] = Field(
smoothing_factor: float = Field(
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
)
top_k: Optional[int] = Field(
top_k: int = Field(
default_factory=lambda: get_default_sampler_value("top_k", 0),
)
top_p: Optional[float] = Field(
top_p: float = Field(
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
examples=[1.0],
)
top_a: Optional[float] = Field(
top_a: float = Field(
default_factory=lambda: get_default_sampler_value("top_a", 0.0)
)
min_p: Optional[float] = Field(
min_p: float = Field(
default_factory=lambda: get_default_sampler_value("min_p", 0.0)
)
tfs: Optional[float] = Field(
tfs: float = Field(
default_factory=lambda: get_default_sampler_value("tfs", 1.0),
examples=[1.0],
)
frequency_penalty: Optional[float] = Field(
frequency_penalty: float = Field(
default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0)
)
presence_penalty: Optional[float] = Field(
presence_penalty: float = Field(
default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0)
)
repetition_penalty: Optional[float] = Field(
repetition_penalty: float = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
examples=[1.0],
)
repetition_decay: Optional[int] = Field(
repetition_decay: int = Field(
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
)
mirostat_mode: Optional[int] = Field(
mirostat_mode: int = Field(
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
)
mirostat_tau: Optional[float] = Field(
mirostat_tau: float = Field(
default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5),
examples=[1.5],
)
mirostat_eta: Optional[float] = Field(
mirostat_eta: float = Field(
default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3),
examples=[0.3],
)
add_bos_token: Optional[bool] = Field(
add_bos_token: bool = Field(
default_factory=lambda: get_default_sampler_value("add_bos_token", True)
)
ban_eos_token: Optional[bool] = Field(
ban_eos_token: bool = Field(
default_factory=lambda: get_default_sampler_value("ban_eos_token", False),
examples=[False],
)
skip_special_tokens: Optional[bool] = Field(
skip_special_tokens: bool = Field(
default_factory=lambda: get_default_sampler_value("ban_eos_token", True),
examples=[True],
)
@ -133,14 +133,14 @@ class BaseSamplerRequest(BaseModel):
)
# Aliased variables
typical: Optional[float] = Field(
typical: float = Field(
default_factory=lambda: get_default_sampler_value("typical", 1.0),
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
examples=[1.0],
)
penalty_range: Optional[int] = Field(
penalty_range: int = Field(
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
validation_alias=AliasChoices(
"penalty_range",
@ -150,34 +150,34 @@ class BaseSamplerRequest(BaseModel):
description="Aliases: repetition_range, repetition_penalty_range",
)
cfg_scale: Optional[float] = Field(
cfg_scale: float = Field(
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
description="Aliases: guidance_scale",
examples=[1.0],
)
max_temp: Optional[float] = Field(
max_temp: float = Field(
default_factory=lambda: get_default_sampler_value("max_temp", 1.0),
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
description="Aliases: dynatemp_high",
examples=[1.0],
)
min_temp: Optional[float] = Field(
min_temp: float = Field(
default_factory=lambda: get_default_sampler_value("min_temp", 1.0),
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
description="Aliases: dynatemp_low",
examples=[1.0],
)
temp_exponent: Optional[float] = Field(
temp_exponent: float = Field(
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
examples=[1.0],
)
banned_tokens: Optional[Union[List[int], str]] = Field(
banned_tokens: Union[List[int], str] = Field(
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
description="Aliases: custom_token_bans",

View file

@ -43,8 +43,8 @@ class ChatCompletionRequest(CommonCompletionRequest):
# Take in a string as well even though it's not part of the OAI spec
messages: Union[str, List[Dict[str, str]]]
prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = {}
add_generation_prompt: bool = True
template_vars: dict = {}
response_prefix: Optional[str] = None

View file

@ -26,8 +26,8 @@ class CommonCompletionRequest(BaseSamplerRequest):
model: Optional[str] = None
# Generation info (remainder is in BaseSamplerRequest superclass)
stream: Optional[bool] = False
logprobs: Optional[int] = 0
stream: bool = False
logprobs: int = 0
response_format: Optional[CompletionResponseFormat] = Field(
default_factory=CompletionResponseFormat
)
@ -36,12 +36,10 @@ class CommonCompletionRequest(BaseSamplerRequest):
best_of: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
echo: Optional[bool] = Field(
echo: bool = Field(
description="Not parsed. Only used for OAI compliance.", default=False
)
n: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=1
)
n: int = Field(description="Not parsed. Only used for OAI compliance.", default=1)
suffix: Optional[str] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)

View file

@ -26,7 +26,7 @@ class LoraLoadInfo(BaseModel):
"""Represents a single Lora load info."""
name: str
scaling: Optional[float] = 1.0
scaling: float = 1.0
class LoraLoadRequest(BaseModel):

View file

@ -13,10 +13,10 @@ class ModelCardParameters(BaseModel):
# Safe to do this since it's guaranteed to fetch a max seq len
# from model_container
max_seq_len: Optional[int] = None
rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0
cache_mode: Optional[str] = "FP16"
chunk_size: Optional[int] = 2048
rope_scale: float = 1.0
rope_alpha: float = 1.0
cache_mode: str = "FP16"
chunk_size: int = 2048
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
use_cfg: Optional[bool] = None
@ -47,7 +47,7 @@ class DraftModelLoadRequest(BaseModel):
"""Represents a draft model load request."""
draft_model_name: str
draft_rope_scale: Optional[float] = 1.0
draft_rope_scale: float = 1.0
draft_rope_alpha: Optional[float] = Field(
description="Automatically calculated if not present",
default=None,
@ -73,11 +73,9 @@ class ModelLoadRequest(BaseModel):
default=None,
examples=[4096],
)
gpu_split_auto: Optional[bool] = True
autosplit_reserve: Optional[List[float]] = [96]
gpu_split: Optional[List[float]] = Field(
default_factory=list, examples=[[24.0, 20.0]]
)
gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96]
gpu_split: List[float] = Field(default_factory=list, examples=[[24.0, 20.0]])
rope_scale: Optional[float] = Field(
description="Automatically pulled from the model's config if not present",
default=None,
@ -88,16 +86,16 @@ class ModelLoadRequest(BaseModel):
default=None,
examples=[1.0],
)
no_flash_attention: Optional[bool] = False
no_flash_attention: bool = False
# low_mem: Optional[bool] = False
cache_mode: Optional[str] = "FP16"
chunk_size: Optional[int] = 2048
cache_mode: str = "FP16"
chunk_size: int = 2048
prompt_template: Optional[str] = None
num_experts_per_token: Optional[int] = None
use_cfg: Optional[bool] = None
fasttensors: Optional[bool] = False
fasttensors: bool = False
draft: Optional[DraftModelLoadRequest] = None
skip_queue: Optional[bool] = False
skip_queue: bool = False
class ModelLoadResponse(BaseModel):