diff --git a/.gitignore b/.gitignore index 8dde2c2..f77b7f9 100644 --- a/.gitignore +++ b/.gitignore @@ -192,3 +192,7 @@ templates/* !templates/place_your_templates_here.txt !templates/alpaca.jinja !templates/chatml.jinja + +# Sampler overrides folder +sampler_overrides/* +!sampler_overrides/sample_preset.yml diff --git a/OAI/types/common.py b/OAI/types/common.py index df54349..e90919e 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -1,6 +1,8 @@ """ Common types for OAI. """ -from pydantic import BaseModel, Field, AliasChoices -from typing import List, Dict, Optional, Union +from pydantic import BaseModel, Field +from typing import List, Dict, Optional + +from common.sampling import SamplerParams class LogProbs(BaseModel): @@ -20,7 +22,7 @@ class UsageStats(BaseModel): total_tokens: int -class CommonCompletionRequest(BaseModel): +class CommonCompletionRequest(SamplerParams): """Represents a common completion request.""" # Model information @@ -47,87 +49,5 @@ class CommonCompletionRequest(BaseModel): description="Not parsed. Only used for OAI compliance.", default=None ) - # Generation info - # seed: Optional[int] = -1 + # Generation info (remainder is in SamplerParams superclass) stream: Optional[bool] = False - stop: Optional[Union[str, List[str]]] = [] - - # Default to 150 as 16 makes no sense as a default - max_tokens: Optional[int] = 150 - - # Sampling params - token_healing: Optional[bool] = False - temperature: Optional[float] = 1.0 - temperature_last: Optional[bool] = False - top_k: Optional[int] = 0 - top_p: Optional[float] = 1.0 - top_a: Optional[float] = 0.0 - min_p: Optional[float] = 0.0 - tfs: Optional[float] = 1.0 - frequency_penalty: Optional[float] = 0.0 - presence_penalty: Optional[float] = 0.0 - repetition_penalty: Optional[float] = 1.0 - repetition_decay: Optional[int] = 0 - mirostat_mode: Optional[int] = 0 - mirostat_tau: Optional[float] = 1.5 - mirostat_eta: Optional[float] = 0.1 - add_bos_token: Optional[bool] = True - ban_eos_token: Optional[bool] = False - logit_bias: Optional[Dict[int, float]] = Field(default=None, examples=[[{"1": 10}]]) - negative_prompt: Optional[str] = None - - # Aliased variables - typical: Optional[float] = Field( - default=1.0, - validation_alias=AliasChoices("typical", "typical_p"), - description="Aliases: typical_p", - ) - - penalty_range: Optional[int] = Field( - default=-1, - validation_alias=AliasChoices( - "penalty_range", - "repetition_range", - "repetition_penalty_range", - ), - description="Aliases: repetition_range, repetition_penalty_range", - ) - - cfg_scale: Optional[float] = Field( - default=1.0, - validation_alias=AliasChoices("cfg_scale", "guidance_scale"), - description="Aliases: guidance_scale", - ) - - def to_gen_params(self): - """Converts to internal generation parameters.""" - # Convert stop to an array of strings - if isinstance(self.stop, str): - self.stop = [self.stop] - - return { - "stop": self.stop, - "max_tokens": self.max_tokens, - "add_bos_token": self.add_bos_token, - "ban_eos_token": self.ban_eos_token, - "token_healing": self.token_healing, - "logit_bias": self.logit_bias, - "temperature": self.temperature, - "temperature_last": self.temperature_last, - "top_k": self.top_k, - "top_p": self.top_p, - "top_a": self.top_a, - "typical": self.typical, - "min_p": self.min_p, - "tfs": self.tfs, - "frequency_penalty": self.frequency_penalty, - "presence_penalty": self.presence_penalty, - "repetition_penalty": self.repetition_penalty, - "penalty_range": self.penalty_range, - "repetition_decay": self.repetition_decay, - "mirostat": self.mirostat_mode == 2, - "mirostat_tau": self.mirostat_tau, - "mirostat_eta": self.mirostat_eta, - "cfg_scale": self.cfg_scale, - "negative_prompt": self.negative_prompt, - } diff --git a/common/config.py b/common/config.py index e46be62..9a4b7b1 100644 --- a/common/config.py +++ b/common/config.py @@ -56,6 +56,11 @@ def override_config_from_args(args: dict): } +def get_sampling_config(): + """Returns the sampling parameter config from the global config""" + return unwrap(GLOBAL_CONFIG.get("sampling"), {}) + + def get_model_config(): """Returns the model config from the global config""" return unwrap(GLOBAL_CONFIG.get("model"), {}) diff --git a/common/sampling.py b/common/sampling.py new file mode 100644 index 0000000..01d11f1 --- /dev/null +++ b/common/sampling.py @@ -0,0 +1,212 @@ +"""Common functions for sampling parameters""" + +import pathlib +from typing import Dict, List, Optional, Union +from pydantic import AliasChoices, BaseModel, Field +import yaml + +from common.logger import init_logger +from common.utils import unwrap + + +logger = init_logger(__name__) + + +# Common class for sampler params +class SamplerParams(BaseModel): + """Common class for sampler params that are used in APIs""" + + max_tokens: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("max_tokens", 150) + ) + + stop: Optional[Union[str, List[str]]] = Field( + default_factory=lambda: get_default_sampler_value("stop", []) + ) + + token_healing: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("token_healing", False) + ) + + temperature: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("temperature", 1.0) + ) + + temperature_last: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("temperature_last", False) + ) + + top_k: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("top_k", 0) + ) + + top_p: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("top_p", 1.0) + ) + + top_a: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("top_a", 0.0) + ) + + min_p: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("min_p", 0.0) + ) + + tfs: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("tfs", 0.0) + ) + + frequency_penalty: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("frequency_penalty", 0.0) + ) + + presence_penalty: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("presence_penalty", 0.0) + ) + + repetition_penalty: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0) + ) + + repetition_decay: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("repetition_decay", 0) + ) + + mirostat_mode: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("mirostat_mode", 0) + ) + + mirostat_tau: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("mirostat_tau", 1.5) + ) + + mirostat_eta: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("mirostat_eta", 0.3) + ) + + add_bos_token: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("add_bos_token", True) + ) + + ban_eos_token: Optional[bool] = Field( + default_factory=lambda: get_default_sampler_value("ban_eos_token", False) + ) + + logit_bias: Optional[Dict[int, float]] = Field( + default_factory=lambda: get_default_sampler_value("logit_bias"), + examples=[[{"1": 10}]], + ) + + negative_prompt: Optional[str] = Field( + default_factory=lambda: get_default_sampler_value("negative_prompt") + ) + + # Aliased variables + typical: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("typical", 1.0), + validation_alias=AliasChoices("typical", "typical_p"), + description="Aliases: typical_p", + ) + + penalty_range: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("penalty_range", -1), + validation_alias=AliasChoices( + "penalty_range", + "repetition_range", + "repetition_penalty_range", + ), + description="Aliases: repetition_range, repetition_penalty_range", + ) + + cfg_scale: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("cfg_scale", 1.0), + validation_alias=AliasChoices("cfg_scale", "guidance_scale"), + description="Aliases: guidance_scale", + ) + + def to_gen_params(self): + """Converts samplers to internal generation params""" + + # Add forced overrides if present + apply_forced_sampler_overrides(self) + + # Convert stop to an array of strings + if isinstance(self.stop, str): + self.stop = [self.stop] + + return { + "stop": self.stop, + "max_tokens": self.max_tokens, + "add_bos_token": self.add_bos_token, + "ban_eos_token": self.ban_eos_token, + "token_healing": self.token_healing, + "logit_bias": self.logit_bias, + "temperature": self.temperature, + "temperature_last": self.temperature_last, + "top_k": self.top_k, + "top_p": self.top_p, + "top_a": self.top_a, + "typical": self.typical, + "min_p": self.min_p, + "tfs": self.tfs, + "frequency_penalty": self.frequency_penalty, + "presence_penalty": self.presence_penalty, + "repetition_penalty": self.repetition_penalty, + "penalty_range": self.penalty_range, + "repetition_decay": self.repetition_decay, + "mirostat": self.mirostat_mode == 2, + "mirostat_tau": self.mirostat_tau, + "mirostat_eta": self.mirostat_eta, + "cfg_scale": self.cfg_scale, + "negative_prompt": self.negative_prompt, + } + + +# Global for default overrides +DEFAULT_OVERRIDES = {} + + +def set_overrides_from_dict(new_overrides: dict): + """Wrapper function to update sampler overrides""" + + global DEFAULT_OVERRIDES + + if isinstance(new_overrides, dict): + DEFAULT_OVERRIDES = new_overrides + else: + raise TypeError("new sampler overrides must be a dict!") + + +def get_overrides_from_file(preset_name: str): + """Fetches an override preset from a file""" + + preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml") + if preset_path.exists(): + with open(preset_path, "r", encoding="utf8") as raw_preset: + preset = yaml.safe_load(raw_preset) + set_overrides_from_dict(preset) + + logger.info("Applied sampler overrides from file.") + else: + logger.warn( + f"Sampler override file named \"{preset_name}\" was not found. " + + "Make sure it's located in the sampler_overrides folder." + ) + + +# TODO: Maybe move these into the class +# Classmethods aren't recognized in pydantic default_factories +def get_default_sampler_value(key, fallback=None): + """Gets an overridden default sampler value""" + + return unwrap(DEFAULT_OVERRIDES.get(key, {}).get("override"), fallback) + + +def apply_forced_sampler_overrides(params: SamplerParams): + """Forcefully applies overrides if specified by the user""" + + for var, value in DEFAULT_OVERRIDES.items(): + override = value.get("override") + force = unwrap(value.get("force"), False) + if force and override: + setattr(params, var, override) diff --git a/config_sample.yml b/config_sample.yml index 7f88d94..89368ac 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -27,6 +27,14 @@ logging: # Enable generation parameter logging (default: False) generation_params: False +# Options for sampling +sampling: + # Override preset name. Find this in the sampler-overrides folder (default: None) + # This overrides default fallbacks for sampler values that are passed to the API + # Server-side overrides are NOT needed by default + # WARNING: Using this can result in a generation speed penalty + #override_preset: + # Options for model overrides and loading model: # Overrides the directory to look for models (default: models) diff --git a/main.py b/main.py index 0ba9650..218d9c0 100644 --- a/main.py +++ b/main.py @@ -16,6 +16,7 @@ from backends.exllamav2.model import ExllamaV2Container from common.args import convert_args_to_dict, init_argparser from common.auth import check_admin_key, check_api_key, load_auth_keys from common.config import ( + get_sampling_config, override_config_from_args, read_config_from_file, get_gen_logging_config, @@ -25,6 +26,7 @@ from common.config import ( get_network_config, ) from common.generators import call_with_semaphore, generate_with_semaphore +from common.sampling import get_overrides_from_file from common.templating import get_all_templates, get_prompt_from_template from common.utils import get_generator_error, get_sse_packet, load_progress, unwrap from common.logger import init_logger @@ -522,6 +524,12 @@ def entrypoint(args: Optional[dict] = None): gen_logging.broadcast_status() + # Set sampler parameter overrides if provided + sampling_config = get_sampling_config() + sampling_override_preset = sampling_config.get("override_preset") + if sampling_override_preset: + get_overrides_from_file(sampling_override_preset) + # If an initial model name is specified, create a container # and load the model model_config = get_model_config() diff --git a/sampler_overrides/sample_preset.yml b/sampler_overrides/sample_preset.yml new file mode 100644 index 0000000..9c661a1 --- /dev/null +++ b/sampler_overrides/sample_preset.yml @@ -0,0 +1,94 @@ +# Sample YAML file for override presets. +# Each block corresponds to a sampler fallback override. Remove ones that you don't need. +# "force" always overrides the sampler to the specified value. +# For example, a top-p override of 1.5 with force = true will make every API request have a top_p value of 1.5 + +# You can use https://www.yamllint.com/ if you want to check your YAML formatting. + +# TODO: Improve documentation for each field + +# MARK: Misc generation parameters +max_tokens: + override: 150 + force: false +stop: + override: [] + force: false +token_healing: + override: false + force: false + +# MARK: Temperature +temperature: + override: 1.0 + force: false +temperature_last: + override: false + force: false + +# MARK: Alphabet soup +top_k: + override: 0 + force: false +top_p: + override: 1.0 + force: false +top_a: + override: 0.0 + force: false +min_p: + override: 0.0 + force: false +tfs: + override: 0.0 + force: false +typical: + override: 1.0 + force: false + +# MARK: Penalty settings +frequency_penalty: + override: 0.0 + force: false +presence_penalty: + override: 0.0 + force: false +repetition_penalty: + override: 1.0 + force: false +repetition_decay: + override: 0 + force: false +penalty_range: + override: -1 + force: false + +# MARK: Mirostat +mirostat_mode: + override: 0 + force: false +mirostat_tau: + override: 1.5 + force: false +mirostat_eta: + override: 0.3 + force: false + +# MARK: Token options +add_bos_token: + override: true + force: false +ban_eos_token: + override: false + force: false +logit_bias: + override: + force: false + +# MARK: CFG scale +cfg_scale: + override: 1.0 + force: false +negative_prompt: + override: + force: false