Args: Switch to use model_field for everything
Pydantic provides these helpers. Better to use these instead of the inspect lib. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
6e935c565e
commit
21747bf9e4
2 changed files with 18 additions and 12 deletions
|
|
@ -1,8 +1,7 @@
|
|||
"""Argparser for overriding config values"""
|
||||
|
||||
import argparse
|
||||
from typing import Any, Type, get_origin, get_args, Union, List
|
||||
from inspect import get_annotations, isclass
|
||||
from typing import Any, get_origin, get_args, Union, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -42,6 +41,7 @@ def map_pydantic_type_to_argparse(pydantic_type: Any):
|
|||
Maps Pydantic types to argparse compatible types.
|
||||
Handles special cases like Union and List.
|
||||
"""
|
||||
|
||||
origin = get_origin(pydantic_type)
|
||||
|
||||
# Handle optional types
|
||||
|
|
@ -65,6 +65,7 @@ def add_field_to_group(group, field_name, field_type, field) -> None:
|
|||
"""
|
||||
Adds a Pydantic field to an argparse argument group.
|
||||
"""
|
||||
|
||||
arg_type = map_pydantic_type_to_argparse(field_type)
|
||||
help_text = field.description if field.description else "No description available"
|
||||
|
||||
|
|
@ -75,23 +76,26 @@ def init_argparser() -> argparse.ArgumentParser:
|
|||
"""
|
||||
Initializes an argparse parser based on a Pydantic config schema.
|
||||
"""
|
||||
|
||||
parser = argparse.ArgumentParser(description="TabbyAPI server")
|
||||
|
||||
field_type: Union[Type[BaseModel], Any]
|
||||
|
||||
# Loop through each top-level field in the config
|
||||
for field_name, field_type in get_annotations(TabbyConfigModel).items():
|
||||
for field_name, field_info in TabbyConfigModel.model_fields.items():
|
||||
field_type = field_info.annotation
|
||||
group = parser.add_argument_group(
|
||||
field_name, description=f"Arguments for {field_name}"
|
||||
)
|
||||
|
||||
# Check if the field_type is a Pydantic model
|
||||
if isclass(field_type):
|
||||
for sub_field_name, sub_field_type in get_annotations(field_type).items():
|
||||
field = field_type.model_fields[sub_field_name]
|
||||
add_field_to_group(group, sub_field_name, sub_field_type, field)
|
||||
if issubclass(field_type, BaseModel):
|
||||
for sub_field_name, sub_field_info in field_type.model_fields.items():
|
||||
sub_field_name = sub_field_name.replace("_", "-")
|
||||
sub_field_type = sub_field_info.annotation
|
||||
add_field_to_group(
|
||||
group, sub_field_name, sub_field_type, sub_field_info
|
||||
)
|
||||
else:
|
||||
# Handle cases where the field_type is not a Pydantic mode
|
||||
field_name = field_name.replace("_", "-")
|
||||
arg_type = map_pydantic_type_to_argparse(field_type)
|
||||
group.add_argument(
|
||||
f"--{field_name}", type=arg_type, help=f"Argument for {field_name}"
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import List, Optional, Union
|
|||
from common.utils import unwrap
|
||||
|
||||
|
||||
class ConfigConfig(BaseModel):
|
||||
class ConfigOverrideConfig(BaseModel):
|
||||
config: Optional[str] = Field(
|
||||
None, description=("Path to an overriding config.yml file")
|
||||
)
|
||||
|
|
@ -279,7 +279,9 @@ class EmbeddingsConfig(BaseModel):
|
|||
|
||||
|
||||
class TabbyConfigModel(BaseModel):
|
||||
config: ConfigConfig = Field(default_factory=ConfigConfig.model_construct)
|
||||
config: ConfigOverrideConfig = Field(
|
||||
default_factory=ConfigOverrideConfig.model_construct
|
||||
)
|
||||
network: NetworkConfig = Field(default_factory=NetworkConfig.model_construct)
|
||||
logging: LoggingConfig = Field(default_factory=LoggingConfig.model_construct)
|
||||
model: ModelConfig = Field(default_factory=ModelConfig.model_construct)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue