tabbyAPI-ollama/OAI/types/common.py
AlpinDale fa47f51f85
feat: workflows for formatting/linting (#35)
* add github workflows for pylint and yapf

* yapf

* docstrings for auth

* fix auth.py

* fix generators.py

* fix gen_logging.py

* fix main.py

* fix model.py

* fix templating.py

* fix utils.py

* update formatting.sh to include subdirs for pylint

* fix model_test.py

* fix wheel_test.py

* rename utils to utils_oai

* fix OAI/utils_oai.py

* fix completion.py

* fix token.py

* fix lora.py

* fix common.py

* add pylintrc and fix model.py

* finish up pylint

* fix attribute error

* main.py formatting

* add formatting batch script

* Main: Remove unnecessary global

Linter suggestion.

Signed-off-by: kingbri <bdashore3@proton.me>

* switch to ruff

* Formatting + Linting: Add ruff.toml

Signed-off-by: kingbri <bdashore3@proton.me>

* Formatting + Linting: Switch scripts to use ruff

Also remove the file and recent file change functions from both
scripts.

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Format and lint

Signed-off-by: kingbri <bdashore3@proton.me>

* Scripts + Workflows: Format

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Remove pylint flags

We use ruff now

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Format

Signed-off-by: kingbri <bdashore3@proton.me>

* Formatting: Line length is 88

Use the same value as Black.

Signed-off-by: kingbri <bdashore3@proton.me>

* Tree: Format

Update to new line length rules.

Signed-off-by: kingbri <bdashore3@proton.me>

---------

Authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com>
Co-authored-by: kingbri <bdashore3@proton.me>
2023-12-22 16:20:35 +00:00

123 lines
4.1 KiB
Python

""" Common types for OAI. """
from typing import List, Dict, Optional, Union
from pydantic import BaseModel, Field, AliasChoices
from utils import unwrap
class LogProbs(BaseModel):
"""Represents log probabilities."""
text_offset: List[int] = Field(default_factory=list)
token_logprobs: List[float] = Field(default_factory=list)
tokens: List[str] = Field(default_factory=list)
top_logprobs: List[Dict[str, float]] = Field(default_factory=list)
class UsageStats(BaseModel):
"""Represents usage stats."""
prompt_tokens: int
completion_tokens: int
total_tokens: int
class CommonCompletionRequest(BaseModel):
"""Represents a common completion request."""
# Model information
# This parameter is not used, the loaded model is used instead
model: Optional[str] = None
# Extra OAI request stuff
best_of: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
echo: Optional[bool] = Field(
description="Not parsed. Only used for OAI compliance.", default=False
)
logprobs: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
n: Optional[int] = Field(
description="Not parsed. Only used for OAI compliance.", default=1
)
suffix: Optional[str] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
user: Optional[str] = Field(
description="Not parsed. Only used for OAI compliance.", default=None
)
# Generation info
# seed: Optional[int] = -1
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = []
# Default to 150 as 16 makes no sense as a default
max_tokens: Optional[int] = 150
# Aliased to repetition_penalty
frequency_penalty: Optional[float] = Field(
description="Aliased to Repetition Penalty", default=0.0
)
# Sampling params
token_healing: Optional[bool] = False
temperature: Optional[float] = 1.0
temperature_last: Optional[bool] = False
top_k: Optional[int] = 0
top_p: Optional[float] = 1.0
typical: Optional[float] = 1.0
min_p: Optional[float] = 0.0
tfs: Optional[float] = 1.0
repetition_penalty: Optional[float] = 1.0
repetition_decay: Optional[int] = 0
mirostat_mode: Optional[int] = 0
mirostat_tau: Optional[float] = 1.5
mirostat_eta: Optional[float] = 0.1
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
logit_bias: Optional[Dict[int, float]] = None
# Aliased variables
repetition_range: Optional[int] = Field(
default=None,
validation_alias=AliasChoices("repetition_range", "repetition_penalty_range"),
)
def to_gen_params(self):
"""Converts to internal generation parameters."""
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]
# Set repetition_penalty to frequency_penalty if repetition_penalty
# isn't already defined
if (
self.repetition_penalty is None or self.repetition_penalty == 1.0
) and self.frequency_penalty:
self.repetition_penalty = self.frequency_penalty
return {
"stop": self.stop,
"max_tokens": self.max_tokens,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
"logit_bias": self.logit_bias,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"top_k": self.top_k,
"top_p": self.top_p,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"repetition_penalty": self.repetition_penalty,
"repetition_range": unwrap(self.repetition_range, -1),
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta,
}