Sampling: Copy over iterable overrides
If an override was iterable, any modifications to the returned value would alter the reference to the global storage dict. Therefore, copy the structure if it's an iterable so any modification won't alter the original override. Also apply this for the function that checks for forced overrides. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
0e9385e023
commit
b9fd8555fe
2 changed files with 9 additions and 3 deletions
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
import pathlib
|
||||
import yaml
|
||||
from copy import deepcopy
|
||||
from loguru import logger
|
||||
from pydantic import AliasChoices, BaseModel, Field
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
|
@ -376,14 +377,19 @@ def get_all_presets():
|
|||
def get_default_sampler_value(key, fallback=None):
|
||||
"""Gets an overridden default sampler value"""
|
||||
|
||||
return unwrap(overrides_container.overrides.get(key, {}).get("override"), fallback)
|
||||
default_value = unwrap(
|
||||
deepcopy(overrides_container.overrides.get(key, {}).get("override")),
|
||||
fallback,
|
||||
)
|
||||
|
||||
return default_value
|
||||
|
||||
|
||||
def apply_forced_sampler_overrides(params: BaseSamplerRequest):
|
||||
"""Forcefully applies overrides if specified by the user"""
|
||||
|
||||
for var, value in overrides_container.overrides.items():
|
||||
override = value.get("override")
|
||||
override = deepcopy(value.get("override"))
|
||||
original_value = getattr(params, var, None)
|
||||
|
||||
# Force takes precedence over additive
|
||||
|
|
|
|||
|
|
@ -15,6 +15,6 @@ def coalesce(*args):
|
|||
|
||||
|
||||
def prune_dict(input_dict):
|
||||
"""Trim out instances of None from a dictionary"""
|
||||
"""Trim out instances of None from a dictionary."""
|
||||
|
||||
return {k: v for k, v in input_dict.items() if v is not None}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue