tabbyAPI-ollama/OAI/types/common.py
kingbri 5e8419ec0c OAI: Add chat completions endpoint
Chat completions is the endpoint that will be used by OAI in the
future. Makes sense to support it even though the completions
endpoint will be used more often.

Also unify common parameters between the chat completion and completion
requests since they're very similar.

Signed-off-by: kingbri <bdashore3@proton.me>
2023-11-16 01:06:07 -05:00

85 lines
3 KiB
Python

from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Union
class LogProbs(BaseModel):
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[float] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Dict[str, float]] = Field(default_factory=list)
class UsageStats(BaseModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int
class CommonCompletionRequest(BaseModel):
# Model information
# This parameter is not used, the loaded model is used instead
model: Optional[str] = None
# Extra OAI request stuff
best_of: Optional[int] = None
echo: Optional[bool] = False
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
n: Optional[int] = 1
suffix: Optional[str] = None
user: Optional[str] = None
# Generation info
seed: Optional[int] = -1
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
# Default to 150 as 16 makes no sense as a default
max_tokens: Optional[int] = 150
# Aliased to repetition_penalty
frequency_penalty: Optional[float] = 0.0
# Sampling params
token_healing: Optional[bool] = False
temperature: Optional[float] = 1.0
top_k: Optional[int] = 0
top_p: Optional[float] = 1.0
typical: Optional[float] = 0.0
min_p: Optional[float] = 0.0
tfs: Optional[float] = 1.0
repetition_penalty: Optional[float] = 1.0
repetition_penalty_range: Optional[int] = 0
repetition_decay: Optional[int] = 0
mirostat_mode: Optional[int] = 0
mirostat_tau: Optional[float] = 1.5
mirostat_eta: Optional[float] = 0.1
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
# Converts to internal generation parameters
def to_gen_params(self):
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
self.repetition_penalty = self.frequency_penalty
return {
"stop": self.stop,
"max_tokens": self.max_tokens,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"repetition_penalty": self.repetition_penalty,
"repetition_penalty_range": self.repetition_penalty_range,
"repetition_decay": self.repetition_decay,
"mirostat": True if self.mirostat_mode == 2 else False,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta
}