diff --git a/common/config.py b/common/config.py index 9b2f654..d59e56b 100644 --- a/common/config.py +++ b/common/config.py @@ -1,6 +1,8 @@ import yaml import pathlib from loguru import logger +from mergedeep import merge, Strategy +from typing import Any from common.utils import unwrap @@ -8,61 +10,66 @@ from common.utils import unwrap GLOBAL_CONFIG: dict = {} -def from_file(config_path: pathlib.Path): - """Sets the global config from a given file path""" +def load(arguments: dict[str, Any]): + """load the global application config""" global GLOBAL_CONFIG + # config is applied in order of items in the list + configs = [ + from_file(pathlib.Path("config.yml")), + from_environment(), + from_args(arguments), + ] + + GLOBAL_CONFIG = merge({}, *configs, strategy=Strategy.REPLACE) + +def from_file(config_path: pathlib.Path) -> dict[str, Any]: + """loads config from a given file path""" + + # try loading from file try: with open(str(config_path.resolve()), "r", encoding="utf8") as config_file: - GLOBAL_CONFIG = unwrap(yaml.safe_load(config_file), {}) + return unwrap(yaml.safe_load(config_file), {}) + except FileNotFoundError: + logger.info("The config.yml file cannot be found") except Exception as exc: logger.error( - "The YAML config couldn't load because of the following error: " - f"\n\n{exc}" - "\n\nTabbyAPI will start anyway and not parse this config file." + f"The YAML config couldn't load because of the following error:\n\n{exc}" ) - GLOBAL_CONFIG = {} + + # if no config file was loaded + return {} -def from_args(args: dict): - """Overrides the config based on a dict representation of args""" +def from_args(args: dict[str, Any]) -> dict[str, Any]: + """loads config from the provided arguments""" + config = {} config_override = unwrap(args.get("options", {}).get("config")) if config_override: - logger.info("Attempting to override config.yml from args.") - from_file(pathlib.Path(config_override)) - return + logger.info("Config file override detected in args.") + config = from_file(pathlib.Path(config_override)) + return config # Return early if loading from file - # Network config - network_override = args.get("network") - if network_override: - cur_network_config = network_config() - GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override} + for key in ["network", "model", "logging", "developer", "embeddings"]: + override = args.get(key) + if override: + if key == "logging": + # Strip the "log_" prefix from logging keys if present + override = {k.replace("log_", ""): v for k, v in override.items()} + config[key] = override - # Model config - model_override = args.get("model") - if model_override: - cur_model_config = model_config() - GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override} + return config - # Generation Logging config - logging_override = args.get("logging") - if logging_override: - cur_logging_config = logging_config() - GLOBAL_CONFIG["logging"] = { - **cur_logging_config, - **{k.replace("log_", ""): logging_override[k] for k in logging_override}, - } - developer_override = args.get("developer") - if developer_override: - cur_developer_config = developer_config() - GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override} +def from_environment() -> dict[str, Any]: + """loads configuration from environment variables""" - embeddings_override = args.get("embeddings") - if embeddings_override: - cur_embeddings_config = embeddings_config() - GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override} + # TODO: load config from environment variables + # this means that we can have host default to 0.0.0.0 in docker for example + # this would also mean that docker containers no longer require a non + # default config file to be used + return {} def sampling_config(): diff --git a/main.py b/main.py index b0c5108..e25a9ff 100644 --- a/main.py +++ b/main.py @@ -110,15 +110,13 @@ def entrypoint(arguments: Optional[dict] = None): signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - # Load from YAML config - config.from_file(pathlib.Path("config.yml")) - # Parse and override config from args if arguments is None: parser = init_argparser() arguments = convert_args_to_dict(parser.parse_args(), parser) - config.from_args(arguments) + # load config + config.load(arguments) if do_export_openapi: openapi_json = export_openapi() diff --git a/pyproject.toml b/pyproject.toml index b9e80fe..89c9661 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ dependencies = [ "huggingface_hub", "psutil", "httptools>=0.5.0", + "mergedeep", # Improved asyncio loops "uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",