Config: Allow existing values to get included in generated file
Allows for generation from an existing config file. Primarily used for migration purposes. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
7f03003437
commit
81ae461eb8
3 changed files with 20 additions and 8 deletions
|
|
@ -4,7 +4,7 @@ import argparse
|
|||
from pydantic import BaseModel
|
||||
|
||||
from common.config_models import TabbyConfigModel
|
||||
from common.utils import is_list_type, unwrap_optional
|
||||
from common.utils import is_list_type, unwrap_optional_type
|
||||
|
||||
|
||||
def add_field_to_group(group, field_name, field_type, field) -> None:
|
||||
|
|
@ -32,7 +32,7 @@ def init_argparser() -> argparse.ArgumentParser:
|
|||
|
||||
# Loop through each top-level field in the config
|
||||
for field_name, field_info in TabbyConfigModel.model_fields.items():
|
||||
field_type = unwrap_optional(field_info.annotation)
|
||||
field_type = unwrap_optional_type(field_info.annotation)
|
||||
group = parser.add_argument_group(
|
||||
field_name, description=f"Arguments for {field_name}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
from inspect import getdoc
|
||||
from pathlib import Path
|
||||
from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
|
||||
from pydantic_core import PydanticUndefined
|
||||
from textwrap import dedent
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic_core import PydanticUndefined
|
||||
from common.utils import unwrap
|
||||
|
||||
CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"]
|
||||
|
||||
|
|
@ -488,12 +489,17 @@ def generate_config_file(
|
|||
# You can use https://www.yamllint.com/ if you want to check your YAML formatting.\n
|
||||
""")
|
||||
|
||||
schema = model if model else TabbyConfigModel()
|
||||
schema = unwrap(model, TabbyConfigModel())
|
||||
|
||||
# TODO: Make the disordered iteration look cleaner
|
||||
iter_once = False
|
||||
for field, field_data in schema.model_fields.items():
|
||||
subfield_model = field_data.default_factory()
|
||||
# Fetch from the existing model class if it's passed
|
||||
# Probably can use this on schema too, but play it safe
|
||||
if model:
|
||||
subfield_model = getattr(model, field, None)
|
||||
else:
|
||||
subfield_model = field_data.default_factory()
|
||||
|
||||
if not subfield_model._metadata.include_in_config:
|
||||
continue
|
||||
|
|
@ -519,7 +525,10 @@ def generate_config_file(
|
|||
else:
|
||||
sub_iter_once = True
|
||||
|
||||
if subfield_data.default_factory:
|
||||
# If a value already exists, use it
|
||||
if hasattr(subfield_model, subfield):
|
||||
value = getattr(subfield_model, subfield)
|
||||
elif subfield_data.default_factory:
|
||||
value = subfield_data.default_factory()
|
||||
else:
|
||||
value = subfield_data.default
|
||||
|
|
|
|||
|
|
@ -62,8 +62,11 @@ def is_list_type(type_hint) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def unwrap_optional(type_hint) -> Type:
|
||||
"""unwrap Optional[type] annotations"""
|
||||
def unwrap_optional_type(type_hint) -> Type:
|
||||
"""
|
||||
Unwrap Optional[type] annotations.
|
||||
This is not the same as unwrap.
|
||||
"""
|
||||
|
||||
if get_origin(type_hint) is Union:
|
||||
args = get_args(type_hint)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue