API: Add preset listing for sampler overrides

Querying the overrides list endpoint now returns the selected preset
and a list of presets to use.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-12 01:34:51 -04:00
parent b4bc941cbe
commit 6f4012d20d
3 changed files with 35 additions and 9 deletions

View file

@ -313,17 +313,20 @@ class BaseSamplerRequest(BaseModel):
return {**gen_params, **kwargs}
class SamplerOverridesContainer(BaseModel):
selected_preset: Optional[str] = None
overrides: dict = {}
# Global for default overrides
overrides = {}
overrides_container = SamplerOverridesContainer()
def overrides_from_dict(new_overrides: dict):
"""Wrapper function to update sampler overrides"""
global overrides
if isinstance(new_overrides, dict):
overrides = prune_dict(new_overrides)
overrides_container.overrides = prune_dict(new_overrides)
else:
raise TypeError("New sampler overrides must be a dict!")
@ -333,6 +336,7 @@ def overrides_from_file(preset_name: str):
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
if preset_path.exists():
overrides_container.selected_preset = preset_path.stem
with open(preset_path, "r", encoding="utf8") as raw_preset:
preset = yaml.safe_load(raw_preset)
overrides_from_dict(preset)
@ -347,18 +351,27 @@ def overrides_from_file(preset_name: str):
raise FileNotFoundError(error_message)
def get_all_presets():
"""Fetches all sampler override presets from the overrides directory"""
override_directory = pathlib.Path("sampler_overrides")
preset_files = map(lambda file: file.stem, override_directory.glob("*.yml"))
return preset_files
# TODO: Maybe move these into the class
# Classmethods aren't recognized in pydantic default_factories
def get_default_sampler_value(key, fallback=None):
"""Gets an overridden default sampler value"""
return unwrap(overrides.get(key, {}).get("override"), fallback)
return unwrap(overrides_container.overrides.get(key, {}).get("override"), fallback)
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
"""Forcefully applies overrides if specified by the user"""
for var, value in overrides.items():
for var, value in overrides_container.overrides.items():
override = value.get("override")
original_value = getattr(params, var, None)