Sampling: Reorder aliased params and add kobold aliases
Also add dynatemp range which is an alternative way of calculating min and max temp. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
7522b1447b
commit
ea80b62e30
2 changed files with 39 additions and 31 deletions
|
|
@ -34,13 +34,22 @@ class BaseSamplerRequest(BaseModel):
|
|||
)
|
||||
|
||||
stop: Optional[Union[str, List[str]]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("stop", [])
|
||||
default_factory=lambda: get_default_sampler_value("stop", []),
|
||||
validation_alias=AliasChoices("stop", "stop_sequence"),
|
||||
description="Aliases: stop_sequence",
|
||||
)
|
||||
|
||||
banned_strings: Optional[Union[str, List[str]]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("banned_strings", [])
|
||||
)
|
||||
|
||||
banned_tokens: Optional[Union[List[int], str]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
|
||||
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
|
||||
description="Aliases: custom_token_bans",
|
||||
examples=[[128, 330]],
|
||||
)
|
||||
|
||||
token_healing: Optional[bool] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("token_healing", False)
|
||||
)
|
||||
|
|
@ -80,6 +89,13 @@ class BaseSamplerRequest(BaseModel):
|
|||
examples=[1.0],
|
||||
)
|
||||
|
||||
typical: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("typical", 1.0),
|
||||
validation_alias=AliasChoices("typical", "typical_p"),
|
||||
description="Aliases: typical_p",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
skew: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("skew", 0.0),
|
||||
examples=[0.0],
|
||||
|
|
@ -100,6 +116,20 @@ class BaseSamplerRequest(BaseModel):
|
|||
examples=[1.0],
|
||||
)
|
||||
|
||||
penalty_range: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
|
||||
validation_alias=AliasChoices(
|
||||
"penalty_range",
|
||||
"repetition_range",
|
||||
"repetition_penalty_range",
|
||||
"rep_pen_range",
|
||||
),
|
||||
description=(
|
||||
"Aliases: repetition_range, repetition_penalty_range, "
|
||||
"rep_pen_range"
|
||||
),
|
||||
)
|
||||
|
||||
repetition_decay: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
|
||||
)
|
||||
|
|
@ -159,28 +189,6 @@ class BaseSamplerRequest(BaseModel):
|
|||
default_factory=lambda: get_default_sampler_value("speculative_ngram"),
|
||||
)
|
||||
|
||||
# Aliased variables
|
||||
typical: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("typical", 1.0),
|
||||
validation_alias=AliasChoices("typical", "typical_p"),
|
||||
description="Aliases: typical_p",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
penalty_range: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
|
||||
validation_alias=AliasChoices(
|
||||
"penalty_range",
|
||||
"repetition_range",
|
||||
"repetition_penalty_range",
|
||||
"rep_pen_range",
|
||||
),
|
||||
description=(
|
||||
"Aliases: repetition_range, repetition_penalty_range, "
|
||||
"rep_pen_range"
|
||||
),
|
||||
)
|
||||
|
||||
cfg_scale: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
|
||||
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
|
||||
|
|
@ -208,13 +216,6 @@ class BaseSamplerRequest(BaseModel):
|
|||
examples=[1.0],
|
||||
)
|
||||
|
||||
banned_tokens: Optional[Union[List[int], str]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
|
||||
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
|
||||
description="Aliases: custom_token_bans",
|
||||
examples=[[128, 330]],
|
||||
)
|
||||
|
||||
# TODO: Return back to adaptable class-based validation But that's just too much
|
||||
# abstraction compared to simple if statements at the moment
|
||||
def validate_params(self):
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from typing import List, Optional
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
from common import model
|
||||
from common.sampling import BaseSamplerRequest
|
||||
from common.sampling import BaseSamplerRequest, get_default_sampler_value
|
||||
from common.utils import flat_map, unwrap
|
||||
|
||||
|
||||
|
|
@ -10,12 +10,19 @@ class GenerateRequest(BaseSamplerRequest):
|
|||
prompt: str
|
||||
genkey: Optional[str] = None
|
||||
use_default_badwordsids: Optional[bool] = False
|
||||
dynatemp_range: Optional[float] = Field(
|
||||
default_factory=get_default_sampler_value("dynatemp_range")
|
||||
)
|
||||
|
||||
def to_gen_params(self, **kwargs):
|
||||
# Exl2 uses -1 to include all tokens in repetition penalty
|
||||
if self.penalty_range == 0:
|
||||
self.penalty_range = -1
|
||||
|
||||
if self.dynatemp_range:
|
||||
self.min_temp = self.temperature - self.dynatemp_range
|
||||
self.max_temp = self.temperature + self.dynatemp_range
|
||||
|
||||
# Move badwordsids into banned tokens for generation
|
||||
if self.use_default_badwordsids:
|
||||
bad_words_ids = unwrap(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue