diff --git a/.github/workflows/pages.yml b/.github/workflows/pages.yml index a7b3327..fe54250 100644 --- a/.github/workflows/pages.yml +++ b/.github/workflows/pages.yml @@ -48,10 +48,8 @@ jobs: npm install @redocly/cli -g - name: Export OpenAPI docs run: | - EXPORT_OPENAPI=1 python main.py - mv openapi.json openapi-oai.json - EXPORT_OPENAPI=1 python main.py --api-servers kobold - mv openapi.json openapi-kobold.json + python main.py --export-openapi true --openapi-export-path "openapi-oai.json" --api-servers OAI + python main.py --export-openapi true --openapi-export-path "openapi-kobold.json" --api-servers kobold - name: Build and store Redocly site run: | mkdir static diff --git a/common/actions.py b/common/actions.py new file mode 100644 index 0000000..079a78d --- /dev/null +++ b/common/actions.py @@ -0,0 +1,27 @@ +import json +from loguru import logger +from common.tabby_config import config +from endpoints.server import export_openapi +from common.config_models import generate_config_file + + +def branch_to_actions() -> bool: + if config.actions.export_openapi: + openapi_json = export_openapi() + + with open(config.actions.openapi_export_path, "w") as f: + f.write(json.dumps(openapi_json)) + logger.info( + "Successfully wrote OpenAPI spec to " + + f"{config.actions.openapi_export_path}" + ) + + elif config.actions.export_config: + generate_config_file(config.actions.config_export_path) + + else: + # did not branch + return False + + # branched and ran an action + return True diff --git a/common/args.py b/common/args.py index 22c7681..bd9c67c 100644 --- a/common/args.py +++ b/common/args.py @@ -8,14 +8,28 @@ from pydantic import BaseModel from common.config_models import TabbyConfigModel +def is_list_type(type_hint): + if hasattr(type_hint, "__origin__") and type_hint.__origin__ is list: + return True + if hasattr(type_hint, "__args__"): + # Recursively check for lists inside type arguments + return any(is_list_type(arg) for arg in type_hint.__args__) + return False + + def add_field_to_group(group, field_name, field_type, field) -> None: """ Adds a Pydantic field to an argparse argument group. """ - help_text = field.description if field.description else "No description available" + kwargs = { + "help": field.description if field.description else "No description available", + } - group.add_argument(f"--{field_name}", help=help_text) + if is_list_type(field_type): + kwargs["nargs"] = "+" + + group.add_argument(f"--{field_name}", **kwargs) def init_argparser() -> argparse.ArgumentParser: diff --git a/common/config_models.py b/common/config_models.py index b400d5c..1b371d5 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -1,5 +1,6 @@ from pydantic import BaseModel, ConfigDict, Field from typing import List, Literal, Optional, Union +from pathlib import Path CACHE_SIZES = Literal["FP16", "Q8", "Q6", "Q4"] @@ -11,6 +12,22 @@ class ConfigOverrideConfig(BaseModel): ) +class UtilityActions(BaseModel): + export_config: Optional[str] = Field( + None, description="generate a template config file" + ) + config_export_path: Optional[Path] = Field( + "config_sample.yml", description="path to export configuration file to" + ) + + export_openapi: Optional[bool] = Field( + False, description="export openapi schema files" + ) + openapi_export_path: Optional[Path] = Field( + "openapi.json", description="path to export openapi schema to" + ) + + class NetworkConfig(BaseModel): host: Optional[str] = Field("127.0.0.1", description=("The IP to host on")) port: Optional[int] = Field(5000, description=("The port to host on")) @@ -308,6 +325,7 @@ class TabbyConfigModel(BaseModel): embeddings: EmbeddingsConfig = Field( default_factory=EmbeddingsConfig.model_construct ) + actions: UtilityActions = Field(default_factory=UtilityActions.model_construct) model_config = ConfigDict(validate_assignment=True, protected_namespaces=()) diff --git a/main.py b/main.py index bd70686..429c0e8 100644 --- a/main.py +++ b/main.py @@ -1,7 +1,6 @@ """The main tabbyAPI module. Contains the FastAPI server and endpoints.""" import asyncio -import json import os import pathlib import platform @@ -12,11 +11,12 @@ from typing import Optional from common import gen_logging, sampling, model from common.args import convert_args_to_dict, init_argparser from common.auth import load_auth_keys +from common.actions import branch_to_actions from common.logger import setup_logger from common.networking import is_port_in_use from common.signals import signal_handler from common.tabby_config import config -from endpoints.server import export_openapi, start_api +from endpoints.server import start_api from endpoints.utils import do_export_openapi if not do_export_openapi: @@ -112,13 +112,8 @@ def entrypoint(arguments: Optional[dict] = None): # load config config.load(arguments) - if do_export_openapi: - openapi_json = export_openapi() - - with open("openapi.json", "w") as f: - f.write(json.dumps(openapi_json)) - logger.info("Successfully wrote OpenAPI spec to openapi.json") - + # branch to default paths if required + if branch_to_actions(): return # Check exllamav2 version and give a descriptive error if it's too old