Sampling: Copy over iterable overrides

If an override was iterable, any modifications to the returned value
would alter the reference to the global storage dict.

Therefore, copy the structure if it's an iterable so any modification
won't alter the original override. Also apply this for the function
that checks for forced overrides.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-17 21:27:56 -04:00
parent 0e9385e023
commit b9fd8555fe
2 changed files with 9 additions and 3 deletions

View file

@ -2,6 +2,7 @@
import pathlib
import yaml
from copy import deepcopy
from loguru import logger
from pydantic import AliasChoices, BaseModel, Field
from typing import Dict, List, Optional, Union
@ -376,14 +377,19 @@ def get_all_presets():
def get_default_sampler_value(key, fallback=None):
"""Gets an overridden default sampler value"""
return unwrap(overrides_container.overrides.get(key, {}).get("override"), fallback)
default_value = unwrap(
deepcopy(overrides_container.overrides.get(key, {}).get("override")),
fallback,
)
return default_value
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
"""Forcefully applies overrides if specified by the user"""
for var, value in overrides_container.overrides.items():
override = value.get("override")
override = deepcopy(value.get("override"))
original_value = getattr(params, var, None)
# Force takes precedence over additive

View file

@ -15,6 +15,6 @@ def coalesce(*args):
def prune_dict(input_dict):
"""Trim out instances of None from a dictionary"""
"""Trim out instances of None from a dictionary."""
return {k: v for k, v in input_dict.items() if v is not None}