tabbyAPI-ollama/common/sampling.py
kingbri c23e406f2d Sampling: Add max_completion_tokens
Conforms with OAI's updated spec

Signed-off-by: kingbri <8082010+bdashore3@users.noreply.github.com>
2024-12-13 01:02:37 -05:00

434 lines
14 KiB
Python

"""Common functions for sampling parameters"""
import aiofiles
import json
import pathlib
from pydantic_core import ValidationError
from ruamel.yaml import YAML
from copy import deepcopy
from loguru import logger
from pydantic import (
AliasChoices,
BaseModel,
Field,
field_validator,
model_validator,
)
from typing import Dict, List, Optional, Union
from common.utils import filter_none_values, unwrap
# Common class for sampler params
class BaseSamplerRequest(BaseModel):
"""Common class for sampler params that are used in APIs"""
max_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("max_tokens"),
validation_alias=AliasChoices(
"max_tokens", "max_completion_tokens", "max_length"
),
description="Aliases: max_length",
examples=[150],
ge=0,
)
min_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("min_tokens", 0),
validation_alias=AliasChoices("min_tokens", "min_length"),
description="Aliases: min_length",
examples=[0],
ge=0,
)
generate_window: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("generate_window"),
examples=[512],
ge=0,
)
stop: Optional[Union[str, List[Union[str, int]]]] = Field(
default_factory=lambda: get_default_sampler_value("stop", []),
validation_alias=AliasChoices("stop", "stop_sequence"),
description="Aliases: stop_sequence",
)
banned_strings: Optional[Union[str, List[str]]] = Field(
default_factory=lambda: get_default_sampler_value("banned_strings", [])
)
banned_tokens: Optional[Union[List[int], str]] = Field(
default_factory=lambda: get_default_sampler_value("banned_tokens", []),
validation_alias=AliasChoices("banned_tokens", "custom_token_bans"),
description="Aliases: custom_token_bans",
examples=[[128, 330]],
)
allowed_tokens: Optional[Union[List[int], str]] = Field(
default_factory=lambda: get_default_sampler_value("allowed_tokens", []),
validation_alias=AliasChoices("allowed_tokens", "allowed_token_ids"),
description="Aliases: allowed_token_ids",
examples=[[128, 330]],
)
token_healing: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("token_healing", False)
)
temperature: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temperature", 1.0),
examples=[1.0],
ge=0,
le=10,
)
temperature_last: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("temperature_last", False),
)
smoothing_factor: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("smoothing_factor", 0.0),
ge=0,
)
top_k: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("top_k", 0),
ge=-1,
)
top_p: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("top_p", 1.0),
ge=0,
le=1,
examples=[1.0],
)
top_a: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("top_a", 0.0)
)
min_p: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("min_p", 0.0)
)
tfs: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("tfs", 1.0),
examples=[1.0],
)
typical: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("typical", 1.0),
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
examples=[1.0],
gt=0,
le=1,
)
skew: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("skew", 0.0),
examples=[0.0],
)
xtc_probability: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("xtc_probability", 0.0),
)
xtc_threshold: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("xtc_threshold", 0.1)
)
frequency_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0),
ge=0,
)
presence_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0),
ge=0,
)
repetition_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
validation_alias=AliasChoices("repetition_penalty", "rep_pen"),
description="Aliases: rep_pen",
examples=[1.0],
gt=0,
)
penalty_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
validation_alias=AliasChoices(
"penalty_range",
"repetition_range",
"repetition_penalty_range",
"rep_pen_range",
),
description=(
"Aliases: repetition_range, repetition_penalty_range, " "rep_pen_range"
),
)
repetition_decay: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
)
dry_multiplier: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("dry_multiplier", 0.0)
)
dry_base: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("dry_base", 0.0)
)
dry_allowed_length: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("dry_allowed_length", 0)
)
dry_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("dry_range", 0),
validation_alias=AliasChoices("dry_range", "dry_penalty_last_n"),
description=("Aliases: dry_penalty_last_n"),
)
dry_sequence_breakers: Optional[Union[str, List[str]]] = Field(
default_factory=lambda: get_default_sampler_value("dry_sequence_breakers", [])
)
mirostat: Optional[bool] = False
mirostat_mode: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
)
mirostat_tau: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5),
examples=[1.5],
)
mirostat_eta: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3),
examples=[0.3],
)
add_bos_token: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("add_bos_token", True)
)
ban_eos_token: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("ban_eos_token", False),
validation_alias=AliasChoices("ban_eos_token", "ignore_eos"),
description="Aliases: ignore_eos",
examples=[False],
)
skip_special_tokens: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("skip_special_tokens", True),
examples=[True],
)
logit_bias: Optional[Dict[int, float]] = Field(
default_factory=lambda: get_default_sampler_value("logit_bias"),
examples=[{"1": 10, "2": 50}],
)
negative_prompt: Optional[str] = Field(
default_factory=lambda: get_default_sampler_value("negative_prompt")
)
json_schema: Optional[object] = Field(
default_factory=lambda: get_default_sampler_value("json_schema"),
)
regex_pattern: Optional[str] = Field(
default_factory=lambda: get_default_sampler_value("regex_pattern"),
)
grammar_string: Optional[str] = Field(
default_factory=lambda: get_default_sampler_value("grammar_string"),
)
speculative_ngram: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("speculative_ngram"),
)
cfg_scale: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
description="Aliases: guidance_scale",
examples=[1.0],
)
max_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("max_temp", 1.0),
validation_alias=AliasChoices("max_temp", "dynatemp_high"),
description="Aliases: dynatemp_high",
examples=[1.0],
ge=0,
)
min_temp: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("min_temp", 1.0),
validation_alias=AliasChoices("min_temp", "dynatemp_low"),
description="Aliases: dynatemp_low",
examples=[1.0],
ge=0,
)
temp_exponent: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temp_exponent", 1.0),
validation_alias=AliasChoices("temp_exponent", "dynatemp_exponent"),
examples=[1.0],
ge=0,
)
@field_validator("top_k", mode="before")
def convert_top_k(cls, v):
"""Fixes instance if Top-K is -1."""
if v == -1:
logger.warning("Provided a top-k value of -1. Converting to 0 instead.")
return 0
return v
@field_validator("stop", "banned_strings", mode="before")
def convert_str_to_list(cls, v):
"""Convert single string to list of strings."""
if isinstance(v, str):
return [v]
return v
@field_validator("banned_tokens", "allowed_tokens", mode="before")
def convert_tokens_to_int_list(cls, v):
"""Convert comma-separated string of numbers to a list of integers."""
if isinstance(v, str):
return [int(x) for x in v.replace(" ", "").split(",") if x.isdigit()]
return v
@field_validator("dry_sequence_breakers", mode="before")
def parse_json_if_needed(cls, v):
"""Parse dry_sequence_breakers string to JSON array."""
if isinstance(v, str) and not v.startswith("["):
v = f"[{v}]"
try:
return json.loads(v) if isinstance(v, str) else v
except Exception:
logger.warning(
"Could not parse DRY sequence breakers. Using an empty array."
)
return [] # Return empty list if parsing fails
@field_validator("mirostat_mode", mode="before")
def convert_mirostat(cls, v, field_info):
"""Mirostat is enabled if mirostat_mode == 2."""
if v == 2:
field_info.data["mirostat"] = True
return v
@model_validator(mode="after")
def after_validate(self):
# FIXME: find a better way to register this
# Maybe make a function to assign values to the
# model if they do not exist post creation
apply_forced_sampler_overrides(self)
if self.min_temp and self.max_temp and self.min_temp > self.max_temp:
raise ValidationError("min temp cannot be more then max temp")
if self.min_tokens and self.max_tokens and self.min_tokens > self.max_tokens:
raise ValidationError("min tokens cannot be more then max tokens")
return self
class SamplerOverridesContainer(BaseModel):
selected_preset: Optional[str] = None
overrides: dict = {}
# Global for default overrides
overrides_container = SamplerOverridesContainer()
def overrides_from_dict(new_overrides: dict):
"""Wrapper function to update sampler overrides"""
if isinstance(new_overrides, dict):
overrides_container.overrides = filter_none_values(new_overrides)
else:
raise TypeError("New sampler overrides must be a dict!")
async def overrides_from_file(preset_name: str):
"""Fetches an override preset from a file"""
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
if preset_path.exists():
overrides_container.selected_preset = preset_path.stem
async with aiofiles.open(preset_path, "r", encoding="utf8") as raw_preset:
contents = await raw_preset.read()
# Create a temporary YAML parser
yaml = YAML(typ="safe")
preset = yaml.load(contents)
overrides_from_dict(preset)
logger.info("Applied sampler overrides from file.")
else:
error_message = (
f'Sampler override file named "{preset_name}" was not found. '
+ "Make sure it's located in the sampler_overrides folder."
)
raise FileNotFoundError(error_message)
def get_all_presets():
"""Fetches all sampler override presets from the overrides directory"""
override_directory = pathlib.Path("sampler_overrides")
preset_files = [file.stem for file in override_directory.glob("*.yml")]
return preset_files
# TODO: Maybe move these into the class
# Classmethods aren't recognized in pydantic default_factories
def get_default_sampler_value(key, fallback=None):
"""Gets an overridden default sampler value"""
default_value = unwrap(
deepcopy(overrides_container.overrides.get(key, {}).get("override")),
fallback,
)
return default_value
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
"""Forcefully applies overrides if specified by the user"""
for var, value in overrides_container.overrides.items():
override = deepcopy(value.get("override"))
original_value = getattr(params, var, None)
# Force takes precedence over additive
# Additive only works on lists and doesn't remove duplicates
if override:
if unwrap(value.get("force"), False):
setattr(params, var, override)
elif (
unwrap(value.get("additive"), False)
and isinstance(override, list)
and isinstance(original_value, list)
):
setattr(params, var, override + original_value)