Config: Add override argparser
Add an argparser that casts over to dictionaries of subgroups to integrate with the config. This argparser doesn't contain everything in the config due to complexity issues with CLI args, but will eventually progress to parity. In addition, it's used to override the config.yml rather than replace it. A config arg is also provided if the user wants to fully override the config yaml with another file path. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
7176fa66f0
commit
bb7a8e4614
4 changed files with 173 additions and 9 deletions
122
args.py
Normal file
122
args.py
Normal file
|
|
@ -0,0 +1,122 @@
|
|||
"""Argparser for overriding config values"""
|
||||
import argparse
|
||||
|
||||
|
||||
def str_to_bool(value):
|
||||
"""Converts a string into a boolean value"""
|
||||
|
||||
if value.lower() in {"false", "f", "0", "no", "n"}:
|
||||
return False
|
||||
elif value.lower() in {"true", "t", "1", "yes", "y"}:
|
||||
return True
|
||||
raise ValueError(f"{value} is not a valid boolean value")
|
||||
|
||||
|
||||
def init_argparser():
|
||||
"""Creates an argument parser that any function can use"""
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog="These args are only for a subset of the config. "
|
||||
+ "Please edit config.yml for all options!"
|
||||
)
|
||||
add_network_args(parser)
|
||||
add_model_args(parser)
|
||||
add_logging_args(parser)
|
||||
add_config_args(parser)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
"""Broad conversion of surface level arg groups to dictionaries"""
|
||||
|
||||
arg_groups = {}
|
||||
for group in parser._action_groups:
|
||||
group_dict = {}
|
||||
for arg in group._group_actions:
|
||||
value = getattr(args, arg.dest, None)
|
||||
if value is not None:
|
||||
group_dict[arg.dest] = value
|
||||
|
||||
arg_groups[group.title] = group_dict
|
||||
|
||||
return arg_groups
|
||||
|
||||
|
||||
def add_config_args(parser: argparse.ArgumentParser):
|
||||
"""Adds config arguments"""
|
||||
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="Path to an overriding config.yml file"
|
||||
)
|
||||
|
||||
|
||||
def add_network_args(parser: argparse.ArgumentParser):
|
||||
"""Adds networking arguments"""
|
||||
|
||||
network_group = parser.add_argument_group("network")
|
||||
network_group.add_argument("--host", type=str, help="The IP to host on")
|
||||
network_group.add_argument("--port", type=int, help="The port to host on")
|
||||
network_group.add_argument(
|
||||
"--disable-auth",
|
||||
type=str_to_bool,
|
||||
help="Disable HTTP token authenticaion with requests",
|
||||
)
|
||||
|
||||
|
||||
def add_model_args(parser: argparse.ArgumentParser):
|
||||
"""Adds model arguments"""
|
||||
|
||||
model_group = parser.add_argument_group("model")
|
||||
model_group.add_argument(
|
||||
"--model-dir", type=str, help="Overrides the directory to look for models"
|
||||
)
|
||||
model_group.add_argument("--model-name", type=str, help="An initial model to load")
|
||||
model_group.add_argument(
|
||||
"--max-seq-len", type=int, help="Override the maximum model sequence length"
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--override-base-seq-len",
|
||||
type=str_to_bool,
|
||||
help="Overrides base model context length",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
|
||||
)
|
||||
model_group.add_argument("--rope-alpha", type=float, help="Sets rope_alpha for NTK")
|
||||
model_group.add_argument(
|
||||
"--prompt-template",
|
||||
type=str,
|
||||
help="Set the prompt template for chat completions",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--gpu-split-auto",
|
||||
type=str_to_bool,
|
||||
help="Automatically allocate resources to GPUs",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--gpu-split",
|
||||
type=float,
|
||||
nargs="+",
|
||||
help="An integer array of GBs of vram to split between GPUs. "
|
||||
+ "Ignored if gpu_split_auto is true",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--num-experts-per-token",
|
||||
type=int,
|
||||
help="Number of experts to use per token in MoE models",
|
||||
)
|
||||
|
||||
|
||||
def add_logging_args(parser: argparse.ArgumentParser):
|
||||
"""Adds logging arguments"""
|
||||
|
||||
logging_group = parser.add_argument_group("logging")
|
||||
logging_group.add_argument(
|
||||
"--log-prompt", type=str_to_bool, help="Enable prompt logging"
|
||||
)
|
||||
logging_group.add_argument(
|
||||
"--log-generation-params",
|
||||
type=str_to_bool,
|
||||
help="Enable generation parameter logging",
|
||||
)
|
||||
31
config.py
31
config.py
|
|
@ -25,6 +25,37 @@ def read_config_from_file(config_path: pathlib.Path):
|
|||
GLOBAL_CONFIG = {}
|
||||
|
||||
|
||||
def override_config_from_args(args: dict):
|
||||
"""Overrides the config based on a dict representation of args"""
|
||||
|
||||
config_override = unwrap(args.get("options", {}).get("config"))
|
||||
if config_override:
|
||||
logger.info("Attempting to override config.yml from args.")
|
||||
read_config_from_file(pathlib.Path(config_override))
|
||||
return
|
||||
|
||||
# Network config
|
||||
network_override = args.get("network")
|
||||
if network_override:
|
||||
network_config = get_network_config()
|
||||
GLOBAL_CONFIG["network"] = {**network_config, **network_override}
|
||||
|
||||
# Model config
|
||||
model_override = args.get("model")
|
||||
if model_override:
|
||||
model_config = get_model_config()
|
||||
GLOBAL_CONFIG["model"] = {**model_config, **model_override}
|
||||
|
||||
# Logging config
|
||||
logging_override = args.get("logging")
|
||||
if logging_override:
|
||||
logging_config = get_gen_logging_config()
|
||||
GLOBAL_CONFIG["logging"] = {
|
||||
**logging_config,
|
||||
**{k.replace("log_", ""): logging_override[k] for k in logging_override},
|
||||
}
|
||||
|
||||
|
||||
def get_model_config():
|
||||
"""Returns the model config from the global config"""
|
||||
return unwrap(GLOBAL_CONFIG.get("model"), {})
|
||||
|
|
|
|||
11
main.py
11
main.py
|
|
@ -12,8 +12,10 @@ from functools import partial
|
|||
from progress.bar import IncrementalBar
|
||||
|
||||
import gen_logging
|
||||
from args import convert_args_to_dict, init_argparser
|
||||
from auth import check_admin_key, check_api_key, load_auth_keys
|
||||
from config import (
|
||||
override_config_from_args,
|
||||
read_config_from_file,
|
||||
get_gen_logging_config,
|
||||
get_model_config,
|
||||
|
|
@ -493,13 +495,20 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
return response
|
||||
|
||||
|
||||
def entrypoint():
|
||||
def entrypoint(args: Optional[dict] = None):
|
||||
"""Entry function for program startup"""
|
||||
global MODEL_CONTAINER
|
||||
|
||||
# Load from YAML config
|
||||
read_config_from_file(pathlib.Path("config.yml"))
|
||||
|
||||
# Parse and override config from args
|
||||
if args is None:
|
||||
parser = init_argparser()
|
||||
args = convert_args_to_dict(parser.parse_args(), parser)
|
||||
|
||||
override_config_from_args(args)
|
||||
|
||||
network_config = get_network_config()
|
||||
|
||||
# Initialize auth keys
|
||||
|
|
|
|||
18
start.py
18
start.py
|
|
@ -3,6 +3,7 @@ import argparse
|
|||
import os
|
||||
import pathlib
|
||||
import subprocess
|
||||
from args import convert_args_to_dict, init_argparser
|
||||
|
||||
|
||||
def get_requirements_file():
|
||||
|
|
@ -24,28 +25,29 @@ def get_requirements_file():
|
|||
return requirements_name
|
||||
|
||||
|
||||
def get_argparser():
|
||||
"""Fetches the argparser for this script"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
def add_start_args(parser: argparse.ArgumentParser):
|
||||
"""Add start script args to the provided parser"""
|
||||
start_group = parser.add_argument_group("start")
|
||||
start_group.add_argument(
|
||||
"-iu",
|
||||
"--ignore-upgrade",
|
||||
action="store_true",
|
||||
help="Ignore requirements upgrade",
|
||||
)
|
||||
parser.add_argument(
|
||||
start_group.add_argument(
|
||||
"-nw",
|
||||
"--nowheel",
|
||||
action="store_true",
|
||||
help="Don't upgrade wheel dependencies (exllamav2, torch)",
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
subprocess.run(["pip", "-V"])
|
||||
|
||||
parser = get_argparser()
|
||||
# Create an argparser and add extra startup script args
|
||||
parser = init_argparser()
|
||||
add_start_args(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.ignore_upgrade:
|
||||
|
|
@ -59,4 +61,4 @@ if __name__ == "__main__":
|
|||
# Import entrypoint after installing all requirements
|
||||
from main import entrypoint
|
||||
|
||||
entrypoint()
|
||||
entrypoint(convert_args_to_dict(args, parser))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue