Sampling: Cleanup and update

Cleanup how overrides are handled, class naming, and adopt exllamav2's
model class to enforce latest stable version methods rather than
adding multiple backwards compatability checks.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-02-01 12:58:55 -05:00
parent 2ea063cea9
commit b827bcbb44
4 changed files with 34 additions and 86 deletions

View file

@ -6,14 +6,14 @@ from pydantic import AliasChoices, BaseModel, Field
import yaml
from common.logger import init_logger
from common.utils import unwrap
from common.utils import unwrap, prune_dict
logger = init_logger(__name__)
# Common class for sampler params
class SamplerParams(BaseModel):
class BaseSamplerRequest(BaseModel):
"""Common class for sampler params that are used in APIs"""
max_tokens: Optional[int] = Field(
@ -164,7 +164,7 @@ class SamplerParams(BaseModel):
if isinstance(self.stop, str):
self.stop = [self.stop]
return {
gen_params = {
"max_tokens": self.max_tokens,
"generate_window": self.generate_window,
"stop": self.stop,
@ -196,6 +196,8 @@ class SamplerParams(BaseModel):
"negative_prompt": self.negative_prompt,
}
return gen_params
# Global for default overrides
DEFAULT_OVERRIDES = {}
@ -211,7 +213,7 @@ def set_overrides_from_dict(new_overrides: dict):
global DEFAULT_OVERRIDES
if isinstance(new_overrides, dict):
DEFAULT_OVERRIDES = new_overrides
DEFAULT_OVERRIDES = prune_dict(new_overrides)
else:
raise TypeError("New sampler overrides must be a dict!")
@ -243,7 +245,7 @@ def get_default_sampler_value(key, fallback=None):
return unwrap(DEFAULT_OVERRIDES.get(key, {}).get("override"), fallback)
def apply_forced_sampler_overrides(params: SamplerParams):
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
"""Forcefully applies overrides if specified by the user"""
for var, value in DEFAULT_OVERRIDES.items():

View file

@ -36,7 +36,7 @@ def get_generator_error(message: str):
generator_error = TabbyGeneratorError(error=error_message)
# Log and send the exception
logger.error(generator_error.error.message)
logger.error(generator_error.error.trace)
return get_sse_packet(generator_error.model_dump_json())
@ -56,3 +56,9 @@ def unwrap(wrapped, default=None):
def coalesce(*args):
"""Coalesce function for multiple unwraps."""
return next((arg for arg in args if arg is not None), None)
def prune_dict(input_dict):
"""Trim out instances of None from a dictionary"""
return {k: v for k, v in input_dict.items() if v is not None}