From b9fd8555fec484f07b43b91efb824c8b42854928 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 17 May 2024 21:27:56 -0400 Subject: [PATCH] Sampling: Copy over iterable overrides If an override was iterable, any modifications to the returned value would alter the reference to the global storage dict. Therefore, copy the structure if it's an iterable so any modification won't alter the original override. Also apply this for the function that checks for forced overrides. Signed-off-by: kingbri --- common/sampling.py | 10 ++++++++-- common/utils.py | 2 +- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/common/sampling.py b/common/sampling.py index 0d808be..bee4040 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -2,6 +2,7 @@ import pathlib import yaml +from copy import deepcopy from loguru import logger from pydantic import AliasChoices, BaseModel, Field from typing import Dict, List, Optional, Union @@ -376,14 +377,19 @@ def get_all_presets(): def get_default_sampler_value(key, fallback=None): """Gets an overridden default sampler value""" - return unwrap(overrides_container.overrides.get(key, {}).get("override"), fallback) + default_value = unwrap( + deepcopy(overrides_container.overrides.get(key, {}).get("override")), + fallback, + ) + + return default_value def apply_forced_sampler_overrides(params: BaseSamplerRequest): """Forcefully applies overrides if specified by the user""" for var, value in overrides_container.overrides.items(): - override = value.get("override") + override = deepcopy(value.get("override")) original_value = getattr(params, var, None) # Force takes precedence over additive diff --git a/common/utils.py b/common/utils.py index 079a380..6787f39 100644 --- a/common/utils.py +++ b/common/utils.py @@ -15,6 +15,6 @@ def coalesce(*args): def prune_dict(input_dict): - """Trim out instances of None from a dictionary""" + """Trim out instances of None from a dictionary.""" return {k: v for k, v in input_dict.items() if v is not None}