Config: Add override argparser

Add an argparser that casts over to dictionaries of subgroups to
integrate with the config.

This argparser doesn't contain everything in the config due to complexity
issues with CLI args, but will eventually progress to parity. In addition,
it's used to override the config.yml rather than replace it.

A config arg is also provided if the user wants to fully override the
config yaml with another file path.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-01-01 14:27:12 -05:00
parent 7176fa66f0
commit bb7a8e4614
4 changed files with 173 additions and 9 deletions

122
args.py Normal file
View file

@ -0,0 +1,122 @@
"""Argparser for overriding config values"""
import argparse
def str_to_bool(value):
"""Converts a string into a boolean value"""
if value.lower() in {"false", "f", "0", "no", "n"}:
return False
elif value.lower() in {"true", "t", "1", "yes", "y"}:
return True
raise ValueError(f"{value} is not a valid boolean value")
def init_argparser():
"""Creates an argument parser that any function can use"""
parser = argparse.ArgumentParser(
epilog="These args are only for a subset of the config. "
+ "Please edit config.yml for all options!"
)
add_network_args(parser)
add_model_args(parser)
add_logging_args(parser)
add_config_args(parser)
return parser
def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
"""Broad conversion of surface level arg groups to dictionaries"""
arg_groups = {}
for group in parser._action_groups:
group_dict = {}
for arg in group._group_actions:
value = getattr(args, arg.dest, None)
if value is not None:
group_dict[arg.dest] = value
arg_groups[group.title] = group_dict
return arg_groups
def add_config_args(parser: argparse.ArgumentParser):
"""Adds config arguments"""
parser.add_argument(
"--config", type=str, help="Path to an overriding config.yml file"
)
def add_network_args(parser: argparse.ArgumentParser):
"""Adds networking arguments"""
network_group = parser.add_argument_group("network")
network_group.add_argument("--host", type=str, help="The IP to host on")
network_group.add_argument("--port", type=int, help="The port to host on")
network_group.add_argument(
"--disable-auth",
type=str_to_bool,
help="Disable HTTP token authenticaion with requests",
)
def add_model_args(parser: argparse.ArgumentParser):
"""Adds model arguments"""
model_group = parser.add_argument_group("model")
model_group.add_argument(
"--model-dir", type=str, help="Overrides the directory to look for models"
)
model_group.add_argument("--model-name", type=str, help="An initial model to load")
model_group.add_argument(
"--max-seq-len", type=int, help="Override the maximum model sequence length"
)
model_group.add_argument(
"--override-base-seq-len",
type=str_to_bool,
help="Overrides base model context length",
)
model_group.add_argument(
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
)
model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK")
model_group.add_argument(
"--prompt-template",
type=str,
help="Set the prompt template for chat completions",
)
model_group.add_argument(
"--gpu-split-auto",
type=str_to_bool,
help="Automatically allocate resources to GPUs",
)
model_group.add_argument(
"--gpu-split",
type=float,
nargs="+",
help="An integer array of GBs of vram to split between GPUs. "
+ "Ignored if gpu_split_auto is true",
)
model_group.add_argument(
"--num-experts-per-token",
type=int,
help="Number of experts to use per token in MoE models",
)
def add_logging_args(parser: argparse.ArgumentParser):
"""Adds logging arguments"""
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
"--log-prompt", type=str_to_bool, help="Enable prompt logging"
)
logging_group.add_argument(
"--log-generation-params",
type=str_to_bool,
help="Enable generation parameter logging",
)

View file

@ -25,6 +25,37 @@ def read_config_from_file(config_path: pathlib.Path):
GLOBAL_CONFIG = {}
def override_config_from_args(args: dict):
"""Overrides the config based on a dict representation of args"""
config_override = unwrap(args.get("options", {}).get("config"))
if config_override:
logger.info("Attempting to override config.yml from args.")
read_config_from_file(pathlib.Path(config_override))
return
# Network config
network_override = args.get("network")
if network_override:
network_config = get_network_config()
GLOBAL_CONFIG["network"] = {**network_config, **network_override}
# Model config
model_override = args.get("model")
if model_override:
model_config = get_model_config()
GLOBAL_CONFIG["model"] = {**model_config, **model_override}
# Logging config
logging_override = args.get("logging")
if logging_override:
logging_config = get_gen_logging_config()
GLOBAL_CONFIG["logging"] = {
**logging_config,
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
}
def get_model_config():
"""Returns the model config from the global config"""
return unwrap(GLOBAL_CONFIG.get("model"), {})

11
main.py
View file

@ -12,8 +12,10 @@ from functools import partial
from progress.bar import IncrementalBar
import gen_logging
from args import convert_args_to_dict, init_argparser
from auth import check_admin_key, check_api_key, load_auth_keys
from config import (
override_config_from_args,
read_config_from_file,
get_gen_logging_config,
get_model_config,
@ -493,13 +495,20 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
return response
def entrypoint():
def entrypoint(args: Optional[dict] = None):
"""Entry function for program startup"""
global MODEL_CONTAINER
# Load from YAML config
read_config_from_file(pathlib.Path("config.yml"))
# Parse and override config from args
if args is None:
parser = init_argparser()
args = convert_args_to_dict(parser.parse_args(), parser)
override_config_from_args(args)
network_config = get_network_config()
# Initialize auth keys

View file

@ -3,6 +3,7 @@ import argparse
import os
import pathlib
import subprocess
from args import convert_args_to_dict, init_argparser
def get_requirements_file():
@ -24,28 +25,29 @@ def get_requirements_file():
return requirements_name
def get_argparser():
"""Fetches the argparser for this script"""
parser = argparse.ArgumentParser()
parser.add_argument(
def add_start_args(parser: argparse.ArgumentParser):
"""Add start script args to the provided parser"""
start_group = parser.add_argument_group("start")
start_group.add_argument(
"-iu",
"--ignore-upgrade",
action="store_true",
help="Ignore requirements upgrade",
)
parser.add_argument(
start_group.add_argument(
"-nw",
"--nowheel",
action="store_true",
help="Don't upgrade wheel dependencies (exllamav2, torch)",
)
return parser
if __name__ == "__main__":
subprocess.run(["pip", "-V"])
parser = get_argparser()
# Create an argparser and add extra startup script args
parser = init_argparser()
add_start_args(parser)
args = parser.parse_args()
if args.ignore_upgrade:
@ -59,4 +61,4 @@ if __name__ == "__main__":
# Import entrypoint after installing all requirements
from main import entrypoint
entrypoint()
entrypoint(convert_args_to_dict(args, parser))