tabbyAPI-ollama/OAI/types/completion.py
kingbri d0b6b11068 OAI: Make freq and presence pen floats
Also rename the completions typing file.

Signed-off-by: kingbri <bdashore3@proton.me>
2023-11-15 00:55:15 -05:00

103 lines
3.4 KiB
Python

from uuid import uuid4
from time import time
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Union
from OAI.types.common import LogProbs, UsageStats
class CompletionRespChoice(BaseModel):
finish_reason: str
index: int
logprobs: Optional[LogProbs] = None
text: str
class CompletionRequest(BaseModel):
# Model information
model: str
# Prompt can also contain token ids, but that's out of scope for this project.
prompt: Union[str, List[str]]
# Extra OAI request stuff
best_of: Optional[int] = None
echo: Optional[bool] = False
logit_bias: Optional[Dict[str, float]] = None
logprobs: Optional[int] = None
n: Optional[int] = 1
suffix: Optional[str] = None
user: Optional[str] = None
# Generation info
seed: Optional[int] = -1
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = None
# Default to 150 as 16 makes no sense as a default
max_tokens: Optional[int] = 150
# Not supported sampling params
presence_penalty: Optional[float] = 0.0
# Aliased to repetition_penalty
frequency_penalty: Optional[float] = 0.0
# Sampling params
token_healing: Optional[bool] = False
temperature: Optional[float] = 1.0
top_k: Optional[int] = 0
top_p: Optional[float] = 1.0
typical: Optional[float] = 0.0
min_p: Optional[float] = 0.0
tfs: Optional[float] = 1.0
repetition_penalty: Optional[float] = 1.0
repetition_penalty_range: Optional[int] = 0
repetition_decay: Optional[int] = 0
mirostat_mode: Optional[int] = 0
mirostat_tau: Optional[float] = 1.5
mirostat_eta: Optional[float] = 0.1
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
# Converts to internal generation parameters
def to_gen_params(self):
# Convert prompt to a string
if isinstance(self.prompt, list):
self.prompt = "\n".join(self.prompt)
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]
# Set repetition_penalty to frequency_penalty if repetition_penalty isn't already defined
if (self.repetition_penalty is None or self.repetition_penalty == 1.0) and self.frequency_penalty:
self.repetition_penalty = self.frequency_penalty
return {
"prompt": self.prompt,
"stop": self.stop,
"max_tokens": self.max_tokens,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
"temperature": self.temperature,
"top_k": self.top_k,
"top_p": self.top_p,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"repetition_penalty": self.repetition_penalty,
"repetition_penalty_range": self.repetition_penalty_range,
"repetition_decay": self.repetition_decay,
"mirostat": True if self.mirostat_mode == 2 else False,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta
}
class CompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"cmpl-{uuid4().hex}")
choices: List[CompletionRespChoice]
created: int = Field(default_factory=lambda: int(time()))
model: str
object: str = "text-completion"
# TODO: Add usage stats
usage: Optional[UsageStats] = None