diff --git a/common/args.py b/common/args.py index bd9c67c..7d2427f 100644 --- a/common/args.py +++ b/common/args.py @@ -1,20 +1,10 @@ """Argparser for overriding config values""" import argparse -from typing import Any - from pydantic import BaseModel from common.config_models import TabbyConfigModel - - -def is_list_type(type_hint): - if hasattr(type_hint, "__origin__") and type_hint.__origin__ is list: - return True - if hasattr(type_hint, "__args__"): - # Recursively check for lists inside type arguments - return any(is_list_type(arg) for arg in type_hint.__args__) - return False +from common.utils import is_list_type def add_field_to_group(group, field_name, field_type, field) -> None: @@ -26,6 +16,7 @@ def add_field_to_group(group, field_name, field_type, field) -> None: "help": field.description if field.description else "No description available", } + # If the inner type contains a list, specify argparse as such if is_list_type(field_type): kwargs["nargs"] = "+" @@ -63,7 +54,7 @@ def init_argparser() -> argparse.ArgumentParser: def convert_args_to_dict( args: argparse.Namespace, parser: argparse.ArgumentParser -) -> dict[str, dict[str, Any]]: +) -> dict: """Broad conversion of surface level arg groups to dictionaries""" arg_groups = {} diff --git a/common/config_models.py b/common/config_models.py index 1b371d5..4c12445 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -6,6 +6,8 @@ CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] class ConfigOverrideConfig(BaseModel): + """Model for overriding a provided config file.""" + # TODO: convert this to a pathlib.path? config: Optional[str] = Field( None, description=("Path to an overriding config.yml file") @@ -13,6 +15,9 @@ class ConfigOverrideConfig(BaseModel): class UtilityActions(BaseModel): + """Model used for arg actions.""" + + # YAML export options export_config: Optional[str] = Field( None, description="generate a template config file" ) @@ -20,6 +25,7 @@ class UtilityActions(BaseModel): "config_sample.yml", description="path to export configuration file to" ) + # OpenAPI JSON export options export_openapi: Optional[bool] = Field( False, description="export openapi schema files" ) @@ -29,6 +35,8 @@ class UtilityActions(BaseModel): class NetworkConfig(BaseModel): + """Model for network configuration.""" + host: Optional[str] = Field("127.0.0.1", description=("The IP to host on")) port: Optional[int] = Field(5000, description=("The port to host on")) disable_auth: Optional[bool] = Field( @@ -47,6 +55,8 @@ class NetworkConfig(BaseModel): class LoggingConfig(BaseModel): + """Model for logging configuration.""" + log_prompt: Optional[bool] = Field(False, description=("Enable prompt logging")) log_generation_params: Optional[bool] = Field( False, description=("Enable generation parameter logging") @@ -55,6 +65,8 @@ class LoggingConfig(BaseModel): class ModelConfig(BaseModel): + """Model for LLM configuration.""" + # TODO: convert this to a pathlib.path? model_dir: str = Field( "models", @@ -201,6 +213,8 @@ class ModelConfig(BaseModel): class DraftModelConfig(BaseModel): + """Model for draft LLM model configuration.""" + # TODO: convert this to a pathlib.path? draft_model_dir: Optional[str] = Field( "models", @@ -239,6 +253,8 @@ class DraftModelConfig(BaseModel): class LoraInstanceModel(BaseModel): + """Model representing an instance of a Lora.""" + name: str = Field(..., description=("Name of the LoRA model")) scaling: float = Field( 1.0, @@ -248,6 +264,8 @@ class LoraInstanceModel(BaseModel): class LoraConfig(BaseModel): + """Model for lora configuration.""" + # TODO: convert this to a pathlib.path? lora_dir: Optional[str] = Field( "loras", description=("Directory to look for LoRAs (default: 'loras')") @@ -262,12 +280,16 @@ class LoraConfig(BaseModel): class SamplingConfig(BaseModel): + """Model for sampling (overrides) config.""" + override_preset: Optional[str] = Field( None, description=("Select a sampler override preset") ) class DeveloperConfig(BaseModel): + """Model for developer settings configuration.""" + unsafe_launch: Optional[bool] = Field( False, description=("Skip Exllamav2 version check") ) @@ -290,6 +312,8 @@ class DeveloperConfig(BaseModel): class EmbeddingsConfig(BaseModel): + """Model for embeddings configuration.""" + # TODO: convert this to a pathlib.path? embedding_model_dir: Optional[str] = Field( "models", @@ -310,6 +334,8 @@ class EmbeddingsConfig(BaseModel): class TabbyConfigModel(BaseModel): + """Base model for a TabbyConfig.""" + config: ConfigOverrideConfig = Field( default_factory=ConfigOverrideConfig.model_construct ) @@ -331,6 +357,8 @@ class TabbyConfigModel(BaseModel): def generate_config_file(filename="config_sample.yml", indentation=2): + """Creates a config.yml file from Pydantic models.""" + schema = TabbyConfigModel.model_json_schema() def dump_def(id: str, indent=2): @@ -356,6 +384,3 @@ def generate_config_file(filename="config_sample.yml", indentation=2): with open(filename, "w") as f: f.write(yaml) - - -# generate_config_file("test.yml") diff --git a/common/utils.py b/common/utils.py index d5723a0..d933fb6 100644 --- a/common/utils.py +++ b/common/utils.py @@ -1,5 +1,7 @@ """Common utility functions""" +from typing import get_args, get_origin + def unwrap(wrapped, default=None): """Unwrap function for Optionals.""" @@ -43,3 +45,17 @@ def flat_map(input_list): """Flattens a list of lists into a single list.""" return [item for sublist in input_list for item in sublist] + + +def is_list_type(type_hint): + """Checks if a type contains a list.""" + + if get_origin(type_hint) is list: + return True + + # Recursively check for lists inside type arguments + type_args = get_args(type_hint) + if type_args: + return any(is_list_type(arg) for arg in type_args) + + return False