migrate to ruamel.yaml
This commit is contained in:
parent
bb4dd7200e
commit
948fcb7f5b
3 changed files with 69 additions and 70 deletions
|
|
@ -1,14 +1,19 @@
|
|||
import yaml
|
||||
import pathlib
|
||||
from inspect import getdoc
|
||||
from pydantic_core import PydanticUndefined
|
||||
from loguru import logger
|
||||
from textwrap import dedent
|
||||
from typing import Optional
|
||||
from os import getenv
|
||||
from textwrap import dedent
|
||||
from typing import Any, Optional
|
||||
|
||||
from common.utils import unwrap, merge_dicts
|
||||
from common.config_models import BaseConfigModel, TabbyConfigModel
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from pydantic_core import PydanticUndefined
|
||||
from ruamel.yaml import YAML
|
||||
from ruamel.yaml.comments import CommentedMap, CommentedSeq
|
||||
|
||||
from common.config_models import TabbyConfigModel
|
||||
from common.utils import merge_dicts, unwrap
|
||||
|
||||
yaml = YAML()
|
||||
|
||||
|
||||
class TabbyConfig(TabbyConfigModel):
|
||||
|
|
@ -57,7 +62,7 @@ class TabbyConfig(TabbyConfigModel):
|
|||
# try loading from file
|
||||
try:
|
||||
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
||||
cfg = yaml.safe_load(config_file)
|
||||
cfg = yaml.load(config_file)
|
||||
|
||||
# NOTE: Remove migration wrapper after a period of time
|
||||
# load legacy config files
|
||||
|
|
@ -130,7 +135,7 @@ class TabbyConfig(TabbyConfigModel):
|
|||
"""loads config from the provided arguments"""
|
||||
config = {}
|
||||
|
||||
config_override = unwrap(args.get("options", {}).get("config"))
|
||||
config_override = args.get("options", {}).get("config", None)
|
||||
if config_override:
|
||||
logger.info("Config file override detected in args.")
|
||||
config = self._from_file(pathlib.Path(config_override))
|
||||
|
|
@ -166,15 +171,25 @@ class TabbyConfig(TabbyConfigModel):
|
|||
config: TabbyConfig = TabbyConfig()
|
||||
|
||||
|
||||
# TODO: Possibly switch to ruamel.yaml for a more native implementation
|
||||
def generate_config_file(
|
||||
model: BaseConfigModel = None,
|
||||
model: BaseModel = None,
|
||||
filename: str = "config_sample.yml",
|
||||
indentation: int = 2,
|
||||
) -> None:
|
||||
"""Creates a config.yml file from Pydantic models."""
|
||||
|
||||
# Add a cleaned up preamble
|
||||
schema = unwrap(model, TabbyConfigModel())
|
||||
preamble = get_preamble()
|
||||
|
||||
yaml_content = pydantic_model_to_yaml(schema)
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write(preamble)
|
||||
yaml.dump(yaml_content, f)
|
||||
|
||||
|
||||
def get_preamble() -> str:
|
||||
"""Returns the cleaned up preamble for the config file."""
|
||||
preamble = """
|
||||
# Sample YAML file for configuration.
|
||||
# Comment and uncomment values as needed.
|
||||
|
|
@ -184,61 +199,43 @@ def generate_config_file(
|
|||
# Unless specified in the comments, DO NOT put these options in quotes!
|
||||
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n
|
||||
"""
|
||||
return dedent(preamble).lstrip()
|
||||
|
||||
# Trim and cleanup preamble
|
||||
yaml = dedent(preamble).lstrip()
|
||||
|
||||
schema = unwrap(model, TabbyConfigModel())
|
||||
# Function to convert pydantic model to dict with field descriptions as comments
|
||||
def pydantic_model_to_yaml(model: BaseModel) -> CommentedMap:
|
||||
"""
|
||||
Recursively converts a Pydantic model into a CommentedMap,
|
||||
with descriptions as comments in YAML.
|
||||
"""
|
||||
# Create a CommentedMap to hold the output data
|
||||
yaml_data = CommentedMap()
|
||||
|
||||
# TODO: Make the disordered iteration look cleaner
|
||||
iter_once = False
|
||||
for field, field_data in schema.model_fields.items():
|
||||
# Fetch from the existing model class if it's passed
|
||||
# Probably can use this on schema too, but play it safe
|
||||
if model and hasattr(model, field):
|
||||
subfield_model = getattr(model, field)
|
||||
# Loop through all fields in the model
|
||||
for field_name, field_info in model.model_fields.items():
|
||||
value = getattr(model, field_name)
|
||||
|
||||
# If the field is another Pydantic model
|
||||
if isinstance(value, BaseModel):
|
||||
yaml_data[field_name] = pydantic_model_to_yaml(value)
|
||||
# If the field is a list of Pydantic models
|
||||
elif (
|
||||
isinstance(value, list)
|
||||
and len(value) > 0
|
||||
and isinstance(value[0], BaseModel)
|
||||
):
|
||||
yaml_list = CommentedSeq()
|
||||
for item in value:
|
||||
yaml_list.append(pydantic_model_to_yaml(item))
|
||||
yaml_data[field_name] = yaml_list
|
||||
# Otherwise, just assign the value
|
||||
else:
|
||||
subfield_model = field_data.default_factory()
|
||||
yaml_data[field_name] = value
|
||||
|
||||
if not subfield_model._metadata.include_in_config:
|
||||
continue
|
||||
# Add field description as a comment if available
|
||||
if field_info.description:
|
||||
yaml_data.yaml_set_comment_before_after_key(
|
||||
field_name, before=field_info.description
|
||||
)
|
||||
|
||||
# Since the list is out of order with the length
|
||||
# Add newlines from the beginning once one iteration finishes
|
||||
# This is a sanity check for formatting
|
||||
if iter_once:
|
||||
yaml += "\n"
|
||||
else:
|
||||
iter_once = True
|
||||
|
||||
for line in getdoc(subfield_model).splitlines():
|
||||
yaml += f"# {line}\n"
|
||||
|
||||
yaml += f"{field}:\n"
|
||||
|
||||
sub_iter_once = False
|
||||
for subfield, subfield_data in subfield_model.model_fields.items():
|
||||
# Same logic as iter_once
|
||||
if sub_iter_once:
|
||||
yaml += "\n"
|
||||
else:
|
||||
sub_iter_once = True
|
||||
|
||||
# If a value already exists, use it
|
||||
if hasattr(subfield_model, subfield):
|
||||
value = getattr(subfield_model, subfield)
|
||||
elif subfield_data.default_factory:
|
||||
value = subfield_data.default_factory()
|
||||
else:
|
||||
value = subfield_data.default
|
||||
|
||||
value = value if value is not None else ""
|
||||
value = value if value is not PydanticUndefined else ""
|
||||
|
||||
for line in subfield_data.description.splitlines():
|
||||
yaml += f"{' ' * indentation}# {line}\n"
|
||||
|
||||
yaml += f"{' ' * indentation}{subfield}: {value}\n"
|
||||
|
||||
with open(filename, "w") as f:
|
||||
f.write(yaml)
|
||||
return yaml_data
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
"""Common utility functions"""
|
||||
|
||||
from types import NoneType
|
||||
from typing import Type, Union, get_args, get_origin
|
||||
from typing import Dict, Optional, Type, Union, get_args, get_origin, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def unwrap(wrapped, default=None):
|
||||
def unwrap(wrapped: Optional[T], default: T = None) -> T:
|
||||
"""Unwrap function for Optionals."""
|
||||
if wrapped is None:
|
||||
return default
|
||||
|
|
@ -17,13 +19,13 @@ def coalesce(*args):
|
|||
return next((arg for arg in args if arg is not None), None)
|
||||
|
||||
|
||||
def prune_dict(input_dict):
|
||||
def prune_dict(input_dict: Dict) -> Dict:
|
||||
"""Trim out instances of None from a dictionary."""
|
||||
|
||||
return {k: v for k, v in input_dict.items() if v is not None}
|
||||
|
||||
|
||||
def merge_dict(dict1, dict2):
|
||||
def merge_dict(dict1: Dict, dict2: Dict) -> Dict:
|
||||
"""Merge 2 dictionaries"""
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict) and key in dict1 and isinstance(dict1[key], dict):
|
||||
|
|
@ -33,7 +35,7 @@ def merge_dict(dict1, dict2):
|
|||
return dict1
|
||||
|
||||
|
||||
def merge_dicts(*dicts):
|
||||
def merge_dicts(*dicts: Dict) -> Dict:
|
||||
"""Merge an arbitrary amount of dictionaries"""
|
||||
result = {}
|
||||
for dictionary in dicts:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue