API: Handle request disconnect on non-streaming gens

Works the same way as streaming gens. If the request is cancelled,
it will log an error to the user and release the semaphore if it's
holding anything.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-21 23:12:59 -04:00
parent 44b7319710
commit 2961c5f3f9
2 changed files with 48 additions and 4 deletions

View file

@ -1,7 +1,9 @@
"""Common utility functions"""
import asyncio
import socket
import traceback
from fastapi import Request
from loguru import logger
from pydantic import BaseModel
from typing import Optional
@ -60,6 +62,34 @@ def handle_request_disconnect(message: str):
logger.error(message)
async def request_disconnect_loop(request: Request):
"""Polls for a starlette request disconnect."""
while not await request.is_disconnected():
await asyncio.sleep(0.5)
async def run_with_request_disconnect(
request: Request, call_task: asyncio.Task, disconnect_message: str
):
"""Utility function to cancel if a request is disconnected."""
_, unfinished = await asyncio.wait(
[
call_task,
asyncio.create_task(request_disconnect_loop(request)),
],
return_when=asyncio.FIRST_COMPLETED,
)
for task in unfinished:
task.cancel()
try:
return call_task.result()
except (asyncio.CancelledError, asyncio.InvalidStateError):
handle_request_disconnect(disconnect_message)
def unwrap(wrapped, default=None):
"""Unwrap function for Optionals."""
if wrapped is None:

View file

@ -1,3 +1,4 @@
import asyncio
import pathlib
import signal
import uvicorn
@ -25,6 +26,7 @@ from common.templating import (
from common.utils import (
coalesce,
handle_request_error,
run_with_request_disconnect,
unwrap,
)
from endpoints.OAI.types.auth import AuthPermissionResponse
@ -452,10 +454,15 @@ async def completion_request(request: Request, data: CompletionRequest):
ping=maxsize,
)
else:
response = await call_with_semaphore(
partial(generate_completion, data, model_path)
generate_task = asyncio.create_task(
call_with_semaphore(partial(generate_completion, data, model_path))
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Completion generation cancelled by user.",
)
return response
@ -494,10 +501,17 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
ping=maxsize,
)
else:
response = await call_with_semaphore(
partial(generate_chat_completion, prompt, data, model_path)
generate_task = asyncio.create_task(
call_with_semaphore(
partial(generate_chat_completion, prompt, data, model_path)
)
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Chat completion generation cancelled by user.",
)
return response