From 933268f7e276b31fffac4879ba63a0ede0e508c9 Mon Sep 17 00:00:00 2001 From: kingbri Date: Mon, 8 Jul 2024 12:34:32 -0400 Subject: [PATCH] API: Integrate OpenAPI export script Move OpenAPI export as an env var within the main function. This allows for easy export by running main. In addition, an env variable provides global and explicit state to disable conditional wheel imports (ex. Exl2 and torch) which caused errors at first. Signed-off-by: kingbri --- common/model.py | 9 ++++++--- endpoints/server.py | 9 +++++++++ generate_openapi.py | 13 ------------- main.py | 13 ++++++++++++- 4 files changed, 27 insertions(+), 17 deletions(-) delete mode 100644 generate_openapi.py diff --git a/common/model.py b/common/model.py index 14d3cc0..f37c765 100644 --- a/common/model.py +++ b/common/model.py @@ -4,17 +4,20 @@ Manages the storage and utility of model containers. Containers exist as a common interface for backends. """ +import os import pathlib from loguru import logger from typing import Optional -from backends.exllamav2.model import ExllamaV2Container from common import config from common.logger import get_loading_progress_bar from common.utils import unwrap -# Global model container -container: Optional[ExllamaV2Container] = None +if not os.getenv("EXPORT_OPENAPI"): + from backends.exllamav2.model import ExllamaV2Container + + # Global model container + container: Optional[ExllamaV2Container] = None def load_progress(module, modules): diff --git a/endpoints/server.py b/endpoints/server.py index 2cdaa72..7ceb208 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -26,9 +26,18 @@ app.add_middleware( def setup_app(): + """Includes the correct routers for startup""" + app.include_router(OAIRouter) +def export_openapi(): + """Function to return the OpenAPI JSON from the API server""" + + setup_app() + return app.openapi() + + async def start_api(host: str, port: int): """Isolated function to start the API server""" diff --git a/generate_openapi.py b/generate_openapi.py deleted file mode 100644 index 0ef4f80..0000000 --- a/generate_openapi.py +++ /dev/null @@ -1,13 +0,0 @@ -import json -from endpoints import server - - -if __name__ == "__main__": - """Uses the FastAPI server to write an OpenAPI JSON documentation file.""" - - server.setup_app() - openapi_json = json.dumps(server.app.openapi()) - - # Write JSON to a file - with open("openapi.json", "w") as f: - f.write(openapi_json) diff --git a/main.py b/main.py index 54b44f9..fe17351 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,8 @@ """The main tabbyAPI module. Contains the FastAPI server and endpoints.""" import asyncio +import aiofiles +import json import os import pathlib import signal @@ -15,7 +17,7 @@ from common.logger import setup_logger from common.networking import is_port_in_use from common.signals import signal_handler from common.utils import unwrap -from endpoints.server import start_api +from endpoints.server import export_openapi, start_api async def entrypoint(args: Optional[dict] = None): @@ -27,6 +29,15 @@ async def entrypoint(args: Optional[dict] = None): signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) + if os.getenv("EXPORT_OPENAPI"): + openapi_json = export_openapi() + + async with aiofiles.open("openapi.json", "w") as f: + await f.write(json.dumps(openapi_json)) + logger.info("Successfully wrote OpenAPI spec to openapi.json") + + return + # Load from YAML config config.from_file(pathlib.Path("config.yml"))