migrate to ruamel.yaml

This commit is contained in:
TerminalMan 2024-09-18 01:06:34 +01:00
parent bb4dd7200e
commit 948fcb7f5b
3 changed files with 69 additions and 70 deletions

View file

@ -1,14 +1,19 @@
import yaml
import pathlib
from inspect import getdoc
from pydantic_core import PydanticUndefined
from loguru import logger
from textwrap import dedent
from typing import Optional
from os import getenv
from textwrap import dedent
from typing import Any, Optional
from common.utils import unwrap, merge_dicts
from common.config_models import BaseConfigModel, TabbyConfigModel
from loguru import logger
from pydantic import BaseModel
from pydantic_core import PydanticUndefined
from ruamel.yaml import YAML
from ruamel.yaml.comments import CommentedMap, CommentedSeq
from common.config_models import TabbyConfigModel
from common.utils import merge_dicts, unwrap
yaml = YAML()
class TabbyConfig(TabbyConfigModel):
@ -57,7 +62,7 @@ class TabbyConfig(TabbyConfigModel):
# try loading from file
try:
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
cfg = yaml.safe_load(config_file)
cfg = yaml.load(config_file)
# NOTE: Remove migration wrapper after a period of time
# load legacy config files
@ -130,7 +135,7 @@ class TabbyConfig(TabbyConfigModel):
"""loads config from the provided arguments"""
config = {}
config_override = unwrap(args.get("options", {}).get("config"))
config_override = args.get("options", {}).get("config", None)
if config_override:
logger.info("Config file override detected in args.")
config = self._from_file(pathlib.Path(config_override))
@ -166,15 +171,25 @@ class TabbyConfig(TabbyConfigModel):
config: TabbyConfig = TabbyConfig()
# TODO: Possibly switch to ruamel.yaml for a more native implementation
def generate_config_file(
model: BaseConfigModel = None,
model: BaseModel = None,
filename: str = "config_sample.yml",
indentation: int = 2,
) -> None:
"""Creates a config.yml file from Pydantic models."""
# Add a cleaned up preamble
schema = unwrap(model, TabbyConfigModel())
preamble = get_preamble()
yaml_content = pydantic_model_to_yaml(schema)
with open(filename, "w") as f:
f.write(preamble)
yaml.dump(yaml_content, f)
def get_preamble() -> str:
"""Returns the cleaned up preamble for the config file."""
preamble = """
# Sample YAML file for configuration.
# Comment and uncomment values as needed.
@ -184,61 +199,43 @@ def generate_config_file(
# Unless specified in the comments, DO NOT put these options in quotes!
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n
"""
return dedent(preamble).lstrip()
# Trim and cleanup preamble
yaml = dedent(preamble).lstrip()
schema = unwrap(model, TabbyConfigModel())
# Function to convert pydantic model to dict with field descriptions as comments
def pydantic_model_to_yaml(model: BaseModel) -> CommentedMap:
"""
Recursively converts a Pydantic model into a CommentedMap,
with descriptions as comments in YAML.
"""
# Create a CommentedMap to hold the output data
yaml_data = CommentedMap()
# TODO: Make the disordered iteration look cleaner
iter_once = False
for field, field_data in schema.model_fields.items():
# Fetch from the existing model class if it's passed
# Probably can use this on schema too, but play it safe
if model and hasattr(model, field):
subfield_model = getattr(model, field)
# Loop through all fields in the model
for field_name, field_info in model.model_fields.items():
value = getattr(model, field_name)
# If the field is another Pydantic model
if isinstance(value, BaseModel):
yaml_data[field_name] = pydantic_model_to_yaml(value)
# If the field is a list of Pydantic models
elif (
isinstance(value, list)
and len(value) > 0
and isinstance(value[0], BaseModel)
):
yaml_list = CommentedSeq()
for item in value:
yaml_list.append(pydantic_model_to_yaml(item))
yaml_data[field_name] = yaml_list
# Otherwise, just assign the value
else:
subfield_model = field_data.default_factory()
yaml_data[field_name] = value
if not subfield_model._metadata.include_in_config:
continue
# Add field description as a comment if available
if field_info.description:
yaml_data.yaml_set_comment_before_after_key(
field_name, before=field_info.description
)
# Since the list is out of order with the length
# Add newlines from the beginning once one iteration finishes
# This is a sanity check for formatting
if iter_once:
yaml += "\n"
else:
iter_once = True
for line in getdoc(subfield_model).splitlines():
yaml += f"# {line}\n"
yaml += f"{field}:\n"
sub_iter_once = False
for subfield, subfield_data in subfield_model.model_fields.items():
# Same logic as iter_once
if sub_iter_once:
yaml += "\n"
else:
sub_iter_once = True
# If a value already exists, use it
if hasattr(subfield_model, subfield):
value = getattr(subfield_model, subfield)
elif subfield_data.default_factory:
value = subfield_data.default_factory()
else:
value = subfield_data.default
value = value if value is not None else ""
value = value if value is not PydanticUndefined else ""
for line in subfield_data.description.splitlines():
yaml += f"{' ' * indentation}# {line}\n"
yaml += f"{' ' * indentation}{subfield}: {value}\n"
with open(filename, "w") as f:
f.write(yaml)
return yaml_data

View file

@ -1,10 +1,12 @@
"""Common utility functions"""
from types import NoneType
from typing import Type, Union, get_args, get_origin
from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar
T = TypeVar("T")
def unwrap(wrapped, default=None):
def unwrap(wrapped: Optional[T], default: T = None) -> T:
"""Unwrap function for Optionals."""
if wrapped is None:
return default
@ -17,13 +19,13 @@ def coalesce(*args):
return next((arg for arg in args if arg is not None), None)
def prune_dict(input_dict):
def prune_dict(input_dict: Dict) -> Dict:
"""Trim out instances of None from a dictionary."""
return {k: v for k, v in input_dict.items() if v is not None}
def merge_dict(dict1, dict2):
def merge_dict(dict1: Dict, dict2: Dict) -> Dict:
"""Merge 2 dictionaries"""
for key, value in dict2.items():
if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict):
@ -33,7 +35,7 @@ def merge_dict(dict1, dict2):
return dict1
def merge_dicts(*dicts):
def merge_dicts(*dicts: Dict) -> Dict:
"""Merge an arbitrary amount of dictionaries"""
result = {}
for dictionary in dicts: