From 545e26608f8dcfa9600e7060d53daf3efa112fc3 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 16:45:29 -0400 Subject: [PATCH] Kobold: Move params to aliases Some of the parameters the API provides are aliases for their OAI equivalents. It makes more sense to move them to the common file. Signed-off-by: kingbri --- common/sampling.py | 14 +++++++++++++- endpoints/Kobold/types/generation.py | 22 +++++----------------- 2 files changed, 18 insertions(+), 18 deletions(-) diff --git a/common/sampling.py b/common/sampling.py index bbeddb8..2b3850e 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -16,11 +16,15 @@ class BaseSamplerRequest(BaseModel): max_tokens: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("max_tokens"), + validation_alias=AliasChoices("max_tokens", "max_length"), + description="Aliases: max_length", examples=[150], ) min_tokens: Optional[int] = Field( default_factory=lambda: get_default_sampler_value("min_tokens", 0), + validation_alias=AliasChoices("min_tokens", "min_length"), + description="Aliases: min_length", examples=[0], ) @@ -91,6 +95,8 @@ class BaseSamplerRequest(BaseModel): repetition_penalty: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), + validation_alias=AliasChoices("repetition_penalty", "rep_pen"), + description="Aliases: rep_pen", examples=[1.0], ) @@ -118,6 +124,8 @@ class BaseSamplerRequest(BaseModel): ban_eos_token: Optional[bool] = Field( default_factory=lambda: get_default_sampler_value("ban_eos_token", False), + validation_alias=AliasChoices("ban_eos_token", "ignore_eos"), + description="Aliases: ignore_eos", examples=[False], ) @@ -165,8 +173,12 @@ class BaseSamplerRequest(BaseModel): "penalty_range", "repetition_range", "repetition_penalty_range", + "rep_pen_range", + ), + description=( + "Aliases: repetition_range, repetition_penalty_range, " + "rep_pen_range" ), - description="Aliases: repetition_range, repetition_penalty_range", ) cfg_scale: Optional[float] = Field( diff --git a/endpoints/Kobold/types/generation.py b/endpoints/Kobold/types/generation.py index 5468741..eab214c 100644 --- a/endpoints/Kobold/types/generation.py +++ b/endpoints/Kobold/types/generation.py @@ -1,30 +1,18 @@ from typing import List, Optional from pydantic import BaseModel, Field -from common.sampling import BaseSamplerRequest, get_default_sampler_value +from common.sampling import BaseSamplerRequest class GenerateRequest(BaseSamplerRequest): prompt: str - use_default_badwordsids: Optional[bool] = False genkey: Optional[str] = None - - max_length: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("max_tokens"), - examples=[150], - ) - rep_pen_range: Optional[int] = Field( - default_factory=lambda: get_default_sampler_value("penalty_range", -1), - ) - rep_pen: Optional[float] = Field( - default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), - ) + use_default_badwordsids: Optional[bool] = False def to_gen_params(self, **kwargs): - # Swap kobold generation params to OAI/Exl2 ones - self.max_tokens = self.max_length - self.repetition_penalty = self.rep_pen - self.penalty_range = -1 if self.rep_pen_range == 0 else self.rep_pen_range + # Exl2 uses -1 to include all tokens in repetition penalty + if self.penalty_range == 0: + self.penalty_range = -1 return super().to_gen_params(**kwargs)