API: Add preset listing for sampler overrides
Querying the overrides list endpoint now returns the selected preset and a list of presets to use. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
b4bc941cbe
commit
6f4012d20d
3 changed files with 35 additions and 9 deletions
|
|
@ -313,17 +313,20 @@ class BaseSamplerRequest(BaseModel):
|
|||
return {**gen_params, **kwargs}
|
||||
|
||||
|
||||
class SamplerOverridesContainer(BaseModel):
|
||||
selected_preset: Optional[str] = None
|
||||
overrides: dict = {}
|
||||
|
||||
|
||||
# Global for default overrides
|
||||
overrides = {}
|
||||
overrides_container = SamplerOverridesContainer()
|
||||
|
||||
|
||||
def overrides_from_dict(new_overrides: dict):
|
||||
"""Wrapper function to update sampler overrides"""
|
||||
|
||||
global overrides
|
||||
|
||||
if isinstance(new_overrides, dict):
|
||||
overrides = prune_dict(new_overrides)
|
||||
overrides_container.overrides = prune_dict(new_overrides)
|
||||
else:
|
||||
raise TypeError("New sampler overrides must be a dict!")
|
||||
|
||||
|
|
@ -333,6 +336,7 @@ def overrides_from_file(preset_name: str):
|
|||
|
||||
preset_path = pathlib.Path(f"sampler_overrides/{preset_name}.yml")
|
||||
if preset_path.exists():
|
||||
overrides_container.selected_preset = preset_path.stem
|
||||
with open(preset_path, "r", encoding="utf8") as raw_preset:
|
||||
preset = yaml.safe_load(raw_preset)
|
||||
overrides_from_dict(preset)
|
||||
|
|
@ -347,18 +351,27 @@ def overrides_from_file(preset_name: str):
|
|||
raise FileNotFoundError(error_message)
|
||||
|
||||
|
||||
def get_all_presets():
|
||||
"""Fetches all sampler override presets from the overrides directory"""
|
||||
|
||||
override_directory = pathlib.Path("sampler_overrides")
|
||||
preset_files = map(lambda file: file.stem, override_directory.glob("*.yml"))
|
||||
|
||||
return preset_files
|
||||
|
||||
|
||||
# TODO: Maybe move these into the class
|
||||
# Classmethods aren't recognized in pydantic default_factories
|
||||
def get_default_sampler_value(key, fallback=None):
|
||||
"""Gets an overridden default sampler value"""
|
||||
|
||||
return unwrap(overrides.get(key, {}).get("override"), fallback)
|
||||
return unwrap(overrides_container.overrides.get(key, {}).get("override"), fallback)
|
||||
|
||||
|
||||
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
|
||||
"""Forcefully applies overrides if specified by the user"""
|
||||
|
||||
for var, value in overrides.items():
|
||||
for var, value in overrides_container.overrides.items():
|
||||
override = value.get("override")
|
||||
original_value = getattr(params, var, None)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue