Tree: Unify sampler parameters and add override support

Unify API sampler params into a superclass which should make them
easier to manage and inherit generic functions from.

Not all frontends expose all sampling parameters due to connections
with OAI (that handles sampling themselves with the exception of
a few sliders).

Add the ability for the user to customize fallback parameters from
server-side.

In addition, parameters can be forced to a certain value server-side
in case the repo automatically sets other sampler values in the
background that the user doesn't want.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-01-21 23:34:44 -05:00 committed by Brian Dashore
parent 78f920eeda
commit 6c30f24c83
7 changed files with 337 additions and 86 deletions

4
.gitignore vendored
View file

@ -192,3 +192,7 @@ templates/*
!templates/place_your_templates_here.txt
!templates/alpaca.jinja
!templates/chatml.jinja
# Sampler overrides folder
sampler_overrides/*
!sampler_overrides/sample_preset.yml

View file

@ -1,6 +1,8 @@
""" Common types for OAI. """
from pydantic import BaseModel, Field, AliasChoices
from typing import List, Dict, Optional, Union
from pydantic import BaseModel, Field
from typing import List, Dict, Optional
from common.sampling import SamplerParams
class LogProbs(BaseModel):
@ -20,7 +22,7 @@ class UsageStats(BaseModel):
total_tokens: int
class CommonCompletionRequest(BaseModel):
class CommonCompletionRequest(SamplerParams):
"""Represents a common completion request."""
# Model information
@ -47,87 +49,5 @@ class CommonCompletionRequest(BaseModel):
description="Not parsed. Only used for OAI compliance.", default=None
)
# Generation info
# seed: Optional[int] = -1
# Generation info (remainder is in SamplerParams superclass)
stream: Optional[bool] = False
stop: Optional[Union[str, List[str]]] = []
# Default to 150 as 16 makes no sense as a default
max_tokens: Optional[int] = 150
# Sampling params
token_healing: Optional[bool] = False
temperature: Optional[float] = 1.0
temperature_last: Optional[bool] = False
top_k: Optional[int] = 0
top_p: Optional[float] = 1.0
top_a: Optional[float] = 0.0
min_p: Optional[float] = 0.0
tfs: Optional[float] = 1.0
frequency_penalty: Optional[float] = 0.0
presence_penalty: Optional[float] = 0.0
repetition_penalty: Optional[float] = 1.0
repetition_decay: Optional[int] = 0
mirostat_mode: Optional[int] = 0
mirostat_tau: Optional[float] = 1.5
mirostat_eta: Optional[float] = 0.1
add_bos_token: Optional[bool] = True
ban_eos_token: Optional[bool] = False
logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]])
negative_prompt: Optional[str] = None
# Aliased variables
typical: Optional[float] = Field(
default=1.0,
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
)
penalty_range: Optional[int] = Field(
default=-1,
validation_alias=AliasChoices(
"penalty_range",
"repetition_range",
"repetition_penalty_range",
),
description="Aliases: repetition_range, repetition_penalty_range",
)
cfg_scale: Optional[float] = Field(
default=1.0,
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
description="Aliases: guidance_scale",
)
def to_gen_params(self):
"""Converts to internal generation parameters."""
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]
return {
"stop": self.stop,
"max_tokens": self.max_tokens,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
"logit_bias": self.logit_bias,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"top_k": self.top_k,
"top_p": self.top_p,
"top_a": self.top_a,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"repetition_penalty": self.repetition_penalty,
"penalty_range": self.penalty_range,
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta,
"cfg_scale": self.cfg_scale,
"negative_prompt": self.negative_prompt,
}

View file

@ -56,6 +56,11 @@ def override_config_from_args(args: dict):
}
def get_sampling_config():
"""Returns the sampling parameter config from the global config"""
return unwrap(GLOBAL_CONFIG.get("sampling"), {})
def get_model_config():
"""Returns the model config from the global config"""
return unwrap(GLOBAL_CONFIG.get("model"), {})

212
common/sampling.py Normal file
View file

@ -0,0 +1,212 @@
"""Common functions for sampling parameters"""
import pathlib
from typing import Dict, List, Optional, Union
from pydantic import AliasChoices, BaseModel, Field
import yaml
from common.logger import init_logger
from common.utils import unwrap
logger = init_logger(__name__)
# Common class for sampler params
class SamplerParams(BaseModel):
"""Common class for sampler params that are used in APIs"""
max_tokens: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("max_tokens", 150)
)
stop: Optional[Union[str, List[str]]] = Field(
default_factory=lambda: get_default_sampler_value("stop", [])
)
token_healing: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("token_healing", False)
)
temperature: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("temperature", 1.0)
)
temperature_last: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("temperature_last", False)
)
top_k: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("top_k", 0)
)
top_p: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("top_p", 1.0)
)
top_a: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("top_a", 0.0)
)
min_p: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("min_p", 0.0)
)
tfs: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("tfs", 0.0)
)
frequency_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0)
)
presence_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0)
)
repetition_penalty: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0)
)
repetition_decay: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("repetition_decay", 0)
)
mirostat_mode: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("mirostat_mode", 0)
)
mirostat_tau: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5)
)
mirostat_eta: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3)
)
add_bos_token: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("add_bos_token", True)
)
ban_eos_token: Optional[bool] = Field(
default_factory=lambda: get_default_sampler_value("ban_eos_token", False)
)
logit_bias: Optional[Dict[int, float]] = Field(
default_factory=lambda: get_default_sampler_value("logit_bias"),
examples=[[{"1": 10}]],
)
negative_prompt: Optional[str] = Field(
default_factory=lambda: get_default_sampler_value("negative_prompt")
)
# Aliased variables
typical: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("typical", 1.0),
validation_alias=AliasChoices("typical", "typical_p"),
description="Aliases: typical_p",
)
penalty_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
validation_alias=AliasChoices(
"penalty_range",
"repetition_range",
"repetition_penalty_range",
),
description="Aliases: repetition_range, repetition_penalty_range",
)
cfg_scale: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0),
validation_alias=AliasChoices("cfg_scale", "guidance_scale"),
description="Aliases: guidance_scale",
)
def to_gen_params(self):
"""Converts samplers to internal generation params"""
# Add forced overrides if present
apply_forced_sampler_overrides(self)
# Convert stop to an array of strings
if isinstance(self.stop, str):
self.stop = [self.stop]
return {
"stop": self.stop,
"max_tokens": self.max_tokens,
"add_bos_token": self.add_bos_token,
"ban_eos_token": self.ban_eos_token,
"token_healing": self.token_healing,
"logit_bias": self.logit_bias,
"temperature": self.temperature,
"temperature_last": self.temperature_last,
"top_k": self.top_k,
"top_p": self.top_p,
"top_a": self.top_a,
"typical": self.typical,
"min_p": self.min_p,
"tfs": self.tfs,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"repetition_penalty": self.repetition_penalty,
"penalty_range": self.penalty_range,
"repetition_decay": self.repetition_decay,
"mirostat": self.mirostat_mode == 2,
"mirostat_tau": self.mirostat_tau,
"mirostat_eta": self.mirostat_eta,
"cfg_scale": self.cfg_scale,
"negative_prompt": self.negative_prompt,
}
# Global for default overrides
DEFAULT_OVERRIDES = {}
def set_overrides_from_dict(new_overrides: dict):
"""Wrapper function to update sampler overrides"""
global DEFAULT_OVERRIDES
if isinstance(new_overrides, dict):
DEFAULT_OVERRIDES = new_overrides
else:
raise TypeError("new sampler overrides must be a dict!")
def get_overrides_from_file(preset_name: str):
"""Fetches an override preset from a file"""
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
if preset_path.exists():
with open(preset_path, "r", encoding="utf8") as raw_preset:
preset = yaml.safe_load(raw_preset)
set_overrides_from_dict(preset)
logger.info("Applied sampler overrides from file.")
else:
logger.warn(
f"Sampler override file named \"{preset_name}\" was not found. "
+ "Make sure it's located in the sampler_overrides folder."
)
# TODO: Maybe move these into the class
# Classmethods aren't recognized in pydantic default_factories
def get_default_sampler_value(key, fallback=None):
"""Gets an overridden default sampler value"""
return unwrap(DEFAULT_OVERRIDES.get(key, {}).get("override"), fallback)
def apply_forced_sampler_overrides(params: SamplerParams):
"""Forcefully applies overrides if specified by the user"""
for var, value in DEFAULT_OVERRIDES.items():
override = value.get("override")
force = unwrap(value.get("force"), False)
if force and override:
setattr(params, var, override)

View file

@ -27,6 +27,14 @@ logging:
# Enable generation parameter logging (default: False)
generation_params: False
# Options for sampling
sampling:
# Override preset name. Find this in the sampler-overrides folder (default: None)
# This overrides default fallbacks for sampler values that are passed to the API
# Server-side overrides are NOT needed by default
# WARNING: Using this can result in a generation speed penalty
#override_preset:
# Options for model overrides and loading
model:
# Overrides the directory to look for models (default: models)

View file

@ -16,6 +16,7 @@ from backends.exllamav2.model import ExllamaV2Container
from common.args import convert_args_to_dict, init_argparser
from common.auth import check_admin_key, check_api_key, load_auth_keys
from common.config import (
get_sampling_config,
override_config_from_args,
read_config_from_file,
get_gen_logging_config,
@ -25,6 +26,7 @@ from common.config import (
get_network_config,
)
from common.generators import call_with_semaphore, generate_with_semaphore
from common.sampling import get_overrides_from_file
from common.templating import get_all_templates, get_prompt_from_template
from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap
from common.logger import init_logger
@ -522,6 +524,12 @@ def entrypoint(args: Optional[dict] = None):
gen_logging.broadcast_status()
# Set sampler parameter overrides if provided
sampling_config = get_sampling_config()
sampling_override_preset = sampling_config.get("override_preset")
if sampling_override_preset:
get_overrides_from_file(sampling_override_preset)
# If an initial model name is specified, create a container
# and load the model
model_config = get_model_config()

View file

@ -0,0 +1,94 @@
# Sample YAML file for override presets.
# Each block corresponds to a sampler fallback override. Remove ones that you don't need.
# "force" always overrides the sampler to the specified value.
# For example, a top-p override of 1.5 with force = true will make every API request have a top_p value of 1.5
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.
# TODO: Improve documentation for each field
# MARK: Misc generation parameters
max_tokens:
override: 150
force: false
stop:
override: []
force: false
token_healing:
override: false
force: false
# MARK: Temperature
temperature:
override: 1.0
force: false
temperature_last:
override: false
force: false
# MARK: Alphabet soup
top_k:
override: 0
force: false
top_p:
override: 1.0
force: false
top_a:
override: 0.0
force: false
min_p:
override: 0.0
force: false
tfs:
override: 0.0
force: false
typical:
override: 1.0
force: false
# MARK: Penalty settings
frequency_penalty:
override: 0.0
force: false
presence_penalty:
override: 0.0
force: false
repetition_penalty:
override: 1.0
force: false
repetition_decay:
override: 0
force: false
penalty_range:
override: -1
force: false
# MARK: Mirostat
mirostat_mode:
override: 0
force: false
mirostat_tau:
override: 1.5
force: false
mirostat_eta:
override: 0.3
force: false
# MARK: Token options
add_bos_token:
override: true
force: false
ban_eos_token:
override: false
force: false
logit_bias:
override:
force: false
# MARK: CFG scale
cfg_scale:
override: 1.0
force: false
negative_prompt:
override:
force: false