OAI: Add field aliasing

Repetition penalty range needs field aliases to support multiple
parameter calls.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-17 21:57:46 -05:00 committed by Brian Dashore
parent 124e39df26
commit bc21f0bbc0

View file

@ -1,6 +1,6 @@
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, AliasChoices
from typing import List, Dict, Optional, Union
from utils import coalesce
from utils import unwrap
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
@ -36,6 +36,7 @@ class CommonCompletionRequest(BaseModel):
max_tokens: Optional[int] = 150
# Aliased to repetition_penalty
# TODO: Maybe make this an alias to rep pen
frequency_penalty: Optional[float] = Field(description = "Aliased to Repetition Penalty", default = 0.0)
# Sampling params
@ -56,20 +57,21 @@ class CommonCompletionRequest(BaseModel):
ban_eos_token: Optional[bool] = False
# Aliased variables
# TODO: Add a function to iterate through aliases and return a default value if all are None
repetition_range: Optional[int] = None
repetition_penalty_range: Optional[int] = None
repetition_range: Optional[int] = Field(
default = None,
validation_alias = AliasChoices('repetition_range', 'repetition_penalty_range')
)
# Converts to internal generation parameters
def to_gen_params(self):
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
self.repetition_penalty = self.frequency_penalty
return {
"stop": self.stop,
"max_tokens": self.max_tokens,
@ -84,7 +86,7 @@ class CommonCompletionRequest(BaseModel):
"min_p": self.min_p,
"tfs": self.tfs,
"repetition_penalty": self.repetition_penalty,
"repetition_range": coalesce(self.repetition_range, self.repetition_penalty_range, -1),
"repetition_range": unwrap(self.repetition_range, -1),
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,