Kobold: Move params to aliases
Some of the parameters the API provides are aliases for their OAI equivalents. It makes more sense to move them to the common file. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
b7cb6f0b91
commit
545e26608f
2 changed files with 18 additions and 18 deletions
|
|
@ -16,11 +16,15 @@ class BaseSamplerRequest(BaseModel):
|
|||
|
||||
max_tokens: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("max_tokens"),
|
||||
validation_alias=AliasChoices("max_tokens", "max_length"),
|
||||
description="Aliases: max_length",
|
||||
examples=[150],
|
||||
)
|
||||
|
||||
min_tokens: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("min_tokens", 0),
|
||||
validation_alias=AliasChoices("min_tokens", "min_length"),
|
||||
description="Aliases: min_length",
|
||||
examples=[0],
|
||||
)
|
||||
|
||||
|
|
@ -91,6 +95,8 @@ class BaseSamplerRequest(BaseModel):
|
|||
|
||||
repetition_penalty: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
|
||||
validation_alias=AliasChoices("repetition_penalty", "rep_pen"),
|
||||
description="Aliases: rep_pen",
|
||||
examples=[1.0],
|
||||
)
|
||||
|
||||
|
|
@ -118,6 +124,8 @@ class BaseSamplerRequest(BaseModel):
|
|||
|
||||
ban_eos_token: Optional[bool] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("ban_eos_token", False),
|
||||
validation_alias=AliasChoices("ban_eos_token", "ignore_eos"),
|
||||
description="Aliases: ignore_eos",
|
||||
examples=[False],
|
||||
)
|
||||
|
||||
|
|
@ -165,8 +173,12 @@ class BaseSamplerRequest(BaseModel):
|
|||
"penalty_range",
|
||||
"repetition_range",
|
||||
"repetition_penalty_range",
|
||||
"rep_pen_range",
|
||||
),
|
||||
description=(
|
||||
"Aliases: repetition_range, repetition_penalty_range, "
|
||||
"rep_pen_range"
|
||||
),
|
||||
description="Aliases: repetition_range, repetition_penalty_range",
|
||||
)
|
||||
|
||||
cfg_scale: Optional[float] = Field(
|
||||
|
|
|
|||
|
|
@ -1,30 +1,18 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from common.sampling import BaseSamplerRequest, get_default_sampler_value
|
||||
from common.sampling import BaseSamplerRequest
|
||||
|
||||
|
||||
class GenerateRequest(BaseSamplerRequest):
|
||||
prompt: str
|
||||
use_default_badwordsids: Optional[bool] = False
|
||||
genkey: Optional[str] = None
|
||||
|
||||
max_length: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("max_tokens"),
|
||||
examples=[150],
|
||||
)
|
||||
rep_pen_range: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
|
||||
)
|
||||
rep_pen: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
|
||||
)
|
||||
use_default_badwordsids: Optional[bool] = False
|
||||
|
||||
def to_gen_params(self, **kwargs):
|
||||
# Swap kobold generation params to OAI/Exl2 ones
|
||||
self.max_tokens = self.max_length
|
||||
self.repetition_penalty = self.rep_pen
|
||||
self.penalty_range = -1 if self.rep_pen_range == 0 else self.rep_pen_range
|
||||
# Exl2 uses -1 to include all tokens in repetition penalty
|
||||
if self.penalty_range == 0:
|
||||
self.penalty_range = -1
|
||||
|
||||
return super().to_gen_params(**kwargs)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue