diff --git a/common/model.py b/common/model.py index 14d3cc0..f37c765 100644 --- a/common/model.py +++ b/common/model.py @@ -4,17 +4,20 @@ Manages the storage and utility of model containers. Containers exist as a common interface for backends. """ +import os import pathlib from loguru import logger from typing import Optional -from backends.exllamav2.model import ExllamaV2Container from common import config from common.logger import get_loading_progress_bar from common.utils import unwrap -# Global model container -container: Optional[ExllamaV2Container] = None +if not os.getenv("EXPORT_OPENAPI"): + from backends.exllamav2.model import ExllamaV2Container + + # Global model container + container: Optional[ExllamaV2Container] = None def load_progress(module, modules): diff --git a/endpoints/server.py b/endpoints/server.py index 2cdaa72..7ceb208 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -26,9 +26,18 @@ app.add_middleware( def setup_app(): + """Includes the correct routers for startup""" + app.include_router(OAIRouter) +def export_openapi(): + """Function to return the OpenAPI JSON from the API server""" + + setup_app() + return app.openapi() + + async def start_api(host: str, port: int): """Isolated function to start the API server""" diff --git a/generate_openapi.py b/generate_openapi.py deleted file mode 100644 index 0ef4f80..0000000 --- a/generate_openapi.py +++ /dev/null @@ -1,13 +0,0 @@ -import json -from endpoints import server - - -if __name__ == "__main__": - """Uses the FastAPI server to write an OpenAPI JSON documentation file.""" - - server.setup_app() - openapi_json = json.dumps(server.app.openapi()) - - # Write JSON to a file - with open("openapi.json", "w") as f: - f.write(openapi_json) diff --git a/main.py b/main.py index 54b44f9..fe17351 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,8 @@ """The main tabbyAPI module. Contains the FastAPI server and endpoints.""" import asyncio +import aiofiles +import json import os import pathlib import signal @@ -15,7 +17,7 @@ from common.logger import setup_logger from common.networking import is_port_in_use from common.signals import signal_handler from common.utils import unwrap -from endpoints.server import start_api +from endpoints.server import export_openapi, start_api async def entrypoint(args: Optional[dict] = None): @@ -27,6 +29,15 @@ async def entrypoint(args: Optional[dict] = None): signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) + if os.getenv("EXPORT_OPENAPI"): + openapi_json = export_openapi() + + async with aiofiles.open("openapi.json", "w") as f: + await f.write(json.dumps(openapi_json)) + logger.info("Successfully wrote OpenAPI spec to openapi.json") + + return + # Load from YAML config config.from_file(pathlib.Path("config.yml"))