diff --git a/args.py b/args.py new file mode 100644 index 0000000..1aa2531 --- /dev/null +++ b/args.py @@ -0,0 +1,122 @@ +"""Argparser for overriding config values""" +import argparse + + +def str_to_bool(value): + """Converts a string into a boolean value""" + + if value.lower() in {"false", "f", "0", "no", "n"}: + return False + elif value.lower() in {"true", "t", "1", "yes", "y"}: + return True + raise ValueError(f"{value} is not a valid boolean value") + + +def init_argparser(): + """Creates an argument parser that any function can use""" + + parser = argparse.ArgumentParser( + epilog="These args are only for a subset of the config. " + + "Please edit config.yml for all options!" + ) + add_network_args(parser) + add_model_args(parser) + add_logging_args(parser) + add_config_args(parser) + + return parser + + +def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser): + """Broad conversion of surface level arg groups to dictionaries""" + + arg_groups = {} + for group in parser._action_groups: + group_dict = {} + for arg in group._group_actions: + value = getattr(args, arg.dest, None) + if value is not None: + group_dict[arg.dest] = value + + arg_groups[group.title] = group_dict + + return arg_groups + + +def add_config_args(parser: argparse.ArgumentParser): + """Adds config arguments""" + + parser.add_argument( + "--config", type=str, help="Path to an overriding config.yml file" + ) + + +def add_network_args(parser: argparse.ArgumentParser): + """Adds networking arguments""" + + network_group = parser.add_argument_group("network") + network_group.add_argument("--host", type=str, help="The IP to host on") + network_group.add_argument("--port", type=int, help="The port to host on") + network_group.add_argument( + "--disable-auth", + type=str_to_bool, + help="Disable HTTP token authenticaion with requests", + ) + + +def add_model_args(parser: argparse.ArgumentParser): + """Adds model arguments""" + + model_group = parser.add_argument_group("model") + model_group.add_argument( + "--model-dir", type=str, help="Overrides the directory to look for models" + ) + model_group.add_argument("--model-name", type=str, help="An initial model to load") + model_group.add_argument( + "--max-seq-len", type=int, help="Override the maximum model sequence length" + ) + model_group.add_argument( + "--override-base-seq-len", + type=str_to_bool, + help="Overrides base model context length", + ) + model_group.add_argument( + "--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb" + ) + model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK") + model_group.add_argument( + "--prompt-template", + type=str, + help="Set the prompt template for chat completions", + ) + model_group.add_argument( + "--gpu-split-auto", + type=str_to_bool, + help="Automatically allocate resources to GPUs", + ) + model_group.add_argument( + "--gpu-split", + type=float, + nargs="+", + help="An integer array of GBs of vram to split between GPUs. " + + "Ignored if gpu_split_auto is true", + ) + model_group.add_argument( + "--num-experts-per-token", + type=int, + help="Number of experts to use per token in MoE models", + ) + + +def add_logging_args(parser: argparse.ArgumentParser): + """Adds logging arguments""" + + logging_group = parser.add_argument_group("logging") + logging_group.add_argument( + "--log-prompt", type=str_to_bool, help="Enable prompt logging" + ) + logging_group.add_argument( + "--log-generation-params", + type=str_to_bool, + help="Enable generation parameter logging", + ) diff --git a/config.py b/config.py index e65edfa..178977b 100644 --- a/config.py +++ b/config.py @@ -25,6 +25,37 @@ def read_config_from_file(config_path: pathlib.Path): GLOBAL_CONFIG = {} +def override_config_from_args(args: dict): + """Overrides the config based on a dict representation of args""" + + config_override = unwrap(args.get("options", {}).get("config")) + if config_override: + logger.info("Attempting to override config.yml from args.") + read_config_from_file(pathlib.Path(config_override)) + return + + # Network config + network_override = args.get("network") + if network_override: + network_config = get_network_config() + GLOBAL_CONFIG["network"] = {**network_config, **network_override} + + # Model config + model_override = args.get("model") + if model_override: + model_config = get_model_config() + GLOBAL_CONFIG["model"] = {**model_config, **model_override} + + # Logging config + logging_override = args.get("logging") + if logging_override: + logging_config = get_gen_logging_config() + GLOBAL_CONFIG["logging"] = { + **logging_config, + **{k.replace("log_", ""): logging_override[k] for k in logging_override}, + } + + def get_model_config(): """Returns the model config from the global config""" return unwrap(GLOBAL_CONFIG.get("model"), {}) diff --git a/main.py b/main.py index fc7c5d3..9b0e0c7 100644 --- a/main.py +++ b/main.py @@ -12,8 +12,10 @@ from functools import partial from progress.bar import IncrementalBar import gen_logging +from args import convert_args_to_dict, init_argparser from auth import check_admin_key, check_api_key, load_auth_keys from config import ( + override_config_from_args, read_config_from_file, get_gen_logging_config, get_model_config, @@ -493,13 +495,20 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest return response -def entrypoint(): +def entrypoint(args: Optional[dict] = None): """Entry function for program startup""" global MODEL_CONTAINER # Load from YAML config read_config_from_file(pathlib.Path("config.yml")) + # Parse and override config from args + if args is None: + parser = init_argparser() + args = convert_args_to_dict(parser.parse_args(), parser) + + override_config_from_args(args) + network_config = get_network_config() # Initialize auth keys diff --git a/start.py b/start.py index 832210f..8fe34fa 100644 --- a/start.py +++ b/start.py @@ -3,6 +3,7 @@ import argparse import os import pathlib import subprocess +from args import convert_args_to_dict, init_argparser def get_requirements_file(): @@ -24,28 +25,29 @@ def get_requirements_file(): return requirements_name -def get_argparser(): - """Fetches the argparser for this script""" - parser = argparse.ArgumentParser() - parser.add_argument( +def add_start_args(parser: argparse.ArgumentParser): + """Add start script args to the provided parser""" + start_group = parser.add_argument_group("start") + start_group.add_argument( "-iu", "--ignore-upgrade", action="store_true", help="Ignore requirements upgrade", ) - parser.add_argument( + start_group.add_argument( "-nw", "--nowheel", action="store_true", help="Don't upgrade wheel dependencies (exllamav2, torch)", ) - return parser if __name__ == "__main__": subprocess.run(["pip", "-V"]) - parser = get_argparser() + # Create an argparser and add extra startup script args + parser = init_argparser() + add_start_args(parser) args = parser.parse_args() if args.ignore_upgrade: @@ -59,4 +61,4 @@ if __name__ == "__main__": # Import entrypoint after installing all requirements from main import entrypoint - entrypoint() + entrypoint(convert_args_to_dict(args, parser))