Kobold: Move params to aliases

Some of the parameters the API provides are aliases for their OAI
equivalents. It makes more sense to move them to the common file.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-26 16:45:29 -04:00
parent b7cb6f0b91
commit 545e26608f
2 changed files with 18 additions and 18 deletions

View file

@ -16,11 +16,15 @@ class BaseSamplerRequest(BaseModel):
max_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("max_tokens"),
validation_alias=AliasChoices("max_tokens", "max_length"),
description="Aliases: max_length",
examples=[150],
)
min_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("min_tokens", 0),
validation_alias=AliasChoices("min_tokens", "min_length"),
description="Aliases: min_length",
examples=[0],
)
@ -91,6 +95,8 @@ class BaseSamplerRequest(BaseModel):
repetition_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
validation_alias=AliasChoices("repetition_penalty", "rep_pen"),
description="Aliases: rep_pen",
examples=[1.0],
)
@ -118,6 +124,8 @@ class BaseSamplerRequest(BaseModel):
ban_eos_token: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("ban_eos_token", False),
validation_alias=AliasChoices("ban_eos_token", "ignore_eos"),
description="Aliases: ignore_eos",
examples=[False],
)
@ -165,8 +173,12 @@ class BaseSamplerRequest(BaseModel):
"penalty_range",
"repetition_range",
"repetition_penalty_range",
"rep_pen_range",
),
description=(
"Aliases: repetition_range, repetition_penalty_range, "
"rep_pen_range"
),
description="Aliases: repetition_range, repetition_penalty_range",
)
cfg_scale: Optional[float] = Field(

View file

@ -1,30 +1,18 @@
from typing import List, Optional
from pydantic import BaseModel, Field
from common.sampling import BaseSamplerRequest, get_default_sampler_value
from common.sampling import BaseSamplerRequest
class GenerateRequest(BaseSamplerRequest):
prompt: str
use_default_badwordsids: Optional[bool] = False
genkey: Optional[str] = None
max_length: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("max_tokens"),
examples=[150],
)
rep_pen_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
)
rep_pen: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
)
use_default_badwordsids: Optional[bool] = False
def to_gen_params(self, **kwargs):
# Swap kobold generation params to OAI/Exl2 ones
self.max_tokens = self.max_length
self.repetition_penalty = self.rep_pen
self.penalty_range = -1 if self.rep_pen_range == 0 else self.rep_pen_range
# Exl2 uses -1 to include all tokens in repetition penalty
if self.penalty_range == 0:
self.penalty_range = -1
return super().to_gen_params(**kwargs)