Config: Allow existing values to get included in generated file

Allows for generation from an existing config file. Primarily used
for migration purposes.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-09-16 12:19:58 -04:00
parent 7f03003437
commit 81ae461eb8
3 changed files with 20 additions and 8 deletions

View file

@ -4,7 +4,7 @@ import argparse
from pydantic import BaseModel
from common.config_models import TabbyConfigModel
from common.utils import is_list_type, unwrap_optional
from common.utils import is_list_type, unwrap_optional_type
def add_field_to_group(group, field_name, field_type, field) -> None:
@ -32,7 +32,7 @@ def init_argparser() -> argparse.ArgumentParser:
# Loop through each top-level field in the config
for field_name, field_info in TabbyConfigModel.model_fields.items():
field_type = unwrap_optional(field_info.annotation)
field_type = unwrap_optional_type(field_info.annotation)
group = parser.add_argument_group(
field_name, description=f"Arguments for {field_name}"
)

View file

@ -1,10 +1,11 @@
from inspect import getdoc
from pathlib import Path
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
from pydantic_core import PydanticUndefined
from textwrap import dedent
from typing import List, Literal, Optional, Union
from pydantic_core import PydanticUndefined
from common.utils import unwrap
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
@ -488,12 +489,17 @@ def generate_config_file(
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n
""")
schema = model if model else TabbyConfigModel()
schema = unwrap(model, TabbyConfigModel())
# TODO: Make the disordered iteration look cleaner
iter_once = False
for field, field_data in schema.model_fields.items():
subfield_model = field_data.default_factory()
# Fetch from the existing model class if it's passed
# Probably can use this on schema too, but play it safe
if model:
subfield_model = getattr(model, field, None)
else:
subfield_model = field_data.default_factory()
if not subfield_model._metadata.include_in_config:
continue
@ -519,7 +525,10 @@ def generate_config_file(
else:
sub_iter_once = True
if subfield_data.default_factory:
# If a value already exists, use it
if hasattr(subfield_model, subfield):
value = getattr(subfield_model, subfield)
elif subfield_data.default_factory:
value = subfield_data.default_factory()
else:
value = subfield_data.default

View file

@ -62,8 +62,11 @@ def is_list_type(type_hint) -> bool:
return False
def unwrap_optional(type_hint) -> Type:
"""unwrap Optional[type] annotations"""
def unwrap_optional_type(type_hint) -> Type:
"""
Unwrap Optional[type] annotations.
This is not the same as unwrap.
"""
if get_origin(type_hint) is Union:
args = get_args(type_hint)