automate arg parse

- generate arg parser dynamically
- remove legavy parser code
This commit is contained in:
Jake 2024-09-06 00:27:53 +01:00
parent 362b8d5818
commit 36e991c16e

View file

@ -1,7 +1,9 @@
"""Argparser for overriding config values"""
import argparse
from typing import get_origin, get_args, Optional, Union, List
from pydantic import BaseModel
from common.tabby_config import config
def str_to_bool(value):
"""Converts a string into a boolean value"""
@ -32,24 +34,40 @@ def argument_with_auto(value):
def init_argparser():
"""Creates an argument parser that any function can use"""
parser = argparse.ArgumentParser(description="TabbyAPI server")
parser = argparse.ArgumentParser(
epilog="NOTE: These args serve to override parts of the config. "
+ "It's highly recommended to edit config.yml for all options and "
+ "better descriptions!"
)
add_network_args(parser)
add_model_args(parser)
add_embeddings_args(parser)
add_logging_args(parser)
add_developer_args(parser)
add_sampling_args(parser)
add_config_args(parser)
# Loop through the fields in the top-level model (ModelX in this case)
for field_name, field_type in config.__annotations__.items():
# Get the sub-model type (e.g., ModelA, ModelB)
sub_model = field_type.__base__
# Create argument group for the sub-model
group = parser.add_argument_group(field_name, description=f"Arguments for {field_name}")
# Loop through each field in the sub-model (e.g., ModelA, ModelB)
for sub_field_name, sub_field_type in field_type.__annotations__.items():
field = field_type.__fields__[sub_field_name]
help_text = field.description if field.description else "No description available"
# Handle Optional types or other generic types
origin = get_origin(sub_field_type)
if origin is Union: # Check if the type is Union (which includes Optional)
sub_field_type = next(t for t in get_args(sub_field_type) if t is not type(None))
elif origin is List : sub_field_type = get_args(sub_field_type)[0]
# Map Pydantic types to argparse types
print(sub_field_type, type(sub_field_type))
if isinstance(sub_field_type, type) and issubclass(sub_field_type, (int, float, str, bool)):
arg_type = sub_field_type
else:
arg_type = str # Default to string for unknown types
# Add the argument for each field in the sub-model
group.add_argument(f"--{sub_field_name}", type=arg_type, help=help_text)
return parser
def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
"""Broad conversion of surface level arg groups to dictionaries"""
@ -63,202 +81,4 @@ def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentPars
arg_groups[group.title] = group_dict
return arg_groups
def add_config_args(parser: argparse.ArgumentParser):
"""Adds config arguments"""
parser.add_argument(
"--config", type=str, help="Path to an overriding config.yml file"
)
def add_network_args(parser: argparse.ArgumentParser):
"""Adds networking arguments"""
network_group = parser.add_argument_group("network")
network_group.add_argument("--host", type=str, help="The IP to host on")
network_group.add_argument("--port", type=int, help="The port to host on")
network_group.add_argument(
"--disable-auth",
type=str_to_bool,
help="Disable HTTP token authenticaion with requests",
)
network_group.add_argument(
"--send-tracebacks",
type=str_to_bool,
help="Decide whether to send error tracebacks over the API",
)
network_group.add_argument(
"--api-servers",
type=str,
nargs="+",
help="API servers to enable. Options: (OAI, Kobold)",
)
def add_model_args(parser: argparse.ArgumentParser):
"""Adds model arguments"""
model_group = parser.add_argument_group("model")
model_group.add_argument(
"--model-dir", type=str, help="Overrides the directory to look for models"
)
model_group.add_argument("--model-name", type=str, help="An initial model to load")
model_group.add_argument(
"--use-dummy-models",
type=str_to_bool,
help="Add dummy OAI model names for API queries",
)
model_group.add_argument(
"--use-as-default",
type=str,
nargs="+",
help="Names of args to use as a default fallback for API load requests ",
)
model_group.add_argument(
"--max-seq-len", type=int, help="Override the maximum model sequence length"
)
model_group.add_argument(
"--override-base-seq-len",
type=str_to_bool,
help="Overrides base model context length",
)
model_group.add_argument(
"--tensor-parallel",
type=str_to_bool,
help="Use tensor parallelism to load models",
)
model_group.add_argument(
"--gpu-split-auto",
type=str_to_bool,
help="Automatically allocate resources to GPUs",
)
model_group.add_argument(
"--autosplit-reserve",
type=int,
nargs="+",
help="Reserve VRAM used for autosplit loading (in MBs) ",
)
model_group.add_argument(
"--gpu-split",
type=float,
nargs="+",
help="An integer array of GBs of vram to split between GPUs. "
+ "Ignored if gpu_split_auto is true",
)
model_group.add_argument(
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
)
model_group.add_argument(
"--rope-alpha",
type=argument_with_auto,
help="Sets rope_alpha for NTK",
)
model_group.add_argument(
"--cache-mode",
type=str,
help="Set the quantization level of the K/V cache. Options: (FP16, Q8, Q6, Q4)",
)
model_group.add_argument(
"--cache-size",
type=int,
help="The size of the prompt cache (in number of tokens) to allocate",
)
model_group.add_argument(
"--chunk-size",
type=int,
help="Chunk size for prompt ingestion",
)
model_group.add_argument(
"--max-batch-size",
type=int,
help="Maximum amount of prompts to process at one time",
)
model_group.add_argument(
"--prompt-template",
type=str,
help="Set the jinja2 prompt template for chat completions",
)
model_group.add_argument(
"--num-experts-per-token",
type=int,
help="Number of experts to use per token in MoE models",
)
model_group.add_argument(
"--fasttensors",
type=str_to_bool,
help="Possibly increases model loading speeds",
)
def add_logging_args(parser: argparse.ArgumentParser):
"""Adds logging arguments"""
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
"--log-prompt", type=str_to_bool, help="Enable prompt logging"
)
logging_group.add_argument(
"--log-generation-params",
type=str_to_bool,
help="Enable generation parameter logging",
)
logging_group.add_argument(
"--log-requests",
type=str_to_bool,
help="Enable request logging",
)
def add_developer_args(parser: argparse.ArgumentParser):
"""Adds developer-specific arguments"""
developer_group = parser.add_argument_group("developer")
developer_group.add_argument(
"--unsafe-launch", type=str_to_bool, help="Skip Exllamav2 version check"
)
developer_group.add_argument(
"--disable-request-streaming",
type=str_to_bool,
help="Disables API request streaming",
)
developer_group.add_argument(
"--cuda-malloc-backend",
type=str_to_bool,
help="Runs with the pytorch CUDA malloc backend",
)
developer_group.add_argument(
"--uvloop",
type=str_to_bool,
help="Run asyncio using Uvloop or Winloop",
)
def add_sampling_args(parser: argparse.ArgumentParser):
"""Adds sampling-specific arguments"""
sampling_group = parser.add_argument_group("sampling")
sampling_group.add_argument(
"--override-preset", type=str, help="Select a sampler override preset"
)
def add_embeddings_args(parser: argparse.ArgumentParser):
"""Adds arguments specific to embeddings"""
embeddings_group = parser.add_argument_group("embeddings")
embeddings_group.add_argument(
"--embedding-model-dir",
type=str,
help="Overrides the directory to look for models",
)
embeddings_group.add_argument(
"--embedding-model-name", type=str, help="An initial model to load"
)
embeddings_group.add_argument(
"--embeddings-device",
type=str,
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
)
return arg_groups