From ea80b62e307b3668553b3fdf9d1b3dfccecaf84f Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 18:32:33 -0400 Subject: [PATCH] Sampling: Reorder aliased params and add kobold aliases Also add dynatemp range which is an alternative way of calculating min and max temp. Signed-off-by: kingbri --- common/sampling.py | 61 ++++++++++++++-------------- endpoints/Kobold/types/generation.py | 9 +++- 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/common/sampling.py b/common/sampling.py index 2b3850e..8851f00 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -34,13 +34,22 @@ class BaseSamplerRequest(BaseModel): ) stop: Optional[Union[str, List[str]]] = Field( - default_factory=lambda: get_default_sampler_value("stop", []) + default_factory=lambda: get_default_sampler_value("stop", []), + validation_alias=AliasChoices("stop", "stop_sequence"), + description="Aliases: stop_sequence", ) banned_strings: Optional[Union[str, List[str]]] = Field( default_factory=lambda: get_default_sampler_value("banned_strings", []) ) + banned_tokens: Optional[Union[List[int], str]] = Field( + default_factory=lambda: get_default_sampler_value("banned_tokens", []), + validation_alias=AliasChoices("banned_tokens", "custom_token_bans"), + description="Aliases: custom_token_bans", + examples=[[128, 330]], + ) + token_healing: Optional[bool] = Field( default_factory=lambda: get_default_sampler_value("token_healing", False) ) @@ -80,6 +89,13 @@ class BaseSamplerRequest(BaseModel): examples=[1.0], ) + typical: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("typical", 1.0), + validation_alias=AliasChoices("typical", "typical_p"), + description="Aliases: typical_p", + examples=[1.0], + ) + skew: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("skew", 0.0), examples=[0.0], @@ -100,6 +116,20 @@ class BaseSamplerRequest(BaseModel): examples=[1.0], ) + penalty_range: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("penalty_range", -1), + validation_alias=AliasChoices( + "penalty_range", + "repetition_range", + "repetition_penalty_range", + "rep_pen_range", + ), + description=( + "Aliases: repetition_range, repetition_penalty_range, " + "rep_pen_range" + ), + ) + repetition_decay: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("repetition_decay", 0) ) @@ -159,28 +189,6 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("speculative_ngram"), ) - # Aliased variables - typical: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("typical", 1.0), - validation_alias=AliasChoices("typical", "typical_p"), - description="Aliases: typical_p", - examples=[1.0], - ) - - penalty_range: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("penalty_range", -1), - validation_alias=AliasChoices( - "penalty_range", - "repetition_range", - "repetition_penalty_range", - "rep_pen_range", - ), - description=( - "Aliases: repetition_range, repetition_penalty_range, " - "rep_pen_range" - ), - ) - cfg_scale: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0), validation_alias=AliasChoices("cfg_scale", "guidance_scale"), @@ -208,13 +216,6 @@ class BaseSamplerRequest(BaseModel): examples=[1.0], ) - banned_tokens: Optional[Union[List[int], str]] = Field( - default_factory=lambda: get_default_sampler_value("banned_tokens", []), - validation_alias=AliasChoices("banned_tokens", "custom_token_bans"), - description="Aliases: custom_token_bans", - examples=[[128, 330]], - ) - # TODO: Return back to adaptable class-based validation But that's just too much # abstraction compared to simple if statements at the moment def validate_params(self): diff --git a/endpoints/Kobold/types/generation.py b/endpoints/Kobold/types/generation.py index 0ee5489..210a914 100644 --- a/endpoints/Kobold/types/generation.py +++ b/endpoints/Kobold/types/generation.py @@ -2,7 +2,7 @@ from typing import List, Optional from pydantic import BaseModel, Field from common import model -from common.sampling import BaseSamplerRequest +from common.sampling import BaseSamplerRequest, get_default_sampler_value from common.utils import flat_map, unwrap @@ -10,12 +10,19 @@ class GenerateRequest(BaseSamplerRequest): prompt: str genkey: Optional[str] = None use_default_badwordsids: Optional[bool] = False + dynatemp_range: Optional[float] = Field( + default_factory=get_default_sampler_value("dynatemp_range") + ) def to_gen_params(self, **kwargs): # Exl2 uses -1 to include all tokens in repetition penalty if self.penalty_range == 0: self.penalty_range = -1 + if self.dynatemp_range: + self.min_temp = self.temperature - self.dynatemp_range + self.max_temp = self.temperature + self.dynatemp_range + # Move badwordsids into banned tokens for generation if self.use_default_badwordsids: bad_words_ids = unwrap(