diff --git a/common/signals.py b/common/signals.py new file mode 100644 index 0000000..07d7564 --- /dev/null +++ b/common/signals.py @@ -0,0 +1,23 @@ +import signal +import sys +from loguru import logger +from types import FrameType + + +def signal_handler(*_): + """Signal handler for main function. Run before uvicorn starts.""" + + logger.warning("Shutdown signal called. Exiting gracefully.") + sys.exit(0) + + +def uvicorn_signal_handler(signal_event: signal.Signals): + """Overrides uvicorn's signal handler.""" + + default_signal_handler = signal.getsignal(signal_event) + + def wrapped_handler(signum: int, frame: FrameType = None): + logger.warning("Shutdown signal called. Exiting gracefully.") + default_signal_handler(signum, frame) + + signal.signal(signal_event, wrapped_handler) diff --git a/endpoints/OAI/app.py b/endpoints/OAI/app.py index 0571874..9c648b4 100644 --- a/endpoints/OAI/app.py +++ b/endpoints/OAI/app.py @@ -1,5 +1,7 @@ import pathlib +import signal import uvicorn +from contextlib import asynccontextmanager from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from functools import partial @@ -14,6 +16,7 @@ from common.concurrency import ( generate_with_semaphore, ) from common.logger import UVICORN_LOG_CONFIG +from common.signals import uvicorn_signal_handler from common.templating import ( get_all_templates, get_template_from_file, @@ -55,6 +58,14 @@ from endpoints.OAI.utils.completion import ( from endpoints.OAI.utils.model import get_model_list, stream_model_load from endpoints.OAI.utils.lora import get_lora_list + +@asynccontextmanager +async def lifespan(_: FastAPI): + uvicorn_signal_handler(signal.SIGINT) + uvicorn_signal_handler(signal.SIGTERM) + yield + + app = FastAPI( title="TabbyAPI", summary="An OAI compatible exllamav2 API that's both lightweight and fast", @@ -62,6 +73,7 @@ app = FastAPI( "This docs page is not meant to send requests! Please use a service " "like Postman or a frontend UI." ), + lifespan=lifespan, ) # ALlow CORS requests diff --git a/main.py b/main.py index a3d7c30..f6fc52a 100644 --- a/main.py +++ b/main.py @@ -4,7 +4,6 @@ import asyncio import os import pathlib import signal -import sys from loguru import logger from typing import Optional @@ -13,15 +12,11 @@ from common import config, gen_logging, sampling, model from common.args import convert_args_to_dict, init_argparser from common.auth import load_auth_keys from common.logger import setup_logger +from common.signals import signal_handler from common.utils import is_port_in_use, unwrap from endpoints.OAI.app import start_api -def signal_handler(*_): - logger.warning("Shutdown signal called. Exiting gracefully.") - sys.exit(0) - - async def entrypoint(args: Optional[dict] = None): """Entry function for program startup"""