Model: Migrate inline config to new format
This matches config.yml and all model overrides should go under the "model" block. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
a3c780ae58
commit
322f9b773a
1 changed files with 18 additions and 8 deletions
|
|
@ -76,6 +76,9 @@ async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
|||
if not override_config_path.exists():
|
||||
return kwargs
|
||||
|
||||
# Initialize overrides dict
|
||||
overrides = {}
|
||||
|
||||
async with aiofiles.open(
|
||||
override_config_path, "r", encoding="utf8"
|
||||
) as override_config_file:
|
||||
|
|
@ -83,18 +86,25 @@ async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs):
|
|||
|
||||
# Create a temporary YAML parser
|
||||
yaml = YAML(typ="safe")
|
||||
override_args = unwrap(yaml.load(contents), {})
|
||||
inline_config = unwrap(yaml.load(contents), {})
|
||||
|
||||
# Check for inline model overrides
|
||||
model_inline_config = unwrap(inline_config.get("model"), {})
|
||||
if model_inline_config:
|
||||
overrides = {**model_inline_config}
|
||||
else:
|
||||
logger.warning(
|
||||
"Cannot find inline model overrides. "
|
||||
"Make sure they are nested under a \"model:\" key"
|
||||
)
|
||||
|
||||
# Merge draft overrides beforehand
|
||||
draft_override_args = unwrap(override_args.get("draft_model"), {})
|
||||
if draft_override_args:
|
||||
kwargs["draft_model"] = {
|
||||
**draft_override_args,
|
||||
**unwrap(kwargs.get("draft_model"), {}),
|
||||
}
|
||||
draft_inline_config = unwrap(inline_config.get("draft_model"), {})
|
||||
if draft_inline_config:
|
||||
overrides["draft_model"] = {**draft_inline_config}
|
||||
|
||||
# Merge the override and model kwargs
|
||||
merged_kwargs = {**override_args, **kwargs}
|
||||
merged_kwargs = {**overrides, **kwargs}
|
||||
return merged_kwargs
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue