API: Add sampler override switching

Allow users to switch the currently overriden samplers via the API
so a restart isn't required to switch the overrides.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-01-24 01:20:58 -05:00 committed by Brian Dashore
parent de0ba7214c
commit b14c5443fd
3 changed files with 87 additions and 6 deletions

View file

@ -166,6 +166,10 @@ class SamplerParams(BaseModel):
DEFAULT_OVERRIDES = {}
def get_sampler_overrides():
return DEFAULT_OVERRIDES
def set_overrides_from_dict(new_overrides: dict):
"""Wrapper function to update sampler overrides"""
@ -174,10 +178,10 @@ def set_overrides_from_dict(new_overrides: dict):
if isinstance(new_overrides, dict):
DEFAULT_OVERRIDES = new_overrides
else:
raise TypeError("new sampler overrides must be a dict!")
raise TypeError("New sampler overrides must be a dict!")
def get_overrides_from_file(preset_name: str):
def set_overrides_from_file(preset_name: str):
"""Fetches an override preset from a file"""
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
@ -188,11 +192,13 @@ def get_overrides_from_file(preset_name: str):
logger.info("Applied sampler overrides from file.")
else:
logger.warn(
f"Sampler override file named \"{preset_name}\" was not found. "
error_message = (
f'Sampler override file named "{preset_name}" was not found. '
+ "Make sure it's located in the sampler_overrides folder."
)
raise FileNotFoundError(error_message)
# TODO: Maybe move these into the class
# Classmethods aren't recognized in pydantic default_factories