API: Handle request disconnect on non-streaming gens
Works the same way as streaming gens. If the request is cancelled, it will log an error to the user and release the semaphore if it's holding anything. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
44b7319710
commit
2961c5f3f9
2 changed files with 48 additions and 4 deletions
|
|
@ -1,7 +1,9 @@
|
|||
"""Common utility functions"""
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import traceback
|
||||
from fastapi import Request
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
|
@ -60,6 +62,34 @@ def handle_request_disconnect(message: str):
|
|||
logger.error(message)
|
||||
|
||||
|
||||
async def request_disconnect_loop(request: Request):
|
||||
"""Polls for a starlette request disconnect."""
|
||||
|
||||
while not await request.is_disconnected():
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
async def run_with_request_disconnect(
|
||||
request: Request, call_task: asyncio.Task, disconnect_message: str
|
||||
):
|
||||
"""Utility function to cancel if a request is disconnected."""
|
||||
|
||||
_, unfinished = await asyncio.wait(
|
||||
[
|
||||
call_task,
|
||||
asyncio.create_task(request_disconnect_loop(request)),
|
||||
],
|
||||
return_when=asyncio.FIRST_COMPLETED,
|
||||
)
|
||||
for task in unfinished:
|
||||
task.cancel()
|
||||
|
||||
try:
|
||||
return call_task.result()
|
||||
except (asyncio.CancelledError, asyncio.InvalidStateError):
|
||||
handle_request_disconnect(disconnect_message)
|
||||
|
||||
|
||||
def unwrap(wrapped, default=None):
|
||||
"""Unwrap function for Optionals."""
|
||||
if wrapped is None:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue