refactor config functions
- improve DRY
This commit is contained in:
parent
fa6404a95a
commit
ac4d9bba1c
1 changed files with 12 additions and 39 deletions
|
|
@ -72,43 +72,16 @@ def from_environment() -> dict[str, Any]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def sampling_config():
|
# refactor the get_config functions
|
||||||
"""Returns the sampling parameter config from the global config"""
|
def get_config(config: dict[str, any], topic: str) -> callable :
|
||||||
return unwrap(GLOBAL_CONFIG.get("sampling"), {})
|
return lambda: unwrap(config.get(topic), {})
|
||||||
|
|
||||||
|
# each of these is a function
|
||||||
def model_config():
|
model_config = get_config(GLOBAL_CONFIG, "model")
|
||||||
"""Returns the model config from the global config"""
|
sampling_config = get_config(GLOBAL_CONFIG, "sampling")
|
||||||
return unwrap(GLOBAL_CONFIG.get("model"), {})
|
draft_model_config = get_config(model_config(), "draft")
|
||||||
|
lora_config = get_config(model_config(), "lora")
|
||||||
|
network_config = get_config(GLOBAL_CONFIG, "network")
|
||||||
def draft_model_config():
|
logging_config = get_config(GLOBAL_CONFIG, "logging")
|
||||||
"""Returns the draft model config from the global config"""
|
developer_config = get_config(GLOBAL_CONFIG, "developer")
|
||||||
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
embeddings_config = get_config(GLOBAL_CONFIG, "embeddings")
|
||||||
return unwrap(model_config.get("draft"), {})
|
|
||||||
|
|
||||||
|
|
||||||
def lora_config():
|
|
||||||
"""Returns the lora config from the global config"""
|
|
||||||
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
|
||||||
return unwrap(model_config.get("lora"), {})
|
|
||||||
|
|
||||||
|
|
||||||
def network_config():
|
|
||||||
"""Returns the network config from the global config"""
|
|
||||||
return unwrap(GLOBAL_CONFIG.get("network"), {})
|
|
||||||
|
|
||||||
|
|
||||||
def logging_config():
|
|
||||||
"""Returns the logging config from the global config"""
|
|
||||||
return unwrap(GLOBAL_CONFIG.get("logging"), {})
|
|
||||||
|
|
||||||
|
|
||||||
def developer_config():
|
|
||||||
"""Returns the developer specific config from the global config"""
|
|
||||||
return unwrap(GLOBAL_CONFIG.get("developer"), {})
|
|
||||||
|
|
||||||
|
|
||||||
def embeddings_config():
|
|
||||||
"""Returns the embeddings config from the global config"""
|
|
||||||
return unwrap(GLOBAL_CONFIG.get("embeddings"), {})
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue