tabbyAPI-ollama/common/config.py
kingbri 5a2de30066 Tree: Update to cleanup globals
Use the module singleton pattern to share global state. This can also
be a modified version of the Global Object Pattern. The main reason
this pattern is used is for ease of use when handling global state
rather than adding extra dependencies for a DI parameter.

Signed-off-by: kingbri <bdashore3@proton.me>
2024-03-12 23:59:30 -04:00

100 lines
3.1 KiB
Python

import yaml
import pathlib
from loguru import logger
from common.utils import unwrap
# Global config dictionary constant
GLOBAL_CONFIG: dict = {}
def from_file(config_path: pathlib.Path):
"""Sets the global config from a given file path"""
global GLOBAL_CONFIG
try:
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
GLOBAL_CONFIG = unwrap(yaml.safe_load(config_file), {})
except Exception as exc:
logger.error(
"The YAML config couldn't load because of the following error: "
f"\n\n{exc}"
"\n\nTabbyAPI will start anyway and not parse this config file."
)
GLOBAL_CONFIG = {}
def from_args(args: dict):
"""Overrides the config based on a dict representation of args"""
config_override = unwrap(args.get("options", {}).get("config"))
if config_override:
logger.info("Attempting to override config.yml from args.")
from_file(pathlib.Path(config_override))
return
# Network config
network_override = args.get("network")
if network_override:
cur_network_config = network_config()
GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override}
# Model config
model_override = args.get("model")
if model_override:
cur_model_config = model_config()
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
# Generation Logging config
gen_logging_override = args.get("logging")
if gen_logging_override:
cur_gen_logging_config = gen_logging_config()
GLOBAL_CONFIG["logging"] = {
**cur_gen_logging_config,
**{
k.replace("log_", ""): gen_logging_override[k]
for k in gen_logging_override
},
}
developer_override = args.get("developer")
if developer_override:
cur_developer_config = developer_config()
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
def sampling_config():
"""Returns the sampling parameter config from the global config"""
return unwrap(GLOBAL_CONFIG.get("sampling"), {})
def model_config():
"""Returns the model config from the global config"""
return unwrap(GLOBAL_CONFIG.get("model"), {})
def draft_model_config():
"""Returns the draft model config from the global config"""
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
return unwrap(model_config.get("draft"), {})
def lora_config():
"""Returns the lora config from the global config"""
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
return unwrap(model_config.get("lora"), {})
def network_config():
"""Returns the network config from the global config"""
return unwrap(GLOBAL_CONFIG.get("network"), {})
def gen_logging_config():
"""Returns the generation logging config from the global config"""
return unwrap(GLOBAL_CONFIG.get("logging"), {})
def developer_config():
"""Returns the developer specific config from the global config"""
return unwrap(GLOBAL_CONFIG.get("developer"), {})