refactor config loading

- improve DRY
- alter logging
- allow extensibility
- add foundation for environment variables as config
This commit is contained in:
Jake 2024-09-04 12:22:49 +01:00
parent 8854269121
commit fa6404a95a
3 changed files with 48 additions and 42 deletions

View file

@ -1,6 +1,8 @@
import yaml
import pathlib
from loguru import logger
from mergedeep import merge, Strategy
from typing import Any
from common.utils import unwrap
@ -8,61 +10,66 @@ from common.utils import unwrap
GLOBAL_CONFIG: dict = {}
def from_file(config_path: pathlib.Path):
"""Sets the global config from a given file path"""
def load(arguments: dict[str, Any]):
"""load the global application config"""
global GLOBAL_CONFIG
# config is applied in order of items in the list
configs = [
from_file(pathlib.Path("config.yml")),
from_environment(),
from_args(arguments),
]
GLOBAL_CONFIG = merge({}, *configs, strategy=Strategy.REPLACE)
def from_file(config_path: pathlib.Path) -> dict[str, Any]:
"""loads config from a given file path"""
# try loading from file
try:
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
GLOBAL_CONFIG = unwrap(yaml.safe_load(config_file), {})
return unwrap(yaml.safe_load(config_file), {})
except FileNotFoundError:
logger.info("The config.yml file cannot be found")
except Exception as exc:
logger.error(
"The YAML config couldn't load because of the following error: "
f"\n\n{exc}"
"\n\nTabbyAPI will start anyway and not parse this config file."
f"The YAML config couldn't load because of the following error:\n\n{exc}"
)
GLOBAL_CONFIG = {}
# if no config file was loaded
return {}
def from_args(args: dict):
"""Overrides the config based on a dict representation of args"""
def from_args(args: dict[str, Any]) -> dict[str, Any]:
"""loads config from the provided arguments"""
config = {}
config_override = unwrap(args.get("options", {}).get("config"))
if config_override:
logger.info("Attempting to override config.yml from args.")
from_file(pathlib.Path(config_override))
return
logger.info("Config file override detected in args.")
config = from_file(pathlib.Path(config_override))
return config # Return early if loading from file
# Network config
network_override = args.get("network")
if network_override:
cur_network_config = network_config()
GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override}
for key in ["network", "model", "logging", "developer", "embeddings"]:
override = args.get(key)
if override:
if key == "logging":
# Strip the "log_" prefix from logging keys if present
override = {k.replace("log_", ""): v for k, v in override.items()}
config[key] = override
# Model config
model_override = args.get("model")
if model_override:
cur_model_config = model_config()
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
return config
# Generation Logging config
logging_override = args.get("logging")
if logging_override:
cur_logging_config = logging_config()
GLOBAL_CONFIG["logging"] = {
**cur_logging_config,
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
}
developer_override = args.get("developer")
if developer_override:
cur_developer_config = developer_config()
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
def from_environment() -> dict[str, Any]:
"""loads configuration from environment variables"""
embeddings_override = args.get("embeddings")
if embeddings_override:
cur_embeddings_config = embeddings_config()
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}
# TODO: load config from environment variables
# this means that we can have host default to 0.0.0.0 in docker for example
# this would also mean that docker containers no longer require a non
# default config file to be used
return {}
def sampling_config():

View file

@ -110,15 +110,13 @@ def entrypoint(arguments: Optional[dict] = None):
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Load from YAML config
config.from_file(pathlib.Path("config.yml"))
# Parse and override config from args
if arguments is None:
parser = init_argparser()
arguments = convert_args_to_dict(parser.parse_args(), parser)
config.from_args(arguments)
# load config
config.load(arguments)
if do_export_openapi:
openapi_json = export_openapi()

View file

@ -32,6 +32,7 @@ dependencies = [
"huggingface_hub",
"psutil",
"httptools>=0.5.0",
"mergedeep",
# Improved asyncio loops
"uvloop ; platform_system == 'Linux' and platform_machine == 'x86_64'",