fix arg parser for dict types

This commit is contained in:
TerminalMan 2024-09-11 16:13:31 +01:00
parent e8fcecd56a
commit 0d7459191c
2 changed files with 48 additions and 24 deletions

View file

@ -34,39 +34,64 @@ def argument_with_auto(value):
) from ex
def map_pydantic_type_to_argparse(pydantic_type):
"""
Maps Pydantic types to argparse compatible types.
Handles special cases like Union and List.
"""
origin = get_origin(pydantic_type)
# Handle optional types
if origin is Union:
# Filter out NoneType
pydantic_type = next(t for t in get_args(pydantic_type) if t is not type(None))
elif origin is List:
pydantic_type = get_args(pydantic_type)[0] # Get the list item type
# Map basic types (int, float, str, bool)
if isinstance(pydantic_type, type) and issubclass(
pydantic_type, (int, float, str, bool)
):
return pydantic_type
return str
def add_field_to_group(group, field_name, field_type, field):
"""
Adds a Pydantic field to an argparse argument group.
"""
arg_type = map_pydantic_type_to_argparse(field_type)
help_text = field.description if field.description else "No description available"
group.add_argument(f"--{field_name}", type=arg_type, help=help_text)
def init_argparser():
"""
Initializes an argparse parser based on a Pydantic config schema.
"""
parser = argparse.ArgumentParser(description="TabbyAPI server")
# Loop through each top-level field in the config
for field_name, field_type in config.__annotations__.items():
group = parser.add_argument_group(
field_name, description=f"Arguments for {field_name}"
)
# Loop through each field in the sub-model
for sub_field_name, sub_field_type in field_type.__annotations__.items():
field = field_type.__fields__[sub_field_name]
help_text = (
field.description if field.description else "No description available"
# Check if the field_type is a Pydantic model
if hasattr(field_type, "__annotations__"):
for sub_field_name, sub_field_type in field_type.__annotations__.items():
field = field_type.__fields__[sub_field_name]
add_field_to_group(group, sub_field_name, sub_field_type, field)
else:
# Handle cases where the field_type is not a Pydantic mode
arg_type = map_pydantic_type_to_argparse(field_type)
group.add_argument(
f"--{field_name}", type=arg_type, help=f"Argument for {field_name}"
)
origin = get_origin(sub_field_type)
if origin is Union:
sub_field_type = next(
t for t in get_args(sub_field_type) if t is not type(None)
)
elif origin is List:
sub_field_type = get_args(sub_field_type)[0]
# Map Pydantic types to argparse types
if isinstance(sub_field_type, type) and issubclass(
sub_field_type, (int, float, str, bool)
):
arg_type = sub_field_type
else:
arg_type = str # Default to string for unknown types
group.add_argument(f"--{sub_field_name}", type=arg_type, help=help_text)
return parser

View file

@ -10,7 +10,6 @@ import common.config_models
class TabbyConfig(tabby_config_model):
# Persistent defaults
# TODO: make this pydantic?
model_defaults: dict = {}