diff --git a/common/args.py b/common/args.py index 0548eaf..a0f19c2 100644 --- a/common/args.py +++ b/common/args.py @@ -72,6 +72,12 @@ def add_network_args(parser: argparse.ArgumentParser): type=str_to_bool, help="Decide whether to send error tracebacks over the API", ) + network_group.add_argument( + "--api-servers", + type=str, + nargs="+", + help="API servers to enable. Options: (OAI, Kobold)", + ) def add_model_args(parser: argparse.ArgumentParser): diff --git a/main.py b/main.py index 5ed20f3..b0c5108 100644 --- a/main.py +++ b/main.py @@ -110,15 +110,6 @@ def entrypoint(arguments: Optional[dict] = None): signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) - if do_export_openapi: - openapi_json = export_openapi() - - with open("openapi.json", "w") as f: - f.write(json.dumps(openapi_json)) - logger.info("Successfully wrote OpenAPI spec to openapi.json") - - return - # Load from YAML config config.from_file(pathlib.Path("config.yml")) @@ -128,6 +119,16 @@ def entrypoint(arguments: Optional[dict] = None): arguments = convert_args_to_dict(parser.parse_args(), parser) config.from_args(arguments) + + if do_export_openapi: + openapi_json = export_openapi() + + with open("openapi.json", "w") as f: + f.write(json.dumps(openapi_json)) + logger.info("Successfully wrote OpenAPI spec to openapi.json") + + return + developer_config = config.developer_config() # Check exllamav2 version and give a descriptive error if it's too old