refactor config loading
- improve DRY - alter logging - allow extensibility - add foundation for environment variables as config
This commit is contained in:
parent
8854269121
commit
fa6404a95a
3 changed files with 48 additions and 42 deletions
|
|
@ -1,6 +1,8 @@
|
|||
import yaml
|
||||
import pathlib
|
||||
from loguru import logger
|
||||
from mergedeep import merge, Strategy
|
||||
from typing import Any
|
||||
|
||||
from common.utils import unwrap
|
||||
|
||||
|
|
@ -8,61 +10,66 @@ from common.utils import unwrap
|
|||
GLOBAL_CONFIG: dict = {}
|
||||
|
||||
|
||||
def from_file(config_path: pathlib.Path):
|
||||
"""Sets the global config from a given file path"""
|
||||
def load(arguments: dict[str, Any]):
|
||||
"""load the global application config"""
|
||||
global GLOBAL_CONFIG
|
||||
|
||||
# config is applied in order of items in the list
|
||||
configs = [
|
||||
from_file(pathlib.Path("config.yml")),
|
||||
from_environment(),
|
||||
from_args(arguments),
|
||||
]
|
||||
|
||||
GLOBAL_CONFIG = merge({}, *configs, strategy=Strategy.REPLACE)
|
||||
|
||||
def from_file(config_path: pathlib.Path) -> dict[str, Any]:
|
||||
"""loads config from a given file path"""
|
||||
|
||||
# try loading from file
|
||||
try:
|
||||
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
||||
GLOBAL_CONFIG = unwrap(yaml.safe_load(config_file), {})
|
||||
return unwrap(yaml.safe_load(config_file), {})
|
||||
except FileNotFoundError:
|
||||
logger.info("The config.yml file cannot be found")
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"The YAML config couldn't load because of the following error: "
|
||||
f"\n\n{exc}"
|
||||
"\n\nTabbyAPI will start anyway and not parse this config file."
|
||||
f"The YAML config couldn't load because of the following error:\n\n{exc}"
|
||||
)
|
||||
GLOBAL_CONFIG = {}
|
||||
|
||||
# if no config file was loaded
|
||||
return {}
|
||||
|
||||
|
||||
def from_args(args: dict):
|
||||
"""Overrides the config based on a dict representation of args"""
|
||||
def from_args(args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""loads config from the provided arguments"""
|
||||
config = {}
|
||||
|
||||
config_override = unwrap(args.get("options", {}).get("config"))
|
||||
if config_override:
|
||||
logger.info("Attempting to override config.yml from args.")
|
||||
from_file(pathlib.Path(config_override))
|
||||
return
|
||||
logger.info("Config file override detected in args.")
|
||||
config = from_file(pathlib.Path(config_override))
|
||||
return config # Return early if loading from file
|
||||
|
||||
# Network config
|
||||
network_override = args.get("network")
|
||||
if network_override:
|
||||
cur_network_config = network_config()
|
||||
GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override}
|
||||
for key in ["network", "model", "logging", "developer", "embeddings"]:
|
||||
override = args.get(key)
|
||||
if override:
|
||||
if key == "logging":
|
||||
# Strip the "log_" prefix from logging keys if present
|
||||
override = {k.replace("log_", ""): v for k, v in override.items()}
|
||||
config[key] = override
|
||||
|
||||
# Model config
|
||||
model_override = args.get("model")
|
||||
if model_override:
|
||||
cur_model_config = model_config()
|
||||
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
|
||||
return config
|
||||
|
||||
# Generation Logging config
|
||||
logging_override = args.get("logging")
|
||||
if logging_override:
|
||||
cur_logging_config = logging_config()
|
||||
GLOBAL_CONFIG["logging"] = {
|
||||
**cur_logging_config,
|
||||
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
|
||||
}
|
||||
|
||||
developer_override = args.get("developer")
|
||||
if developer_override:
|
||||
cur_developer_config = developer_config()
|
||||
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
|
||||
def from_environment() -> dict[str, Any]:
|
||||
"""loads configuration from environment variables"""
|
||||
|
||||
embeddings_override = args.get("embeddings")
|
||||
if embeddings_override:
|
||||
cur_embeddings_config = embeddings_config()
|
||||
GLOBAL_CONFIG["embeddings"] = {**cur_embeddings_config, **embeddings_override}
|
||||
# TODO: load config from environment variables
|
||||
# this means that we can have host default to 0.0.0.0 in docker for example
|
||||
# this would also mean that docker containers no longer require a non
|
||||
# default config file to be used
|
||||
return {}
|
||||
|
||||
|
||||
def sampling_config():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue