Use the module singleton pattern to share global state. This can also be a modified version of the Global Object Pattern. The main reason this pattern is used is for ease of use when handling global state rather than adding extra dependencies for a DI parameter. Signed-off-by: kingbri <bdashore3@proton.me>
100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
import yaml
|
|
import pathlib
|
|
from loguru import logger
|
|
|
|
from common.utils import unwrap
|
|
|
|
# Global config dictionary constant
|
|
GLOBAL_CONFIG: dict = {}
|
|
|
|
|
|
def from_file(config_path: pathlib.Path):
|
|
"""Sets the global config from a given file path"""
|
|
global GLOBAL_CONFIG
|
|
|
|
try:
|
|
with open(str(config_path.resolve()), "r", encoding="utf8") as config_file:
|
|
GLOBAL_CONFIG = unwrap(yaml.safe_load(config_file), {})
|
|
except Exception as exc:
|
|
logger.error(
|
|
"The YAML config couldn't load because of the following error: "
|
|
f"\n\n{exc}"
|
|
"\n\nTabbyAPI will start anyway and not parse this config file."
|
|
)
|
|
GLOBAL_CONFIG = {}
|
|
|
|
|
|
def from_args(args: dict):
|
|
"""Overrides the config based on a dict representation of args"""
|
|
|
|
config_override = unwrap(args.get("options", {}).get("config"))
|
|
if config_override:
|
|
logger.info("Attempting to override config.yml from args.")
|
|
from_file(pathlib.Path(config_override))
|
|
return
|
|
|
|
# Network config
|
|
network_override = args.get("network")
|
|
if network_override:
|
|
cur_network_config = network_config()
|
|
GLOBAL_CONFIG["network"] = {**cur_network_config, **network_override}
|
|
|
|
# Model config
|
|
model_override = args.get("model")
|
|
if model_override:
|
|
cur_model_config = model_config()
|
|
GLOBAL_CONFIG["model"] = {**cur_model_config, **model_override}
|
|
|
|
# Generation Logging config
|
|
gen_logging_override = args.get("logging")
|
|
if gen_logging_override:
|
|
cur_gen_logging_config = gen_logging_config()
|
|
GLOBAL_CONFIG["logging"] = {
|
|
**cur_gen_logging_config,
|
|
**{
|
|
k.replace("log_", ""): gen_logging_override[k]
|
|
for k in gen_logging_override
|
|
},
|
|
}
|
|
|
|
developer_override = args.get("developer")
|
|
if developer_override:
|
|
cur_developer_config = developer_config()
|
|
GLOBAL_CONFIG["developer"] = {**cur_developer_config, **developer_override}
|
|
|
|
|
|
def sampling_config():
|
|
"""Returns the sampling parameter config from the global config"""
|
|
return unwrap(GLOBAL_CONFIG.get("sampling"), {})
|
|
|
|
|
|
def model_config():
|
|
"""Returns the model config from the global config"""
|
|
return unwrap(GLOBAL_CONFIG.get("model"), {})
|
|
|
|
|
|
def draft_model_config():
|
|
"""Returns the draft model config from the global config"""
|
|
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
|
return unwrap(model_config.get("draft"), {})
|
|
|
|
|
|
def lora_config():
|
|
"""Returns the lora config from the global config"""
|
|
model_config = unwrap(GLOBAL_CONFIG.get("model"), {})
|
|
return unwrap(model_config.get("lora"), {})
|
|
|
|
|
|
def network_config():
|
|
"""Returns the network config from the global config"""
|
|
return unwrap(GLOBAL_CONFIG.get("network"), {})
|
|
|
|
|
|
def gen_logging_config():
|
|
"""Returns the generation logging config from the global config"""
|
|
return unwrap(GLOBAL_CONFIG.get("logging"), {})
|
|
|
|
|
|
def developer_config():
|
|
"""Returns the developer specific config from the global config"""
|
|
return unwrap(GLOBAL_CONFIG.get("developer"), {})
|