diff --git a/common/args.py b/common/args.py index ffecaaa..c4fbd46 100644 --- a/common/args.py +++ b/common/args.py @@ -5,6 +5,7 @@ from typing import get_origin, get_args, Optional, Union, List from pydantic import BaseModel from common.tabby_config import config + def str_to_bool(value): """Converts a string into a boolean value""" @@ -40,34 +41,43 @@ def init_argparser(): for field_name, field_type in config.__annotations__.items(): # Get the sub-model type (e.g., ModelA, ModelB) sub_model = field_type.__base__ - + # Create argument group for the sub-model - group = parser.add_argument_group(field_name, description=f"Arguments for {field_name}") - + group = parser.add_argument_group( + field_name, description=f"Arguments for {field_name}" + ) + # Loop through each field in the sub-model (e.g., ModelA, ModelB) for sub_field_name, sub_field_type in field_type.__annotations__.items(): field = field_type.__fields__[sub_field_name] - help_text = field.description if field.description else "No description available" + help_text = ( + field.description if field.description else "No description available" + ) # Handle Optional types or other generic types origin = get_origin(sub_field_type) if origin is Union: # Check if the type is Union (which includes Optional) - sub_field_type = next(t for t in get_args(sub_field_type) if t is not type(None)) - elif origin is List : sub_field_type = get_args(sub_field_type)[0] - + sub_field_type = next( + t for t in get_args(sub_field_type) if t is not type(None) + ) + elif origin is List: + sub_field_type = get_args(sub_field_type)[0] # Map Pydantic types to argparse types print(sub_field_type, type(sub_field_type)) - if isinstance(sub_field_type, type) and issubclass(sub_field_type, (int, float, str, bool)): + if isinstance(sub_field_type, type) and issubclass( + sub_field_type, (int, float, str, bool) + ): arg_type = sub_field_type else: arg_type = str # Default to string for unknown types - + # Add the argument for each field in the sub-model group.add_argument(f"--{sub_field_name}", type=arg_type, help=help_text) return parser + def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser): """Broad conversion of surface level arg groups to dictionaries""" @@ -81,4 +91,4 @@ def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentPars arg_groups[group.title] = group_dict - return arg_groups \ No newline at end of file + return arg_groups diff --git a/main.py b/main.py index 89cc6bf..7385a1d 100644 --- a/main.py +++ b/main.py @@ -69,12 +69,14 @@ async def entrypoint_async(): model_path = pathlib.Path(config.model.model_dir) model_path = model_path / model_name - await model.load_model(model_path.resolve(), **config.model) + # TODO: remove model_dump() + await model.load_model(model_path.resolve(), **config.model.model_dump()) # Load loras after loading the model if config.lora.loras: lora_dir = pathlib.Path(config.lora.lora_dir) - await model.container.load_loras(lora_dir.resolve(), **config.lora) + # TODO: remove model_dump() + await model.container.load_loras(lora_dir.resolve(), **config.lora.model_dump()) # If an initial embedding model name is specified, create a separate container # and load the model @@ -84,7 +86,8 @@ async def entrypoint_async(): embedding_model_path = embedding_model_path / embedding_model_name try: - await model.load_embedding_model(embedding_model_path, **config.embeddings) + # TODO: remove model_dump() + await model.load_embedding_model(embedding_model_path, **config.embeddings.model_dump()) except ImportError as ex: logger.error(ex.msg)