From 6c30f24c832b96fe93f0271dfe3e1815fb29d237 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 21 Jan 2024 23:34:44 -0500 Subject: [PATCH] Tree: Unify sampler parameters and add override support Unify API sampler params into a superclass which should make them easier to manage and inherit generic functions from. Not all frontends expose all sampling parameters due to connections with OAI (that handles sampling themselves with the exception of a few sliders). Add the ability for the user to customize fallback parameters from server-side. In addition, parameters can be forced to a certain value server-side in case the repo automatically sets other sampler values in the background that the user doesn't want. Signed-off-by: kingbri --- .gitignore | 4 + OAI/types/common.py | 92 +----------- common/config.py | 5 + common/sampling.py | 212 ++++++++++++++++++++++++++++ config_sample.yml | 8 ++ main.py | 8 ++ sampler_overrides/sample_preset.yml | 94 ++++++++++++ 7 files changed, 337 insertions(+), 86 deletions(-) create mode 100644 common/sampling.py create mode 100644 sampler_overrides/sample_preset.yml diff --git a/.gitignore b/.gitignore index 8dde2c2..f77b7f9 100644 --- a/.gitignore +++ b/.gitignore @@ -192,3 +192,7 @@ templates/* !templates/place_your_templates_here.txt !templates/alpaca.jinja !templates/chatml.jinja + +# Sampler overrides folder +sampler_overrides/* +!sampler_overrides/sample_preset.yml diff --git a/OAI/types/common.py b/OAI/types/common.py index df54349..e90919e 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -1,6 +1,8 @@ """ Common types for OAI. """ -from pydantic import BaseModel, Field, AliasChoices -from typing import List, Dict, Optional, Union +from pydantic import BaseModel, Field +from typing import List, Dict, Optional + +from common.sampling import SamplerParams class LogProbs(BaseModel): @@ -20,7 +22,7 @@ class UsageStats(BaseModel): total_tokens: int -class CommonCompletionRequest(BaseModel): +class CommonCompletionRequest(SamplerParams): """Represents a common completion request.""" # Model information @@ -47,87 +49,5 @@ class CommonCompletionRequest(BaseModel): description="Not parsed. Only used for OAI compliance.", default=None ) - # Generation info - # seed: Optional[int] = -1 + # Generation info (remainder is in SamplerParams superclass) stream: Optional[bool] = False - stop: Optional[Union[str, List[str]]] = [] - - # Default to 150 as 16 makes no sense as a default - max_tokens: Optional[int] = 150 - - # Sampling params - token_healing: Optional[bool] = False - temperature: Optional[float] = 1.0 - temperature_last: Optional[bool] = False - top_k: Optional[int] = 0 - top_p: Optional[float] = 1.0 - top_a: Optional[float] = 0.0 - min_p: Optional[float] = 0.0 - tfs: Optional[float] = 1.0 - frequency_penalty: Optional[float] = 0.0 - presence_penalty: Optional[float] = 0.0 - repetition_penalty: Optional[float] = 1.0 - repetition_decay: Optional[int] = 0 - mirostat_mode: Optional[int] = 0 - mirostat_tau: Optional[float] = 1.5 - mirostat_eta: Optional[float] = 0.1 - add_bos_token: Optional[bool] = True - ban_eos_token: Optional[bool] = False - logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]]) - negative_prompt: Optional[str] = None - - # Aliased variables - typical: Optional[float] = Field( - default=1.0, - validation_alias=AliasChoices("typical", "typical_p"), - description="Aliases: typical_p", - ) - - penalty_range: Optional[int] = Field( - default=-1, - validation_alias=AliasChoices( - "penalty_range", - "repetition_range", - "repetition_penalty_range", - ), - description="Aliases: repetition_range, repetition_penalty_range", - ) - - cfg_scale: Optional[float] = Field( - default=1.0, - validation_alias=AliasChoices("cfg_scale", "guidance_scale"), - description="Aliases: guidance_scale", - ) - - def to_gen_params(self): - """Converts to internal generation parameters.""" - # Convert stop to an array of strings - if isinstance(self.stop, str): - self.stop = [self.stop] - - return { - "stop": self.stop, - "max_tokens": self.max_tokens, - "add_bos_token": self.add_bos_token, - "ban_eos_token": self.ban_eos_token, - "token_healing": self.token_healing, - "logit_bias": self.logit_bias, - "temperature": self.temperature, - "temperature_last": self.temperature_last, - "top_k": self.top_k, - "top_p": self.top_p, - "top_a": self.top_a, - "typical": self.typical, - "min_p": self.min_p, - "tfs": self.tfs, - "frequency_penalty": self.frequency_penalty, - "presence_penalty": self.presence_penalty, - "repetition_penalty": self.repetition_penalty, - "penalty_range": self.penalty_range, - "repetition_decay": self.repetition_decay, - "mirostat": self.mirostat_mode == 2, - "mirostat_tau": self.mirostat_tau, - "mirostat_eta": self.mirostat_eta, - "cfg_scale": self.cfg_scale, - "negative_prompt": self.negative_prompt, - } diff --git a/common/config.py b/common/config.py index e46be62..9a4b7b1 100644 --- a/common/config.py +++ b/common/config.py @@ -56,6 +56,11 @@ def override_config_from_args(args: dict): } +def get_sampling_config(): + """Returns the sampling parameter config from the global config""" + return unwrap(GLOBAL_CONFIG.get("sampling"), {}) + + def get_model_config(): """Returns the model config from the global config""" return unwrap(GLOBAL_CONFIG.get("model"), {}) diff --git a/common/sampling.py b/common/sampling.py new file mode 100644 index 0000000..01d11f1 --- /dev/null +++ b/common/sampling.py @@ -0,0 +1,212 @@ +"""Common functions for sampling parameters""" + +import pathlib +from typing import Dict, List, Optional, Union +from pydantic import AliasChoices, BaseModel, Field +import yaml + +from common.logger import init_logger +from common.utils import unwrap + + +logger = init_logger(__name__) + + +# Common class for sampler params +class SamplerParams(BaseModel): + """Common class for sampler params that are used in APIs""" + + max_tokens: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("max_tokens", 150) + ) + + stop: Optional[Union[str, List[str]]] = Field( + default_factory=lambda: get_default_sampler_value("stop", []) + ) + + token_healing: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("token_healing", False) + ) + + temperature: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("temperature", 1.0) + ) + + temperature_last: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("temperature_last", False) + ) + + top_k: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("top_k", 0) + ) + + top_p: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("top_p", 1.0) + ) + + top_a: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("top_a", 0.0) + ) + + min_p: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("min_p", 0.0) + ) + + tfs: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("tfs", 0.0) + ) + + frequency_penalty: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0) + ) + + presence_penalty: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0) + ) + + repetition_penalty: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0) + ) + + repetition_decay: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("repetition_decay", 0) + ) + + mirostat_mode: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("mirostat_mode", 0) + ) + + mirostat_tau: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5) + ) + + mirostat_eta: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3) + ) + + add_bos_token: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("add_bos_token", True) + ) + + ban_eos_token: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("ban_eos_token", False) + ) + + logit_bias: Optional[Dict[int, float]] = Field( + default_factory=lambda: get_default_sampler_value("logit_bias"), + examples=[[{"1": 10}]], + ) + + negative_prompt: Optional[str] = Field( + default_factory=lambda: get_default_sampler_value("negative_prompt") + ) + + # Aliased variables + typical: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("typical", 1.0), + validation_alias=AliasChoices("typical", "typical_p"), + description="Aliases: typical_p", + ) + + penalty_range: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("penalty_range", -1), + validation_alias=AliasChoices( + "penalty_range", + "repetition_range", + "repetition_penalty_range", + ), + description="Aliases: repetition_range, repetition_penalty_range", + ) + + cfg_scale: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0), + validation_alias=AliasChoices("cfg_scale", "guidance_scale"), + description="Aliases: guidance_scale", + ) + + def to_gen_params(self): + """Converts samplers to internal generation params""" + + # Add forced overrides if present + apply_forced_sampler_overrides(self) + + # Convert stop to an array of strings + if isinstance(self.stop, str): + self.stop = [self.stop] + + return { + "stop": self.stop, + "max_tokens": self.max_tokens, + "add_bos_token": self.add_bos_token, + "ban_eos_token": self.ban_eos_token, + "token_healing": self.token_healing, + "logit_bias": self.logit_bias, + "temperature": self.temperature, + "temperature_last": self.temperature_last, + "top_k": self.top_k, + "top_p": self.top_p, + "top_a": self.top_a, + "typical": self.typical, + "min_p": self.min_p, + "tfs": self.tfs, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "repetition_penalty": self.repetition_penalty, + "penalty_range": self.penalty_range, + "repetition_decay": self.repetition_decay, + "mirostat": self.mirostat_mode == 2, + "mirostat_tau": self.mirostat_tau, + "mirostat_eta": self.mirostat_eta, + "cfg_scale": self.cfg_scale, + "negative_prompt": self.negative_prompt, + } + + +# Global for default overrides +DEFAULT_OVERRIDES = {} + + +def set_overrides_from_dict(new_overrides: dict): + """Wrapper function to update sampler overrides""" + + global DEFAULT_OVERRIDES + + if isinstance(new_overrides, dict): + DEFAULT_OVERRIDES = new_overrides + else: + raise TypeError("new sampler overrides must be a dict!") + + +def get_overrides_from_file(preset_name: str): + """Fetches an override preset from a file""" + + preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml") + if preset_path.exists(): + with open(preset_path, "r", encoding="utf8") as raw_preset: + preset = yaml.safe_load(raw_preset) + set_overrides_from_dict(preset) + + logger.info("Applied sampler overrides from file.") + else: + logger.warn( + f"Sampler override file named \"{preset_name}\" was not found. " + + "Make sure it's located in the sampler_overrides folder." + ) + + +# TODO: Maybe move these into the class +# Classmethods aren't recognized in pydantic default_factories +def get_default_sampler_value(key, fallback=None): + """Gets an overridden default sampler value""" + + return unwrap(DEFAULT_OVERRIDES.get(key, {}).get("override"), fallback) + + +def apply_forced_sampler_overrides(params: SamplerParams): + """Forcefully applies overrides if specified by the user""" + + for var, value in DEFAULT_OVERRIDES.items(): + override = value.get("override") + force = unwrap(value.get("force"), False) + if force and override: + setattr(params, var, override) diff --git a/config_sample.yml b/config_sample.yml index 7f88d94..89368ac 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -27,6 +27,14 @@ logging: # Enable generation parameter logging (default: False) generation_params: False +# Options for sampling +sampling: + # Override preset name. Find this in the sampler-overrides folder (default: None) + # This overrides default fallbacks for sampler values that are passed to the API + # Server-side overrides are NOT needed by default + # WARNING: Using this can result in a generation speed penalty + #override_preset: + # Options for model overrides and loading model: # Overrides the directory to look for models (default: models) diff --git a/main.py b/main.py index 0ba9650..218d9c0 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,7 @@ from backends.exllamav2.model import ExllamaV2Container from common.args import convert_args_to_dict, init_argparser from common.auth import check_admin_key, check_api_key, load_auth_keys from common.config import ( + get_sampling_config, override_config_from_args, read_config_from_file, get_gen_logging_config, @@ -25,6 +26,7 @@ from common.config import ( get_network_config, ) from common.generators import call_with_semaphore, generate_with_semaphore +from common.sampling import get_overrides_from_file from common.templating import get_all_templates, get_prompt_from_template from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap from common.logger import init_logger @@ -522,6 +524,12 @@ def entrypoint(args: Optional[dict] = None): gen_logging.broadcast_status() + # Set sampler parameter overrides if provided + sampling_config = get_sampling_config() + sampling_override_preset = sampling_config.get("override_preset") + if sampling_override_preset: + get_overrides_from_file(sampling_override_preset) + # If an initial model name is specified, create a container # and load the model model_config = get_model_config() diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml new file mode 100644 index 0000000..9c661a1 --- /dev/null +++ b/sampler_overrides/sample_preset.yml @@ -0,0 +1,94 @@ +# Sample YAML file for override presets. +# Each block corresponds to a sampler fallback override. Remove ones that you don't need. +# "force" always overrides the sampler to the specified value. +# For example, a top-p override of 1.5 with force = true will make every API request have a top_p value of 1.5 + +# You can use https://www.yamllint.com/ if you want to check your YAML formatting. + +# TODO: Improve documentation for each field + +# MARK: Misc generation parameters +max_tokens: + override: 150 + force: false +stop: + override: [] + force: false +token_healing: + override: false + force: false + +# MARK: Temperature +temperature: + override: 1.0 + force: false +temperature_last: + override: false + force: false + +# MARK: Alphabet soup +top_k: + override: 0 + force: false +top_p: + override: 1.0 + force: false +top_a: + override: 0.0 + force: false +min_p: + override: 0.0 + force: false +tfs: + override: 0.0 + force: false +typical: + override: 1.0 + force: false + +# MARK: Penalty settings +frequency_penalty: + override: 0.0 + force: false +presence_penalty: + override: 0.0 + force: false +repetition_penalty: + override: 1.0 + force: false +repetition_decay: + override: 0 + force: false +penalty_range: + override: -1 + force: false + +# MARK: Mirostat +mirostat_mode: + override: 0 + force: false +mirostat_tau: + override: 1.5 + force: false +mirostat_eta: + override: 0.3 + force: false + +# MARK: Token options +add_bos_token: + override: true + force: false +ban_eos_token: + override: false + force: false +logit_bias: + override: + force: false + +# MARK: CFG scale +cfg_scale: + override: 1.0 + force: false +negative_prompt: + override: + force: false