OAI: Fix request cancellation behavior

Depending on the day of the week, Starlette can work with a CancelledError
or using await request.is_disconnected(). Run the same behavior for both
cases and allow cancellation.

Streaming requests now set an event to cancel the batched job and break
out of the generation loop.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-26 13:00:33 -04:00
parent 094c7b1734
commit 660f9b8432
4 changed files with 41 additions and 10 deletions

View file

@ -749,7 +749,9 @@ class ExllamaV2Container:
return kwargs
async def generate_gen(self, prompt: str, **kwargs):
async def generate_gen(
self, prompt: str, abort_event: Optional[asyncio.Event] = None, **kwargs
):
"""
Create generator function for prompt completion.
@ -1034,9 +1036,14 @@ class ExllamaV2Container:
generated_tokens = 0
full_response = ""
# Get the generation status once it's ready
try:
# Get the generation status once it's ready
async for result in job:
# Abort if the event is set while streaming
if abort_event and abort_event.is_set():
await job.cancel()
break
stage = result.get("stage")
result_id = result.get("identifier")

View file

@ -469,7 +469,7 @@ async def completion_request(request: Request, data: CompletionRequest):
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_completion(data, model_path),
stream_generate_completion(data, request, model_path),
ping=maxsize,
)
else:
@ -516,7 +516,7 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
if data.stream and not disable_request_streaming:
return EventSourceResponse(
stream_generate_chat_completion(prompt, data, model_path),
stream_generate_chat_completion(prompt, data, request, model_path),
ping=maxsize,
)
else:

View file

@ -1,11 +1,12 @@
"""Chat completion utilities for OAI server."""
import asyncio
import pathlib
from asyncio import CancelledError
from typing import Optional
from uuid import uuid4
from fastapi import HTTPException
from fastapi import HTTPException, Request
from jinja2 import TemplateError
from loguru import logger
@ -192,14 +193,24 @@ def format_prompt_with_template(data: ChatCompletionRequest):
async def stream_generate_chat_completion(
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
):
"""Generator for the generation process."""
abort_event = asyncio.Event()
try:
const_id = f"chatcmpl-{uuid4().hex}"
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
new_generation = model.container.generate_gen(
prompt, abort_event, **data.to_gen_params()
)
async for generation in new_generation:
# Sometimes this fires, and sometimes a CancelledError will fire
# Keep both implementations in to avoid the headache
if await request.is_disconnected():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
response = _create_stream_chunk(const_id, generation, model_path.name)
yield response.model_dump_json()
@ -210,6 +221,7 @@ async def stream_generate_chat_completion(
except CancelledError:
# Get out if the request gets disconnected
abort_event.set()
handle_request_disconnect("Chat completion generation cancelled by user.")
except Exception:
yield get_generator_error(

View file

@ -1,8 +1,9 @@
"""Completion utilities for OAI server."""
import asyncio
import pathlib
from asyncio import CancelledError
from fastapi import HTTPException
from fastapi import HTTPException, Request
from typing import Optional
from common import model
@ -60,14 +61,24 @@ def _create_response(generation: dict, model_name: Optional[str]):
return response
async def stream_generate_completion(data: CompletionRequest, model_path: pathlib.Path):
async def stream_generate_completion(
data: CompletionRequest, request: Request, model_path: pathlib.Path
):
"""Streaming generation for completions."""
abort_event = asyncio.Event()
try:
new_generation = model.container.generate_gen(
data.prompt, **data.to_gen_params()
data.prompt, abort_event, **data.to_gen_params()
)
async for generation in new_generation:
# Sometimes this fires, and sometimes a CancelledError will fire
# Keep both implementations in to avoid the headache
if await request.is_disconnected():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
response = _create_response(generation, model_path.name)
yield response.model_dump_json()
@ -78,6 +89,7 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
except CancelledError:
# Get out if the request gets disconnected
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
except Exception:
yield get_generator_error(