Args: Add subcommands to run actions

Migrate OpenAPI and sample config export to subcommands "export-openapi"
and "export-config".

Also add a "download" subcommand that passes args to the TabbyAPI
downloader. This allows models to be downloaded via the API and
CLI args.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-02-10 23:14:22 -05:00
parent 30f02e5453
commit 30ab8e04b9
6 changed files with 112 additions and 31 deletions

View file

@ -48,8 +48,8 @@ jobs:
npm install @redocly/cli -g
- name: Export OpenAPI docs
run: |
python main.py --export-openapi true --openapi-export-path "openapi-kobold.json" --api-servers kobold
python main.py --export-openapi true --openapi-export-path "openapi-oai.json" --api-servers OAI
python main.py export-openapi --export-path "openapi-kobold.json"
python main.py export-openapi --export-path "openapi-oai.json"
- name: Build and store Redocly site
run: |
mkdir static

View file

@ -1,27 +1,51 @@
import argparse
import asyncio
import json
from loguru import logger
from common.tabby_config import config, generate_config_file
from common.downloader import hf_repo_download
from common.tabby_config import generate_config_file
from common.utils import unwrap
from endpoints.server import export_openapi
def branch_to_actions() -> bool:
"""Checks if a optional action needs to be run."""
def download_action(args: argparse.Namespace):
asyncio.run(
hf_repo_download(
repo_id=args.repo_id,
folder_name=args.folder_name,
revision=args.revision,
token=args.token,
include=args.include,
exclude=args.exclude,
)
)
if config.actions.export_openapi:
openapi_json = export_openapi()
with open(config.actions.openapi_export_path, "w") as f:
f.write(json.dumps(openapi_json))
logger.info(
"Successfully wrote OpenAPI spec to "
+ f"{config.actions.openapi_export_path}"
)
elif config.actions.export_config:
generate_config_file(filename=config.actions.config_export_path)
else:
# did not branch
return False
def config_export_action(args: argparse.Namespace):
export_path = unwrap(args.export_path, "config_sample.yml")
generate_config_file(filename=export_path)
# branched and ran an action
return True
def openapi_export_action(args: argparse.Namespace):
export_path = unwrap(args.export_path, "openapi.json")
openapi_json = export_openapi()
with open(export_path, "w") as f:
f.write(json.dumps(openapi_json))
logger.info("Successfully wrote OpenAPI spec to " + f"{export_path}")
def run_subcommand(args: argparse.Namespace) -> bool:
match args.actions:
case "download":
download_action(args)
return True
case "export-config":
config_export_action(args)
return True
case "export-openapi":
openapi_export_action(args)
return True
case _:
return False

View file

@ -37,6 +37,8 @@ def init_argparser(
existing_parser, argparse.ArgumentParser(description="TabbyAPI server")
)
add_subcommands(parser)
# Loop through each top-level field in the config
for field_name, field_info in TabbyConfigModel.model_fields.items():
field_type = unwrap_optional_type(field_info.annotation)
@ -59,6 +61,57 @@ def init_argparser(
return parser
def add_subcommands(parser: argparse.ArgumentParser):
"""Adds subcommands to an existing argparser"""
actions_subparsers = parser.add_subparsers(
dest="actions", help="Extra actions that can be run instead of the main server."
)
# Calls download action
download_parser = actions_subparsers.add_parser(
"download", help="Calls the model downloader"
)
download_parser.add_argument("repo_id", type=str, help="HuggingFace repo ID")
download_parser.add_argument(
"--folder-name",
type=str,
help="Folder name where the model should be downloaded",
)
download_parser.add_argument(
"--revision",
type=str,
help="Branch name in HuggingFace repo",
)
download_parser.add_argument(
"--token", type=str, help="HuggingFace access token for private repos"
)
download_parser.add_argument(
"--include", type=str, help="Glob pattern of files to include"
)
download_parser.add_argument(
"--exclude", type=str, help="Glob pattern of files to exclude"
)
# Calls openapi action
openapi_export_parser = actions_subparsers.add_parser(
"export-openapi", help="Exports an OpenAPI compliant JSON schema"
)
openapi_export_parser.add_argument(
"--export-path",
help="Path to export the generated OpenAPI JSON (default: openapi.json)",
)
# Calls config export action
config_export_parser = actions_subparsers.add_parser(
"export-config", help="Generates and exports a sample config YAML file"
)
config_export_parser.add_argument(
"--export-path",
help="Path to export the generated sample config (default: config_sample.yml)",
)
def convert_args_to_dict(
args: argparse.Namespace, parser: argparse.ArgumentParser
) -> dict:

View file

@ -98,10 +98,10 @@ async def hf_repo_download(
folder_name: Optional[str],
revision: Optional[str],
token: Optional[str],
chunk_limit: Optional[float],
include: Optional[List[str]],
exclude: Optional[List[str]],
timeout: Optional[int],
chunk_limit: Optional[float] = None,
timeout: Optional[int] = None,
repo_type: Optional[str] = "model",
):
"""Gets a repo's information from HuggingFace and downloads it locally."""

18
main.py
View file

@ -1,5 +1,6 @@
"""The main tabbyAPI module. Contains the FastAPI server and endpoints."""
import argparse
import asyncio
import os
import pathlib
@ -11,7 +12,7 @@ from typing import Optional
from common import gen_logging, sampling, model
from common.args import convert_args_to_dict, init_argparser
from common.auth import load_auth_keys
from common.actions import branch_to_actions
from common.actions import run_subcommand
from common.logger import setup_logger
from common.networking import is_port_in_use
from common.signals import signal_handler
@ -99,7 +100,10 @@ async def entrypoint_async():
await start_api(host, port)
def entrypoint(arguments: Optional[dict] = None):
def entrypoint(
args: Optional[argparse.Namespace] = None,
parser: Optional[argparse.ArgumentParser] = None,
):
setup_logger()
# Set up signal aborting
@ -115,15 +119,17 @@ def entrypoint(arguments: Optional[dict] = None):
install()
# Parse and override config from args
if arguments is None:
if args is None:
parser = init_argparser()
arguments = convert_args_to_dict(parser.parse_args(), parser)
args = parser.parse_args()
dict_args = convert_args_to_dict(args, parser)
# load config
config.load(arguments)
config.load(dict_args)
# branch to default paths if required
if branch_to_actions():
if run_subcommand(args):
return
# Check exllamav2 version and give a descriptive error if it's too old

View file

@ -275,8 +275,6 @@ if __name__ == "__main__":
from common.args import convert_args_to_dict
from main import entrypoint
converted_args = convert_args_to_dict(args, parser)
# Create a config if it doesn't exist
# This is not necessary to run TabbyAPI, but is new user proof
config_path = (
@ -292,7 +290,7 @@ if __name__ == "__main__":
)
print("Starting TabbyAPI...")
entrypoint(converted_args)
entrypoint(args, parser)
except (ModuleNotFoundError, ImportError):
traceback.print_exc()
print(