Sampling: Reorder aliased params and add kobold aliases

Also add dynatemp range which is an alternative way of calculating
min and max temp.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-26 18:32:33 -04:00
parent 7522b1447b
commit ea80b62e30
2 changed files with 39 additions and 31 deletions

View file

@ -34,13 +34,22 @@ class BaseSamplerRequest(BaseModel):
)
stop: Optional[Union[str, List[str]]] = Field(
default_factory=lambda: get_default_sampler_value("stop", [])
default_factory=lambda: get_default_sampler_value("stop", []),
validation_alias=AliasChoices("stop", "stop_sequence"),
description="Aliases: stop_sequence",
)
banned_strings: Optional[Union[str, List[str]]] = Field(
default_factory=lambda: get_default_sampler_value("banned_strings", [])
)
banned_tokens: Optional[Union[List[int], str]] = Field(
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
description="Aliases: custom_token_bans",
examples=[[128, 330]],
)
token_healing: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("token_healing", False)
)
@ -80,6 +89,13 @@ class BaseSamplerRequest(BaseModel):
examples=[1.0],
)
typical: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("typical", 1.0),
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
examples=[1.0],
)
skew: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("skew", 0.0),
examples=[0.0],
@ -100,6 +116,20 @@ class BaseSamplerRequest(BaseModel):
examples=[1.0],
)
penalty_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
validation_alias=AliasChoices(
"penalty_range",
"repetition_range",
"repetition_penalty_range",
"rep_pen_range",
),
description=(
"Aliases: repetition_range, repetition_penalty_range, "
"rep_pen_range"
),
)
repetition_decay: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
)
@ -159,28 +189,6 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("speculative_ngram"),
)
# Aliased variables
typical: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("typical", 1.0),
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
examples=[1.0],
)
penalty_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
validation_alias=AliasChoices(
"penalty_range",
"repetition_range",
"repetition_penalty_range",
"rep_pen_range",
),
description=(
"Aliases: repetition_range, repetition_penalty_range, "
"rep_pen_range"
),
)
cfg_scale: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
@ -208,13 +216,6 @@ class BaseSamplerRequest(BaseModel):
examples=[1.0],
)
banned_tokens: Optional[Union[List[int], str]] = Field(
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
description="Aliases: custom_token_bans",
examples=[[128, 330]],
)
# TODO: Return back to adaptable class-based validation But that's just too much
# abstraction compared to simple if statements at the moment
def validate_params(self):

View file

@ -2,7 +2,7 @@ from typing import List, Optional
from pydantic import BaseModel, Field
from common import model
from common.sampling import BaseSamplerRequest
from common.sampling import BaseSamplerRequest, get_default_sampler_value
from common.utils import flat_map, unwrap
@ -10,12 +10,19 @@ class GenerateRequest(BaseSamplerRequest):
prompt: str
genkey: Optional[str] = None
use_default_badwordsids: Optional[bool] = False
dynatemp_range: Optional[float] = Field(
default_factory=get_default_sampler_value("dynatemp_range")
)
def to_gen_params(self, **kwargs):
# Exl2 uses -1 to include all tokens in repetition penalty
if self.penalty_range == 0:
self.penalty_range = -1
if self.dynatemp_range:
self.min_temp = self.temperature - self.dynatemp_range
self.max_temp = self.temperature + self.dynatemp_range
# Move badwordsids into banned tokens for generation
if self.use_default_badwordsids:
bad_words_ids = unwrap(