Signal: Fix signal handlers for uvicorn

Add the ability to override uvicorn's signal handler in addition
to using main's signal handler for any SIGINTs before the API server
starts.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-16 23:18:27 -04:00 committed by Brian Dashore
parent 95e44c20d6
commit 14d8ec2007
3 changed files with 36 additions and 6 deletions

23
common/signals.py Normal file
View file

@ -0,0 +1,23 @@
import signal
import sys
from loguru import logger
from types import FrameType
def signal_handler(*_):
"""Signal handler for main function. Run before uvicorn starts."""
logger.warning("Shutdown signal called. Exiting gracefully.")
sys.exit(0)
def uvicorn_signal_handler(signal_event: signal.Signals):
"""Overrides uvicorn's signal handler."""
default_signal_handler = signal.getsignal(signal_event)
def wrapped_handler(signum: int, frame: FrameType = None):
logger.warning("Shutdown signal called. Exiting gracefully.")
default_signal_handler(signum, frame)
signal.signal(signal_event, wrapped_handler)

View file

@ -1,5 +1,7 @@
import pathlib
import signal
import uvicorn
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from functools import partial
@ -14,6 +16,7 @@ from common.concurrency import (
generate_with_semaphore,
)
from common.logger import UVICORN_LOG_CONFIG
from common.signals import uvicorn_signal_handler
from common.templating import (
get_all_templates,
get_template_from_file,
@ -55,6 +58,14 @@ from endpoints.OAI.utils.completion import (
from endpoints.OAI.utils.model import get_model_list, stream_model_load
from endpoints.OAI.utils.lora import get_lora_list
@asynccontextmanager
async def lifespan(_: FastAPI):
uvicorn_signal_handler(signal.SIGINT)
uvicorn_signal_handler(signal.SIGTERM)
yield
app = FastAPI(
title="TabbyAPI",
summary="An OAI compatible exllamav2 API that's both lightweight and fast",
@ -62,6 +73,7 @@ app = FastAPI(
"This docs page is not meant to send requests! Please use a service "
"like Postman or a frontend UI."
),
lifespan=lifespan,
)
# ALlow CORS requests

View file

@ -4,7 +4,6 @@ import asyncio
import os
import pathlib
import signal
import sys
from loguru import logger
from typing import Optional
@ -13,15 +12,11 @@ from common import config, gen_logging, sampling, model
from common.args import convert_args_to_dict, init_argparser
from common.auth import load_auth_keys
from common.logger import setup_logger
from common.signals import signal_handler
from common.utils import is_port_in_use, unwrap
from endpoints.OAI.app import start_api
def signal_handler(*_):
logger.warning("Shutdown signal called. Exiting gracefully.")
sys.exit(0)
async def entrypoint(args: Optional[dict] = None):
"""Entry function for program startup"""