OAI: Add field aliasing
Repetition penalty range needs field aliases to support multiple parameter calls. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
124e39df26
commit
bc21f0bbc0
1 changed files with 10 additions and 8 deletions
|
|
@ -1,6 +1,6 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, AliasChoices
|
||||
from typing import List, Dict, Optional, Union
|
||||
from utils import coalesce
|
||||
from utils import unwrap
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
text_offset: List[int] = Field(default_factory=list)
|
||||
|
|
@ -36,6 +36,7 @@ class CommonCompletionRequest(BaseModel):
|
|||
max_tokens: Optional[int] = 150
|
||||
|
||||
# Aliased to repetition_penalty
|
||||
# TODO: Maybe make this an alias to rep pen
|
||||
frequency_penalty: Optional[float] = Field(description = "Aliased to Repetition Penalty", default = 0.0)
|
||||
|
||||
# Sampling params
|
||||
|
|
@ -56,20 +57,21 @@ class CommonCompletionRequest(BaseModel):
|
|||
ban_eos_token: Optional[bool] = False
|
||||
|
||||
# Aliased variables
|
||||
# TODO: Add a function to iterate through aliases and return a default value if all are None
|
||||
repetition_range: Optional[int] = None
|
||||
repetition_penalty_range: Optional[int] = None
|
||||
repetition_range: Optional[int] = Field(
|
||||
default = None,
|
||||
validation_alias = AliasChoices('repetition_range', 'repetition_penalty_range')
|
||||
)
|
||||
|
||||
# Converts to internal generation parameters
|
||||
def to_gen_params(self):
|
||||
# Convert stop to an array of strings
|
||||
if isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
|
||||
|
||||
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
|
||||
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
|
||||
self.repetition_penalty = self.frequency_penalty
|
||||
|
||||
|
||||
return {
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens,
|
||||
|
|
@ -84,7 +86,7 @@ class CommonCompletionRequest(BaseModel):
|
|||
"min_p": self.min_p,
|
||||
"tfs": self.tfs,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"repetition_range": coalesce(self.repetition_range, self.repetition_penalty_range, -1),
|
||||
"repetition_range": unwrap(self.repetition_range, -1),
|
||||
"repetition_decay": self.repetition_decay,
|
||||
"mirostat": self.mirostat_mode == 2,
|
||||
"mirostat_tau": self.mirostat_tau,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue