From ac4d9bba1c799434d7b7510df9ac4b5ba7186fc3 Mon Sep 17 00:00:00 2001 From: Jake <84923604+SecretiveShell@users.noreply.github.com> Date: Wed, 4 Sep 2024 12:49:22 +0100 Subject: [PATCH] refactor config functions - improve DRY --- common/config.py | 51 ++++++++++++------------------------------------ 1 file changed, 12 insertions(+), 39 deletions(-) diff --git a/common/config.py b/common/config.py index d59e56b..2d1c980 100644 --- a/common/config.py +++ b/common/config.py @@ -72,43 +72,16 @@ def from_environment() -> dict[str, Any]: return {} -def sampling_config(): - """Returns the sampling parameter config from the global config""" - return unwrap(GLOBAL_CONFIG.get("sampling"), {}) +# refactor the get_config functions +def get_config(config: dict[str, any], topic: str) -> callable : + return lambda: unwrap(config.get(topic), {}) - -def model_config(): - """Returns the model config from the global config""" - return unwrap(GLOBAL_CONFIG.get("model"), {}) - - -def draft_model_config(): - """Returns the draft model config from the global config""" - model_config = unwrap(GLOBAL_CONFIG.get("model"), {}) - return unwrap(model_config.get("draft"), {}) - - -def lora_config(): - """Returns the lora config from the global config""" - model_config = unwrap(GLOBAL_CONFIG.get("model"), {}) - return unwrap(model_config.get("lora"), {}) - - -def network_config(): - """Returns the network config from the global config""" - return unwrap(GLOBAL_CONFIG.get("network"), {}) - - -def logging_config(): - """Returns the logging config from the global config""" - return unwrap(GLOBAL_CONFIG.get("logging"), {}) - - -def developer_config(): - """Returns the developer specific config from the global config""" - return unwrap(GLOBAL_CONFIG.get("developer"), {}) - - -def embeddings_config(): - """Returns the embeddings config from the global config""" - return unwrap(GLOBAL_CONFIG.get("embeddings"), {}) +# each of these is a function +model_config = get_config(GLOBAL_CONFIG, "model") +sampling_config = get_config(GLOBAL_CONFIG, "sampling") +draft_model_config = get_config(model_config(), "draft") +lora_config = get_config(model_config(), "lora") +network_config = get_config(GLOBAL_CONFIG, "network") +logging_config = get_config(GLOBAL_CONFIG, "logging") +developer_config = get_config(GLOBAL_CONFIG, "developer") +embeddings_config = get_config(GLOBAL_CONFIG, "embeddings")