Args: Add subcommands to run actions

Migrate OpenAPI and sample config export to subcommands "export-openapi"
and "export-config".

Also add a "download" subcommand that passes args to the TabbyAPI
downloader. This allows models to be downloaded via the API and
CLI args.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-02-10 23:14:22 -05:00
parent 30f02e5453
commit 30ab8e04b9
6 changed files with 112 additions and 31 deletions

View file

@ -1,27 +1,51 @@
import argparse
import asyncio
import json
from loguru import logger
from common.tabby_config import config, generate_config_file
from common.downloader import hf_repo_download
from common.tabby_config import generate_config_file
from common.utils import unwrap
from endpoints.server import export_openapi
def branch_to_actions() -> bool:
"""Checks if a optional action needs to be run."""
def download_action(args: argparse.Namespace):
asyncio.run(
hf_repo_download(
repo_id=args.repo_id,
folder_name=args.folder_name,
revision=args.revision,
token=args.token,
include=args.include,
exclude=args.exclude,
)
)
if config.actions.export_openapi:
openapi_json = export_openapi()
with open(config.actions.openapi_export_path, "w") as f:
f.write(json.dumps(openapi_json))
logger.info(
"Successfully wrote OpenAPI spec to "
+ f"{config.actions.openapi_export_path}"
)
elif config.actions.export_config:
generate_config_file(filename=config.actions.config_export_path)
else:
# did not branch
return False
def config_export_action(args: argparse.Namespace):
export_path = unwrap(args.export_path, "config_sample.yml")
generate_config_file(filename=export_path)
# branched and ran an action
return True
def openapi_export_action(args: argparse.Namespace):
export_path = unwrap(args.export_path, "openapi.json")
openapi_json = export_openapi()
with open(export_path, "w") as f:
f.write(json.dumps(openapi_json))
logger.info("Successfully wrote OpenAPI spec to " + f"{export_path}")
def run_subcommand(args: argparse.Namespace) -> bool:
match args.actions:
case "download":
download_action(args)
return True
case "export-config":
config_export_action(args)
return True
case "export-openapi":
openapi_export_action(args)
return True
case _:
return False

View file

@ -37,6 +37,8 @@ def init_argparser(
existing_parser, argparse.ArgumentParser(description="TabbyAPI server")
)
add_subcommands(parser)
# Loop through each top-level field in the config
for field_name, field_info in TabbyConfigModel.model_fields.items():
field_type = unwrap_optional_type(field_info.annotation)
@ -59,6 +61,57 @@ def init_argparser(
return parser
def add_subcommands(parser: argparse.ArgumentParser):
"""Adds subcommands to an existing argparser"""
actions_subparsers = parser.add_subparsers(
dest="actions", help="Extra actions that can be run instead of the main server."
)
# Calls download action
download_parser = actions_subparsers.add_parser(
"download", help="Calls the model downloader"
)
download_parser.add_argument("repo_id", type=str, help="HuggingFace repo ID")
download_parser.add_argument(
"--folder-name",
type=str,
help="Folder name where the model should be downloaded",
)
download_parser.add_argument(
"--revision",
type=str,
help="Branch name in HuggingFace repo",
)
download_parser.add_argument(
"--token", type=str, help="HuggingFace access token for private repos"
)
download_parser.add_argument(
"--include", type=str, help="Glob pattern of files to include"
)
download_parser.add_argument(
"--exclude", type=str, help="Glob pattern of files to exclude"
)
# Calls openapi action
openapi_export_parser = actions_subparsers.add_parser(
"export-openapi", help="Exports an OpenAPI compliant JSON schema"
)
openapi_export_parser.add_argument(
"--export-path",
help="Path to export the generated OpenAPI JSON (default: openapi.json)",
)
# Calls config export action
config_export_parser = actions_subparsers.add_parser(
"export-config", help="Generates and exports a sample config YAML file"
)
config_export_parser.add_argument(
"--export-path",
help="Path to export the generated sample config (default: config_sample.yml)",
)
def convert_args_to_dict(
args: argparse.Namespace, parser: argparse.ArgumentParser
) -> dict:

View file

@ -98,10 +98,10 @@ async def hf_repo_download(
folder_name: Optional[str],
revision: Optional[str],
token: Optional[str],
chunk_limit: Optional[float],
include: Optional[List[str]],
exclude: Optional[List[str]],
timeout: Optional[int],
chunk_limit: Optional[float] = None,
timeout: Optional[int] = None,
repo_type: Optional[str] = "model",
):
"""Gets a repo's information from HuggingFace and downloads it locally."""