automate arg parse
- generate arg parser dynamically - remove legavy parser code
This commit is contained in:
parent
362b8d5818
commit
36e991c16e
1 changed files with 34 additions and 214 deletions
248
common/args.py
248
common/args.py
|
|
@ -1,7 +1,9 @@
|
|||
"""Argparser for overriding config values"""
|
||||
|
||||
import argparse
|
||||
|
||||
from typing import get_origin, get_args, Optional, Union, List
|
||||
from pydantic import BaseModel
|
||||
from common.tabby_config import config
|
||||
|
||||
def str_to_bool(value):
|
||||
"""Converts a string into a boolean value"""
|
||||
|
|
@ -32,24 +34,40 @@ def argument_with_auto(value):
|
|||
|
||||
|
||||
def init_argparser():
|
||||
"""Creates an argument parser that any function can use"""
|
||||
parser = argparse.ArgumentParser(description="TabbyAPI server")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
epilog="NOTE: These args serve to override parts of the config. "
|
||||
+ "It's highly recommended to edit config.yml for all options and "
|
||||
+ "better descriptions!"
|
||||
)
|
||||
add_network_args(parser)
|
||||
add_model_args(parser)
|
||||
add_embeddings_args(parser)
|
||||
add_logging_args(parser)
|
||||
add_developer_args(parser)
|
||||
add_sampling_args(parser)
|
||||
add_config_args(parser)
|
||||
# Loop through the fields in the top-level model (ModelX in this case)
|
||||
for field_name, field_type in config.__annotations__.items():
|
||||
# Get the sub-model type (e.g., ModelA, ModelB)
|
||||
sub_model = field_type.__base__
|
||||
|
||||
# Create argument group for the sub-model
|
||||
group = parser.add_argument_group(field_name, description=f"Arguments for {field_name}")
|
||||
|
||||
# Loop through each field in the sub-model (e.g., ModelA, ModelB)
|
||||
for sub_field_name, sub_field_type in field_type.__annotations__.items():
|
||||
field = field_type.__fields__[sub_field_name]
|
||||
help_text = field.description if field.description else "No description available"
|
||||
|
||||
# Handle Optional types or other generic types
|
||||
origin = get_origin(sub_field_type)
|
||||
if origin is Union: # Check if the type is Union (which includes Optional)
|
||||
sub_field_type = next(t for t in get_args(sub_field_type) if t is not type(None))
|
||||
elif origin is List : sub_field_type = get_args(sub_field_type)[0]
|
||||
|
||||
|
||||
# Map Pydantic types to argparse types
|
||||
print(sub_field_type, type(sub_field_type))
|
||||
if isinstance(sub_field_type, type) and issubclass(sub_field_type, (int, float, str, bool)):
|
||||
arg_type = sub_field_type
|
||||
else:
|
||||
arg_type = str # Default to string for unknown types
|
||||
|
||||
# Add the argument for each field in the sub-model
|
||||
group.add_argument(f"--{sub_field_name}", type=arg_type, help=help_text)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentParser):
|
||||
"""Broad conversion of surface level arg groups to dictionaries"""
|
||||
|
||||
|
|
@ -63,202 +81,4 @@ def convert_args_to_dict(args: argparse.Namespace, parser: argparse.ArgumentPars
|
|||
|
||||
arg_groups[group.title] = group_dict
|
||||
|
||||
return arg_groups
|
||||
|
||||
|
||||
def add_config_args(parser: argparse.ArgumentParser):
|
||||
"""Adds config arguments"""
|
||||
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="Path to an overriding config.yml file"
|
||||
)
|
||||
|
||||
|
||||
def add_network_args(parser: argparse.ArgumentParser):
|
||||
"""Adds networking arguments"""
|
||||
|
||||
network_group = parser.add_argument_group("network")
|
||||
network_group.add_argument("--host", type=str, help="The IP to host on")
|
||||
network_group.add_argument("--port", type=int, help="The port to host on")
|
||||
network_group.add_argument(
|
||||
"--disable-auth",
|
||||
type=str_to_bool,
|
||||
help="Disable HTTP token authenticaion with requests",
|
||||
)
|
||||
network_group.add_argument(
|
||||
"--send-tracebacks",
|
||||
type=str_to_bool,
|
||||
help="Decide whether to send error tracebacks over the API",
|
||||
)
|
||||
network_group.add_argument(
|
||||
"--api-servers",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="API servers to enable. Options: (OAI, Kobold)",
|
||||
)
|
||||
|
||||
|
||||
def add_model_args(parser: argparse.ArgumentParser):
|
||||
"""Adds model arguments"""
|
||||
|
||||
model_group = parser.add_argument_group("model")
|
||||
model_group.add_argument(
|
||||
"--model-dir", type=str, help="Overrides the directory to look for models"
|
||||
)
|
||||
model_group.add_argument("--model-name", type=str, help="An initial model to load")
|
||||
model_group.add_argument(
|
||||
"--use-dummy-models",
|
||||
type=str_to_bool,
|
||||
help="Add dummy OAI model names for API queries",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--use-as-default",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="Names of args to use as a default fallback for API load requests ",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--max-seq-len", type=int, help="Override the maximum model sequence length"
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--override-base-seq-len",
|
||||
type=str_to_bool,
|
||||
help="Overrides base model context length",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--tensor-parallel",
|
||||
type=str_to_bool,
|
||||
help="Use tensor parallelism to load models",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--gpu-split-auto",
|
||||
type=str_to_bool,
|
||||
help="Automatically allocate resources to GPUs",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--autosplit-reserve",
|
||||
type=int,
|
||||
nargs="+",
|
||||
help="Reserve VRAM used for autosplit loading (in MBs) ",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--gpu-split",
|
||||
type=float,
|
||||
nargs="+",
|
||||
help="An integer array of GBs of vram to split between GPUs. "
|
||||
+ "Ignored if gpu_split_auto is true",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--rope-scale", type=float, help="Sets rope_scale or compress_pos_emb"
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--rope-alpha",
|
||||
type=argument_with_auto,
|
||||
help="Sets rope_alpha for NTK",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--cache-mode",
|
||||
type=str,
|
||||
help="Set the quantization level of the K/V cache. Options: (FP16, Q8, Q6, Q4)",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--cache-size",
|
||||
type=int,
|
||||
help="The size of the prompt cache (in number of tokens) to allocate",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--chunk-size",
|
||||
type=int,
|
||||
help="Chunk size for prompt ingestion",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--max-batch-size",
|
||||
type=int,
|
||||
help="Maximum amount of prompts to process at one time",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--prompt-template",
|
||||
type=str,
|
||||
help="Set the jinja2 prompt template for chat completions",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--num-experts-per-token",
|
||||
type=int,
|
||||
help="Number of experts to use per token in MoE models",
|
||||
)
|
||||
model_group.add_argument(
|
||||
"--fasttensors",
|
||||
type=str_to_bool,
|
||||
help="Possibly increases model loading speeds",
|
||||
)
|
||||
|
||||
|
||||
def add_logging_args(parser: argparse.ArgumentParser):
|
||||
"""Adds logging arguments"""
|
||||
|
||||
logging_group = parser.add_argument_group("logging")
|
||||
logging_group.add_argument(
|
||||
"--log-prompt", type=str_to_bool, help="Enable prompt logging"
|
||||
)
|
||||
logging_group.add_argument(
|
||||
"--log-generation-params",
|
||||
type=str_to_bool,
|
||||
help="Enable generation parameter logging",
|
||||
)
|
||||
logging_group.add_argument(
|
||||
"--log-requests",
|
||||
type=str_to_bool,
|
||||
help="Enable request logging",
|
||||
)
|
||||
|
||||
|
||||
def add_developer_args(parser: argparse.ArgumentParser):
|
||||
"""Adds developer-specific arguments"""
|
||||
|
||||
developer_group = parser.add_argument_group("developer")
|
||||
developer_group.add_argument(
|
||||
"--unsafe-launch", type=str_to_bool, help="Skip Exllamav2 version check"
|
||||
)
|
||||
developer_group.add_argument(
|
||||
"--disable-request-streaming",
|
||||
type=str_to_bool,
|
||||
help="Disables API request streaming",
|
||||
)
|
||||
developer_group.add_argument(
|
||||
"--cuda-malloc-backend",
|
||||
type=str_to_bool,
|
||||
help="Runs with the pytorch CUDA malloc backend",
|
||||
)
|
||||
developer_group.add_argument(
|
||||
"--uvloop",
|
||||
type=str_to_bool,
|
||||
help="Run asyncio using Uvloop or Winloop",
|
||||
)
|
||||
|
||||
|
||||
def add_sampling_args(parser: argparse.ArgumentParser):
|
||||
"""Adds sampling-specific arguments"""
|
||||
|
||||
sampling_group = parser.add_argument_group("sampling")
|
||||
sampling_group.add_argument(
|
||||
"--override-preset", type=str, help="Select a sampler override preset"
|
||||
)
|
||||
|
||||
|
||||
def add_embeddings_args(parser: argparse.ArgumentParser):
|
||||
"""Adds arguments specific to embeddings"""
|
||||
|
||||
embeddings_group = parser.add_argument_group("embeddings")
|
||||
embeddings_group.add_argument(
|
||||
"--embedding-model-dir",
|
||||
type=str,
|
||||
help="Overrides the directory to look for models",
|
||||
)
|
||||
embeddings_group.add_argument(
|
||||
"--embedding-model-name", type=str, help="An initial model to load"
|
||||
)
|
||||
embeddings_group.add_argument(
|
||||
"--embeddings-device",
|
||||
type=str,
|
||||
help="Device to use for embeddings. Options: (cpu, auto, cuda)",
|
||||
)
|
||||
return arg_groups
|
||||
Loading…
Add table
Add a link
Reference in a new issue