remove private attributes in args

This commit is contained in:
TerminalMan 2024-09-13 00:37:17 +01:00
parent eb5f42c845
commit 6e935c565e

View file

@ -1,8 +1,12 @@
"""Argparser for overriding config values"""
import argparse
from typing import get_origin, get_args, Union, List
from common.tabby_config import config
from typing import Any, Type, get_origin, get_args, Union, List
from inspect import get_annotations, isclass
from pydantic import BaseModel
from common.config_models import TabbyConfigModel
def str_to_bool(value):
@ -33,7 +37,7 @@ def argument_with_auto(value):
) from ex
def map_pydantic_type_to_argparse(pydantic_type):
def map_pydantic_type_to_argparse(pydantic_type: Any):
"""
Maps Pydantic types to argparse compatible types.
Handles special cases like Union and List.
@ -57,7 +61,7 @@ def map_pydantic_type_to_argparse(pydantic_type):
return str
def add_field_to_group(group, field_name, field_type, field):
def add_field_to_group(group, field_name, field_type, field) -> None:
"""
Adds a Pydantic field to an argparse argument group.
"""
@ -67,22 +71,24 @@ def add_field_to_group(group, field_name, field_type, field):
group.add_argument(f"--{field_name}", type=arg_type, help=help_text)
def init_argparser():
def init_argparser() -> argparse.ArgumentParser:
"""
Initializes an argparse parser based on a Pydantic config schema.
"""
parser = argparse.ArgumentParser(description="TabbyAPI server")
field_type: Union[Type[BaseModel], Any]
# Loop through each top-level field in the config
for field_name, field_type in config.__annotations__.items():
for field_name, field_type in get_annotations(TabbyConfigModel).items():
group = parser.add_argument_group(
field_name, description=f"Arguments for {field_name}"
)
# Check if the field_type is a Pydantic model
if hasattr(field_type, "__annotations__"):
for sub_field_name, sub_field_type in field_type.__annotations__.items():
field = field_type.__fields__[sub_field_name]
if isclass(field_type):
for sub_field_name, sub_field_type in get_annotations(field_type).items():
field = field_type.model_fields[sub_field_name]
add_field_to_group(group, sub_field_name, sub_field_type, field)
else:
# Handle cases where the field_type is not a Pydantic mode
@ -94,7 +100,9 @@ def init_argparser():
return parser
def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
def convert_args_to_dict(
args: argparse.Namespace, parser: argparse.ArgumentParser
) -> dict[str, dict[str, Any]]:
"""Broad conversion of surface level arg groups to dictionaries"""
arg_groups = {}