diff --git a/common/config.py b/common/config.py index d59e56b..2d1c980 100644 --- a/common/config.py +++ b/common/config.py @@ -72,43 +72,16 @@ def from_environment() -> dict[str, Any]: return {} -def sampling_config(): - """Returns the sampling parameter config from the global config""" - return unwrap(GLOBAL_CONFIG.get("sampling"), {}) +# refactor the get_config functions +def get_config(config: dict[str, any], topic: str) -> callable : + return lambda: unwrap(config.get(topic), {}) - -def model_config(): - """Returns the model config from the global config""" - return unwrap(GLOBAL_CONFIG.get("model"), {}) - - -def draft_model_config(): - """Returns the draft model config from the global config""" - model_config = unwrap(GLOBAL_CONFIG.get("model"), {}) - return unwrap(model_config.get("draft"), {}) - - -def lora_config(): - """Returns the lora config from the global config""" - model_config = unwrap(GLOBAL_CONFIG.get("model"), {}) - return unwrap(model_config.get("lora"), {}) - - -def network_config(): - """Returns the network config from the global config""" - return unwrap(GLOBAL_CONFIG.get("network"), {}) - - -def logging_config(): - """Returns the logging config from the global config""" - return unwrap(GLOBAL_CONFIG.get("logging"), {}) - - -def developer_config(): - """Returns the developer specific config from the global config""" - return unwrap(GLOBAL_CONFIG.get("developer"), {}) - - -def embeddings_config(): - """Returns the embeddings config from the global config""" - return unwrap(GLOBAL_CONFIG.get("embeddings"), {}) +# each of these is a function +model_config = get_config(GLOBAL_CONFIG, "model") +sampling_config = get_config(GLOBAL_CONFIG, "sampling") +draft_model_config = get_config(model_config(), "draft") +lora_config = get_config(model_config(), "lora") +network_config = get_config(GLOBAL_CONFIG, "network") +logging_config = get_config(GLOBAL_CONFIG, "logging") +developer_config = get_config(GLOBAL_CONFIG, "developer") +embeddings_config = get_config(GLOBAL_CONFIG, "embeddings")