Args: Add subcommands to run actions
Migrate OpenAPI and sample config export to subcommands "export-openapi" and "export-config". Also add a "download" subcommand that passes args to the TabbyAPI downloader. This allows models to be downloaded via the API and CLI args. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
30f02e5453
commit
30ab8e04b9
6 changed files with 112 additions and 31 deletions
|
|
@ -1,27 +1,51 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
from loguru import logger
|
||||
|
||||
from common.tabby_config import config, generate_config_file
|
||||
from common.downloader import hf_repo_download
|
||||
from common.tabby_config import generate_config_file
|
||||
from common.utils import unwrap
|
||||
from endpoints.server import export_openapi
|
||||
|
||||
|
||||
def branch_to_actions() -> bool:
|
||||
"""Checks if a optional action needs to be run."""
|
||||
def download_action(args: argparse.Namespace):
|
||||
asyncio.run(
|
||||
hf_repo_download(
|
||||
repo_id=args.repo_id,
|
||||
folder_name=args.folder_name,
|
||||
revision=args.revision,
|
||||
token=args.token,
|
||||
include=args.include,
|
||||
exclude=args.exclude,
|
||||
)
|
||||
)
|
||||
|
||||
if config.actions.export_openapi:
|
||||
openapi_json = export_openapi()
|
||||
|
||||
with open(config.actions.openapi_export_path, "w") as f:
|
||||
f.write(json.dumps(openapi_json))
|
||||
logger.info(
|
||||
"Successfully wrote OpenAPI spec to "
|
||||
+ f"{config.actions.openapi_export_path}"
|
||||
)
|
||||
elif config.actions.export_config:
|
||||
generate_config_file(filename=config.actions.config_export_path)
|
||||
else:
|
||||
# did not branch
|
||||
return False
|
||||
def config_export_action(args: argparse.Namespace):
|
||||
export_path = unwrap(args.export_path, "config_sample.yml")
|
||||
generate_config_file(filename=export_path)
|
||||
|
||||
# branched and ran an action
|
||||
return True
|
||||
|
||||
def openapi_export_action(args: argparse.Namespace):
|
||||
export_path = unwrap(args.export_path, "openapi.json")
|
||||
openapi_json = export_openapi()
|
||||
|
||||
with open(export_path, "w") as f:
|
||||
f.write(json.dumps(openapi_json))
|
||||
logger.info("Successfully wrote OpenAPI spec to " + f"{export_path}")
|
||||
|
||||
|
||||
def run_subcommand(args: argparse.Namespace) -> bool:
|
||||
match args.actions:
|
||||
case "download":
|
||||
download_action(args)
|
||||
return True
|
||||
case "export-config":
|
||||
config_export_action(args)
|
||||
return True
|
||||
case "export-openapi":
|
||||
openapi_export_action(args)
|
||||
return True
|
||||
case _:
|
||||
return False
|
||||
|
|
|
|||
|
|
@ -37,6 +37,8 @@ def init_argparser(
|
|||
existing_parser, argparse.ArgumentParser(description="TabbyAPI server")
|
||||
)
|
||||
|
||||
add_subcommands(parser)
|
||||
|
||||
# Loop through each top-level field in the config
|
||||
for field_name, field_info in TabbyConfigModel.model_fields.items():
|
||||
field_type = unwrap_optional_type(field_info.annotation)
|
||||
|
|
@ -59,6 +61,57 @@ def init_argparser(
|
|||
return parser
|
||||
|
||||
|
||||
def add_subcommands(parser: argparse.ArgumentParser):
|
||||
"""Adds subcommands to an existing argparser"""
|
||||
|
||||
actions_subparsers = parser.add_subparsers(
|
||||
dest="actions", help="Extra actions that can be run instead of the main server."
|
||||
)
|
||||
|
||||
# Calls download action
|
||||
download_parser = actions_subparsers.add_parser(
|
||||
"download", help="Calls the model downloader"
|
||||
)
|
||||
download_parser.add_argument("repo_id", type=str, help="HuggingFace repo ID")
|
||||
download_parser.add_argument(
|
||||
"--folder-name",
|
||||
type=str,
|
||||
help="Folder name where the model should be downloaded",
|
||||
)
|
||||
download_parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
help="Branch name in HuggingFace repo",
|
||||
)
|
||||
download_parser.add_argument(
|
||||
"--token", type=str, help="HuggingFace access token for private repos"
|
||||
)
|
||||
download_parser.add_argument(
|
||||
"--include", type=str, help="Glob pattern of files to include"
|
||||
)
|
||||
download_parser.add_argument(
|
||||
"--exclude", type=str, help="Glob pattern of files to exclude"
|
||||
)
|
||||
|
||||
# Calls openapi action
|
||||
openapi_export_parser = actions_subparsers.add_parser(
|
||||
"export-openapi", help="Exports an OpenAPI compliant JSON schema"
|
||||
)
|
||||
openapi_export_parser.add_argument(
|
||||
"--export-path",
|
||||
help="Path to export the generated OpenAPI JSON (default: openapi.json)",
|
||||
)
|
||||
|
||||
# Calls config export action
|
||||
config_export_parser = actions_subparsers.add_parser(
|
||||
"export-config", help="Generates and exports a sample config YAML file"
|
||||
)
|
||||
config_export_parser.add_argument(
|
||||
"--export-path",
|
||||
help="Path to export the generated sample config (default: config_sample.yml)",
|
||||
)
|
||||
|
||||
|
||||
def convert_args_to_dict(
|
||||
args: argparse.Namespace, parser: argparse.ArgumentParser
|
||||
) -> dict:
|
||||
|
|
|
|||
|
|
@ -98,10 +98,10 @@ async def hf_repo_download(
|
|||
folder_name: Optional[str],
|
||||
revision: Optional[str],
|
||||
token: Optional[str],
|
||||
chunk_limit: Optional[float],
|
||||
include: Optional[List[str]],
|
||||
exclude: Optional[List[str]],
|
||||
timeout: Optional[int],
|
||||
chunk_limit: Optional[float] = None,
|
||||
timeout: Optional[int] = None,
|
||||
repo_type: Optional[str] = "model",
|
||||
):
|
||||
"""Gets a repo's information from HuggingFace and downloads it locally."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue