87 lines
2.9 KiB
Python
87 lines
2.9 KiB
Python
import yaml
|
|
import pathlib
|
|
from loguru import logger
|
|
from mergedeep import merge, Strategy
|
|
from typing import Any
|
|
|
|
from common.utils import unwrap
|
|
|
|
# Global config dictionary constant
|
|
GLOBAL_CONFIG: dict = {}
|
|
|
|
|
|
def load(arguments: dict[str, Any]):
|
|
"""load the global application config"""
|
|
global GLOBAL_CONFIG
|
|
|
|
# config is applied in order of items in the list
|
|
configs = [
|
|
from_file(pathlib.Path("config.yml")),
|
|
from_environment(),
|
|
from_args(arguments),
|
|
]
|
|
|
|
GLOBAL_CONFIG = merge({}, *configs, strategy=Strategy.REPLACE)
|
|
|
|
def from_file(config_path: pathlib.Path) -> dict[str, Any]:
|
|
"""loads config from a given file path"""
|
|
|
|
# try loading from file
|
|
try:
|
|
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
|
return unwrap(yaml.safe_load(config_file), {})
|
|
except FileNotFoundError:
|
|
logger.info("The config.yml file cannot be found")
|
|
except Exception as exc:
|
|
logger.error(
|
|
f"The YAML config couldn't load because of the following error:\n\n{exc}"
|
|
)
|
|
|
|
# if no config file was loaded
|
|
return {}
|
|
|
|
|
|
def from_args(args: dict[str, Any]) -> dict[str, Any]:
|
|
"""loads config from the provided arguments"""
|
|
config = {}
|
|
|
|
config_override = unwrap(args.get("options", {}).get("config"))
|
|
if config_override:
|
|
logger.info("Config file override detected in args.")
|
|
config = from_file(pathlib.Path(config_override))
|
|
return config # Return early if loading from file
|
|
|
|
for key in ["network", "model", "logging", "developer", "embeddings"]:
|
|
override = args.get(key)
|
|
if override:
|
|
if key == "logging":
|
|
# Strip the "log_" prefix from logging keys if present
|
|
override = {k.replace("log_", ""): v for k, v in override.items()}
|
|
config[key] = override
|
|
|
|
return config
|
|
|
|
|
|
def from_environment() -> dict[str, Any]:
|
|
"""loads configuration from environment variables"""
|
|
|
|
# TODO: load config from environment variables
|
|
# this means that we can have host default to 0.0.0.0 in docker for example
|
|
# this would also mean that docker containers no longer require a non
|
|
# default config file to be used
|
|
return {}
|
|
|
|
|
|
# refactor the get_config functions
|
|
def get_config(config: dict[str, any], topic: str) -> callable :
|
|
return lambda: unwrap(config.get(topic), {})
|
|
|
|
# each of these is a function
|
|
model_config = get_config(GLOBAL_CONFIG, "model")
|
|
sampling_config = get_config(GLOBAL_CONFIG, "sampling")
|
|
draft_model_config = get_config(model_config(), "draft")
|
|
lora_config = get_config(model_config(), "lora")
|
|
network_config = get_config(GLOBAL_CONFIG, "network")
|
|
logging_config = get_config(GLOBAL_CONFIG, "logging")
|
|
developer_config = get_config(GLOBAL_CONFIG, "developer")
|
|
embeddings_config = get_config(GLOBAL_CONFIG, "embeddings")
|