diff --git a/common/model.py b/common/model.py index 44a35fa..1935025 100644 --- a/common/model.py +++ b/common/model.py @@ -76,6 +76,9 @@ async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs): if not override_config_path.exists(): return kwargs + # Initialize overrides dict + overrides = {} + async with aiofiles.open( override_config_path, "r", encoding="utf8" ) as override_config_file: @@ -83,18 +86,25 @@ async def apply_inline_overrides(model_dir: pathlib.Path, **kwargs): # Create a temporary YAML parser yaml = YAML(typ="safe") - override_args = unwrap(yaml.load(contents), {}) + inline_config = unwrap(yaml.load(contents), {}) + + # Check for inline model overrides + model_inline_config = unwrap(inline_config.get("model"), {}) + if model_inline_config: + overrides = {**model_inline_config} + else: + logger.warning( + "Cannot find inline model overrides. " + "Make sure they are nested under a \"model:\" key" + ) # Merge draft overrides beforehand - draft_override_args = unwrap(override_args.get("draft_model"), {}) - if draft_override_args: - kwargs["draft_model"] = { - **draft_override_args, - **unwrap(kwargs.get("draft_model"), {}), - } + draft_inline_config = unwrap(inline_config.get("draft_model"), {}) + if draft_inline_config: + overrides["draft_model"] = {**draft_inline_config} # Merge the override and model kwargs - merged_kwargs = {**override_args, **kwargs} + merged_kwargs = {**overrides, **kwargs} return merged_kwargs