diff --git a/OAI/types/common.py b/OAI/types/common.py index dd25ad6..7ba44c4 100644 --- a/OAI/types/common.py +++ b/OAI/types/common.py @@ -2,7 +2,7 @@ from pydantic import BaseModel, Field from typing import List, Dict, Optional -from common.sampling import CommonSamplerRequest +from common.sampling import BaseSamplerRequest class LogProbs(BaseModel): @@ -49,5 +49,5 @@ class CommonCompletionRequest(BaseSamplerRequest): description="Not parsed. Only used for OAI compliance.", default=None ) - # Generation info (remainder is in CommonSamplerRequest superclass) + # Generation info (remainder is in BaseSamplerRequest superclass) stream: Optional[bool] = False diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 54cb416..a9b1549 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -468,56 +468,9 @@ class ExllamaV2Container: } def check_unsupported_settings(self, **kwargs): - # Warn of unsupported settings if the setting is enabled - if (unwrap(kwargs.get("mirostat"), False)) and not hasattr( - ExLlamaV2Sampler.Settings, "mirostat" - ): - logger.warning( - "Mirostat sampling is not supported by the currently " - "installed ExLlamaV2 version." - ) + """Check and warn the user if a sampler is unsupported. Meant for dev wheels!""" - if (unwrap(kwargs.get("min_p"), 0.0)) not in [0.0, 1.0] and not hasattr( - ExLlamaV2Sampler.Settings, "min_p" - ): - logger.warning( - "Min-P sampling is not supported by the currently " - "installed ExLlamaV2 version." - ) - - if (unwrap(kwargs.get("tfs"), 0.0)) not in [0.0, 1.0] and not hasattr( - ExLlamaV2Sampler.Settings, "tfs" - ): - logger.warning( - "Tail-free sampling (TFS) is not supported by the currently " - "installed ExLlamaV2 version." - ) - - if (unwrap(kwargs.get("temperature_last"), False)) and not hasattr( - ExLlamaV2Sampler.Settings, "temperature_last" - ): - logger.warning( - "Temperature last is not supported by the currently " - "installed ExLlamaV2 version." - ) - - if (unwrap(kwargs.get("top_a"), False)) and not hasattr( - ExLlamaV2Sampler.Settings, "top_a" - ): - logger.warning( - "Top-A is not supported by the currently " - "installed ExLlamaV2 version." - ) - - if (unwrap(kwargs.get("presence_penalty"), 0.0)) != 0.0 and not hasattr( - ExLlamaV2Sampler.Settings, "token_presence_penalty" - ): - logger.warning( - "Presence penalty is not supported by the currently " - "installed ExLlamaV2 version." - ) - - if (unwrap(kwargs.get("max_temp"), 0.0)) > 0.0 and not hasattr( + if kwargs.get("max_temp") > 0.0 and not hasattr( ExLlamaV2Sampler.Settings, "max_temp" ): logger.warning( @@ -597,7 +550,7 @@ class ExllamaV2Container: # Sampler settings gen_settings = ExLlamaV2Sampler.Settings() - # TODO: Migrate settings validation to different function + # Check unsupported settings for dev wheels self.check_unsupported_settings(**kwargs) # Apply settings @@ -646,44 +599,31 @@ class ExllamaV2Container: else: logger.warn( "CFG is currently disabled. " - + "Please reload your model with use_cfg = True.", + "Please reload your model with use_cfg = True.", ) - gen_settings.token_presence_penalty = unwrap( - kwargs.get("presence_penalty"), 0.0 - ) gen_settings.token_repetition_penalty = unwrap( kwargs.get("repetition_penalty"), 1.0 ) + gen_settings.token_frequency_penalty = unwrap( + kwargs.get("frequency_penalty"), 0.0 + ) + gen_settings.token_presence_penalty = unwrap( + kwargs.get("presence_penalty"), 0.0 + ) # Applies for all penalties despite being called token_repetition_range gen_settings.token_repetition_range = unwrap( kwargs.get("penalty_range"), self.config.max_seq_len ) - auto_scale_penalty_range = False - frequency_penalty = unwrap(kwargs.get("frequency_penalty"), 0.0) - if hasattr(gen_settings, "token_frequency_penalty"): - gen_settings.token_frequency_penalty = frequency_penalty - - # Dynamically scale penalty range to output tokens - # Only do this if freq/pres pen is enabled - # and the repetition range is -1 - auto_scale_penalty_range = ( - gen_settings.token_frequency_penalty != 0 - or gen_settings.token_presence_penalty != 0 - ) and gen_settings.token_repetition_range == -1 - elif frequency_penalty != 0.0: - logger.warning( - "Frequency penalty is not supported by the currently " - "installed ExLlamaV2 version." - ) - - # Override the repetition penalty value if it isn't set already - # if the user is on an older exl2 version - if unwrap(gen_settings.token_repetition_penalty, 1.0) == 1.0: - gen_settings.token_repetition_penalty = frequency_penalty - logger.warning("Setting this value to repetition penalty instead.") + # Dynamically scale penalty range to output tokens + # Only do this if freq/pres pen is enabled + # and the repetition range is -1 + auto_scale_penalty_range = ( + gen_settings.token_frequency_penalty != 0 + or gen_settings.token_presence_penalty != 0 + ) and gen_settings.token_repetition_range == -1 # Always make sure the fallback is 0 if range < 0 # It's technically fine to use -1, but this just validates the passed @@ -820,7 +760,7 @@ class ExllamaV2Container: gen_settings.token_repetition_range = generated_tokens # Generate - chunk, eos, tokens = self.generator.stream() + chunk, eos, tokens, _, *extra_parts = self.generator.stream() if token_healing: # Extract healed token diff --git a/common/sampling.py b/common/sampling.py index 3f2da1b..53c7b2e 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -6,14 +6,14 @@ from pydantic import AliasChoices, BaseModel, Field import yaml from common.logger import init_logger -from common.utils import unwrap +from common.utils import unwrap, prune_dict logger = init_logger(__name__) # Common class for sampler params -class SamplerParams(BaseModel): +class BaseSamplerRequest(BaseModel): """Common class for sampler params that are used in APIs""" max_tokens: Optional[int] = Field( @@ -164,7 +164,7 @@ class SamplerParams(BaseModel): if isinstance(self.stop, str): self.stop = [self.stop] - return { + gen_params = { "max_tokens": self.max_tokens, "generate_window": self.generate_window, "stop": self.stop, @@ -196,6 +196,8 @@ class SamplerParams(BaseModel): "negative_prompt": self.negative_prompt, } + return gen_params + # Global for default overrides DEFAULT_OVERRIDES = {} @@ -211,7 +213,7 @@ def set_overrides_from_dict(new_overrides: dict): global DEFAULT_OVERRIDES if isinstance(new_overrides, dict): - DEFAULT_OVERRIDES = new_overrides + DEFAULT_OVERRIDES = prune_dict(new_overrides) else: raise TypeError("New sampler overrides must be a dict!") @@ -243,7 +245,7 @@ def get_default_sampler_value(key, fallback=None): return unwrap(DEFAULT_OVERRIDES.get(key, {}).get("override"), fallback) -def apply_forced_sampler_overrides(params: SamplerParams): +def apply_forced_sampler_overrides(params: BaseSamplerRequest): """Forcefully applies overrides if specified by the user""" for var, value in DEFAULT_OVERRIDES.items(): diff --git a/common/utils.py b/common/utils.py index 2db97e9..3b84a01 100644 --- a/common/utils.py +++ b/common/utils.py @@ -36,7 +36,7 @@ def get_generator_error(message: str): generator_error = TabbyGeneratorError(error=error_message) # Log and send the exception - logger.error(generator_error.error.message) + logger.error(generator_error.error.trace) return get_sse_packet(generator_error.model_dump_json()) @@ -56,3 +56,9 @@ def unwrap(wrapped, default=None): def coalesce(*args): """Coalesce function for multiple unwraps.""" return next((arg for arg in args if arg is not None), None) + + +def prune_dict(input_dict): + """Trim out instances of None from a dictionary""" + + return {k: v for k, v in input_dict.items() if v is not None}