Args: Switch to use model_field for everything

Pydantic provides these helpers. Better to use these instead of
the inspect lib.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-09-12 22:17:51 -04:00
parent 6e935c565e
commit 21747bf9e4
2 changed files with 18 additions and 12 deletions

View file

@ -1,8 +1,7 @@
"""Argparser for overriding config values"""
import argparse
from typing import Any, Type, get_origin, get_args, Union, List
from inspect import get_annotations, isclass
from typing import Any, get_origin, get_args, Union, List
from pydantic import BaseModel
@ -42,6 +41,7 @@ def map_pydantic_type_to_argparse(pydantic_type: Any):
Maps Pydantic types to argparse compatible types.
Handles special cases like Union and List.
"""
origin = get_origin(pydantic_type)
# Handle optional types
@ -65,6 +65,7 @@ def add_field_to_group(group, field_name, field_type, field) -> None:
"""
Adds a Pydantic field to an argparse argument group.
"""
arg_type = map_pydantic_type_to_argparse(field_type)
help_text = field.description if field.description else "No description available"
@ -75,23 +76,26 @@ def init_argparser() -> argparse.ArgumentParser:
"""
Initializes an argparse parser based on a Pydantic config schema.
"""
parser = argparse.ArgumentParser(description="TabbyAPI server")
field_type: Union[Type[BaseModel], Any]
# Loop through each top-level field in the config
for field_name, field_type in get_annotations(TabbyConfigModel).items():
for field_name, field_info in TabbyConfigModel.model_fields.items():
field_type = field_info.annotation
group = parser.add_argument_group(
field_name, description=f"Arguments for {field_name}"
)
# Check if the field_type is a Pydantic model
if isclass(field_type):
for sub_field_name, sub_field_type in get_annotations(field_type).items():
field = field_type.model_fields[sub_field_name]
add_field_to_group(group, sub_field_name, sub_field_type, field)
if issubclass(field_type, BaseModel):
for sub_field_name, sub_field_info in field_type.model_fields.items():
sub_field_name = sub_field_name.replace("_", "-")
sub_field_type = sub_field_info.annotation
add_field_to_group(
group, sub_field_name, sub_field_type, sub_field_info
)
else:
# Handle cases where the field_type is not a Pydantic mode
field_name = field_name.replace("_", "-")
arg_type = map_pydantic_type_to_argparse(field_type)
group.add_argument(
f"--{field_name}", type=arg_type, help=f"Argument for {field_name}"

View file

@ -4,7 +4,7 @@ from typing import List, Optional, Union
from common.utils import unwrap
class ConfigConfig(BaseModel):
class ConfigOverrideConfig(BaseModel):
config: Optional[str] = Field(
None, description=("Path to an overriding config.yml file")
)
@ -279,7 +279,9 @@ class EmbeddingsConfig(BaseModel):
class TabbyConfigModel(BaseModel):
config: ConfigConfig = Field(default_factory=ConfigConfig.model_construct)
config: ConfigOverrideConfig = Field(
default_factory=ConfigOverrideConfig.model_construct
)
network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct)
logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct)
model: ModelConfig = Field(default_factory=ModelConfig.model_construct)