Sampling: Cleanup and update
Cleanup how overrides are handled, class naming, and adopt exllamav2's model class to enforce latest stable version methods rather than adding multiple backwards compatability checks. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
2ea063cea9
commit
b827bcbb44
4 changed files with 34 additions and 86 deletions
|
|
@ -2,7 +2,7 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from common.sampling import CommonSamplerRequest
|
||||
from common.sampling import BaseSamplerRequest
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
|
|
@ -49,5 +49,5 @@ class CommonCompletionRequest(BaseSamplerRequest):
|
|||
description="Not parsed. Only used for OAI compliance.", default=None
|
||||
)
|
||||
|
||||
# Generation info (remainder is in CommonSamplerRequest superclass)
|
||||
# Generation info (remainder is in BaseSamplerRequest superclass)
|
||||
stream: Optional[bool] = False
|
||||
|
|
|
|||
|
|
@ -468,56 +468,9 @@ class ExllamaV2Container:
|
|||
}
|
||||
|
||||
def check_unsupported_settings(self, **kwargs):
|
||||
# Warn of unsupported settings if the setting is enabled
|
||||
if (unwrap(kwargs.get("mirostat"), False)) and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "mirostat"
|
||||
):
|
||||
logger.warning(
|
||||
"Mirostat sampling is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
"""Check and warn the user if a sampler is unsupported. Meant for dev wheels!"""
|
||||
|
||||
if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "min_p"
|
||||
):
|
||||
logger.warning(
|
||||
"Min-P sampling is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "tfs"
|
||||
):
|
||||
logger.warning(
|
||||
"Tail-free sampling (TFS) is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "temperature_last"
|
||||
):
|
||||
logger.warning(
|
||||
"Temperature last is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("top_a"), False)) and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "top_a"
|
||||
):
|
||||
logger.warning(
|
||||
"Top-A is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("presence_penalty"), 0.0)) != 0.0 and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "token_presence_penalty"
|
||||
):
|
||||
logger.warning(
|
||||
"Presence penalty is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
if (unwrap(kwargs.get("max_temp"), 0.0)) > 0.0 and not hasattr(
|
||||
if kwargs.get("max_temp") > 0.0 and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "max_temp"
|
||||
):
|
||||
logger.warning(
|
||||
|
|
@ -597,7 +550,7 @@ class ExllamaV2Container:
|
|||
# Sampler settings
|
||||
gen_settings = ExLlamaV2Sampler.Settings()
|
||||
|
||||
# TODO: Migrate settings validation to different function
|
||||
# Check unsupported settings for dev wheels
|
||||
self.check_unsupported_settings(**kwargs)
|
||||
|
||||
# Apply settings
|
||||
|
|
@ -646,44 +599,31 @@ class ExllamaV2Container:
|
|||
else:
|
||||
logger.warn(
|
||||
"CFG is currently disabled. "
|
||||
+ "Please reload your model with use_cfg = True.",
|
||||
"Please reload your model with use_cfg = True.",
|
||||
)
|
||||
|
||||
gen_settings.token_presence_penalty = unwrap(
|
||||
kwargs.get("presence_penalty"), 0.0
|
||||
)
|
||||
gen_settings.token_repetition_penalty = unwrap(
|
||||
kwargs.get("repetition_penalty"), 1.0
|
||||
)
|
||||
gen_settings.token_frequency_penalty = unwrap(
|
||||
kwargs.get("frequency_penalty"), 0.0
|
||||
)
|
||||
gen_settings.token_presence_penalty = unwrap(
|
||||
kwargs.get("presence_penalty"), 0.0
|
||||
)
|
||||
|
||||
# Applies for all penalties despite being called token_repetition_range
|
||||
gen_settings.token_repetition_range = unwrap(
|
||||
kwargs.get("penalty_range"), self.config.max_seq_len
|
||||
)
|
||||
auto_scale_penalty_range = False
|
||||
|
||||
frequency_penalty = unwrap(kwargs.get("frequency_penalty"), 0.0)
|
||||
if hasattr(gen_settings, "token_frequency_penalty"):
|
||||
gen_settings.token_frequency_penalty = frequency_penalty
|
||||
|
||||
# Dynamically scale penalty range to output tokens
|
||||
# Only do this if freq/pres pen is enabled
|
||||
# and the repetition range is -1
|
||||
auto_scale_penalty_range = (
|
||||
gen_settings.token_frequency_penalty != 0
|
||||
or gen_settings.token_presence_penalty != 0
|
||||
) and gen_settings.token_repetition_range == -1
|
||||
elif frequency_penalty != 0.0:
|
||||
logger.warning(
|
||||
"Frequency penalty is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
# Override the repetition penalty value if it isn't set already
|
||||
# if the user is on an older exl2 version
|
||||
if unwrap(gen_settings.token_repetition_penalty, 1.0) == 1.0:
|
||||
gen_settings.token_repetition_penalty = frequency_penalty
|
||||
logger.warning("Setting this value to repetition penalty instead.")
|
||||
# Dynamically scale penalty range to output tokens
|
||||
# Only do this if freq/pres pen is enabled
|
||||
# and the repetition range is -1
|
||||
auto_scale_penalty_range = (
|
||||
gen_settings.token_frequency_penalty != 0
|
||||
or gen_settings.token_presence_penalty != 0
|
||||
) and gen_settings.token_repetition_range == -1
|
||||
|
||||
# Always make sure the fallback is 0 if range < 0
|
||||
# It's technically fine to use -1, but this just validates the passed
|
||||
|
|
@ -820,7 +760,7 @@ class ExllamaV2Container:
|
|||
gen_settings.token_repetition_range = generated_tokens
|
||||
|
||||
# Generate
|
||||
chunk, eos, tokens = self.generator.stream()
|
||||
chunk, eos, tokens, _, *extra_parts = self.generator.stream()
|
||||
|
||||
if token_healing:
|
||||
# Extract healed token
|
||||
|
|
|
|||
|
|
@ -6,14 +6,14 @@ from pydantic import AliasChoices, BaseModel, Field
|
|||
import yaml
|
||||
|
||||
from common.logger import init_logger
|
||||
from common.utils import unwrap
|
||||
from common.utils import unwrap, prune_dict
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
# Common class for sampler params
|
||||
class SamplerParams(BaseModel):
|
||||
class BaseSamplerRequest(BaseModel):
|
||||
"""Common class for sampler params that are used in APIs"""
|
||||
|
||||
max_tokens: Optional[int] = Field(
|
||||
|
|
@ -164,7 +164,7 @@ class SamplerParams(BaseModel):
|
|||
if isinstance(self.stop, str):
|
||||
self.stop = [self.stop]
|
||||
|
||||
return {
|
||||
gen_params = {
|
||||
"max_tokens": self.max_tokens,
|
||||
"generate_window": self.generate_window,
|
||||
"stop": self.stop,
|
||||
|
|
@ -196,6 +196,8 @@ class SamplerParams(BaseModel):
|
|||
"negative_prompt": self.negative_prompt,
|
||||
}
|
||||
|
||||
return gen_params
|
||||
|
||||
|
||||
# Global for default overrides
|
||||
DEFAULT_OVERRIDES = {}
|
||||
|
|
@ -211,7 +213,7 @@ def set_overrides_from_dict(new_overrides: dict):
|
|||
global DEFAULT_OVERRIDES
|
||||
|
||||
if isinstance(new_overrides, dict):
|
||||
DEFAULT_OVERRIDES = new_overrides
|
||||
DEFAULT_OVERRIDES = prune_dict(new_overrides)
|
||||
else:
|
||||
raise TypeError("New sampler overrides must be a dict!")
|
||||
|
||||
|
|
@ -243,7 +245,7 @@ def get_default_sampler_value(key, fallback=None):
|
|||
return unwrap(DEFAULT_OVERRIDES.get(key, {}).get("override"), fallback)
|
||||
|
||||
|
||||
def apply_forced_sampler_overrides(params: SamplerParams):
|
||||
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
|
||||
"""Forcefully applies overrides if specified by the user"""
|
||||
|
||||
for var, value in DEFAULT_OVERRIDES.items():
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def get_generator_error(message: str):
|
|||
generator_error = TabbyGeneratorError(error=error_message)
|
||||
|
||||
# Log and send the exception
|
||||
logger.error(generator_error.error.message)
|
||||
logger.error(generator_error.error.trace)
|
||||
return get_sse_packet(generator_error.model_dump_json())
|
||||
|
||||
|
||||
|
|
@ -56,3 +56,9 @@ def unwrap(wrapped, default=None):
|
|||
def coalesce(*args):
|
||||
"""Coalesce function for multiple unwraps."""
|
||||
return next((arg for arg in args if arg is not None), None)
|
||||
|
||||
|
||||
def prune_dict(input_dict):
|
||||
"""Trim out instances of None from a dictionary"""
|
||||
|
||||
return {k: v for k, v in input_dict.items() if v is not None}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue