From 420fd84f6b8f1fa97f499892a2f03ab2ca64d4b4 Mon Sep 17 00:00:00 2001 From: TerminalMan <84923604+SecretiveShell@users.noreply.github.com> Date: Fri, 6 Sep 2024 15:05:48 +0100 Subject: [PATCH] add env var loading automation - load config from env vars (eg. TABBY_NETWORK_HOST) - remove print statements - improve command line args automation --- common/args.py | 12 ++---------- common/tabby_config.py | 25 ++++++++++++++++++------- main.py | 8 ++++++-- 3 files changed, 26 insertions(+), 19 deletions(-) diff --git a/common/args.py b/common/args.py index c4fbd46..42f795a 100644 --- a/common/args.py +++ b/common/args.py @@ -37,26 +37,20 @@ def argument_with_auto(value): def init_argparser(): parser = argparse.ArgumentParser(description="TabbyAPI server") - # Loop through the fields in the top-level model (ModelX in this case) for field_name, field_type in config.__annotations__.items(): - # Get the sub-model type (e.g., ModelA, ModelB) - sub_model = field_type.__base__ - - # Create argument group for the sub-model group = parser.add_argument_group( field_name, description=f"Arguments for {field_name}" ) - # Loop through each field in the sub-model (e.g., ModelA, ModelB) + # Loop through each field in the sub-model for sub_field_name, sub_field_type in field_type.__annotations__.items(): field = field_type.__fields__[sub_field_name] help_text = ( field.description if field.description else "No description available" ) - # Handle Optional types or other generic types origin = get_origin(sub_field_type) - if origin is Union: # Check if the type is Union (which includes Optional) + if origin is Union: sub_field_type = next( t for t in get_args(sub_field_type) if t is not type(None) ) @@ -64,7 +58,6 @@ def init_argparser(): sub_field_type = get_args(sub_field_type)[0] # Map Pydantic types to argparse types - print(sub_field_type, type(sub_field_type)) if isinstance(sub_field_type, type) and issubclass( sub_field_type, (int, float, str, bool) ): @@ -72,7 +65,6 @@ def init_argparser(): else: arg_type = str # Default to string for unknown types - # Add the argument for each field in the sub-model group.add_argument(f"--{sub_field_name}", type=arg_type, help=help_text) return parser diff --git a/common/tabby_config.py b/common/tabby_config.py index 5aac0b8..a379ebb 100644 --- a/common/tabby_config.py +++ b/common/tabby_config.py @@ -1,7 +1,8 @@ import yaml import pathlib from loguru import logger -from typing import Optional +from typing import Optional, Union, get_origin, get_args +from os import getenv from common.utils import unwrap, merge_dicts from common.config_models import tabby_config_model @@ -15,6 +16,7 @@ class TabbyConfig(tabby_config_model): # config is applied in order of items in the list configs = [ self._from_file(pathlib.Path("config.yml")), + self._from_environment(), self._from_args(unwrap(arguments, {})), ] @@ -54,7 +56,7 @@ class TabbyConfig(tabby_config_model): config = self.from_file(pathlib.Path(config_override)) return config # Return early if loading from file - for key in ["network", "model", "logging", "developer", "embeddings"]: + for key in tabby_config_model.model_fields.keys(): override = args.get(key) if override: if key == "logging": @@ -67,11 +69,20 @@ class TabbyConfig(tabby_config_model): def _from_environment(self): """loads configuration from environment variables""" - # TODO: load config from environment variables - # this means that we can have host default to 0.0.0.0 in docker for example - # this would also mean that docker containers no longer require a non - # default config file to be used - pass + config = {} + + for field_name in tabby_config_model.model_fields.keys(): + section_config = {} + for sub_field_name in getattr( + tabby_config_model(), field_name + ).model_fields.keys(): + setting = getenv(f"TABBY_{field_name}_{sub_field_name}".upper(), None) + if setting is not None: + section_config[sub_field_name] = setting + + config[field_name] = section_config + + return config # Create an empty instance of the shared var to make sure nothing breaks diff --git a/main.py b/main.py index 7385a1d..6254bf2 100644 --- a/main.py +++ b/main.py @@ -76,7 +76,9 @@ async def entrypoint_async(): if config.lora.loras: lora_dir = pathlib.Path(config.lora.lora_dir) # TODO: remove model_dump() - await model.container.load_loras(lora_dir.resolve(), **config.lora.model_dump()) + await model.container.load_loras( + lora_dir.resolve(), **config.lora.model_dump() + ) # If an initial embedding model name is specified, create a separate container # and load the model @@ -87,7 +89,9 @@ async def entrypoint_async(): try: # TODO: remove model_dump() - await model.load_embedding_model(embedding_model_path, **config.embeddings.model_dump()) + await model.load_embedding_model( + embedding_model_path, **config.embeddings.model_dump() + ) except ImportError as ex: logger.error(ex.msg)