OAI: Fix request cancellation behavior
Depending on the day of the week, Starlette can work with a CancelledError or using await request.is_disconnected(). Run the same behavior for both cases and allow cancellation. Streaming requests now set an event to cancel the batched job and break out of the generation loop. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
094c7b1734
commit
660f9b8432
4 changed files with 41 additions and 10 deletions
|
|
@ -749,7 +749,9 @@ class ExllamaV2Container:
|
|||
|
||||
return kwargs
|
||||
|
||||
async def generate_gen(self, prompt: str, **kwargs):
|
||||
async def generate_gen(
|
||||
self, prompt: str, abort_event: Optional[asyncio.Event] = None, **kwargs
|
||||
):
|
||||
"""
|
||||
Create generator function for prompt completion.
|
||||
|
||||
|
|
@ -1034,9 +1036,14 @@ class ExllamaV2Container:
|
|||
generated_tokens = 0
|
||||
full_response = ""
|
||||
|
||||
# Get the generation status once it's ready
|
||||
try:
|
||||
# Get the generation status once it's ready
|
||||
async for result in job:
|
||||
# Abort if the event is set while streaming
|
||||
if abort_event and abort_event.is_set():
|
||||
await job.cancel()
|
||||
break
|
||||
|
||||
stage = result.get("stage")
|
||||
result_id = result.get("identifier")
|
||||
|
||||
|
|
|
|||
|
|
@ -469,7 +469,7 @@ async def completion_request(request: Request, data: CompletionRequest):
|
|||
|
||||
if data.stream and not disable_request_streaming:
|
||||
return EventSourceResponse(
|
||||
stream_generate_completion(data, model_path),
|
||||
stream_generate_completion(data, request, model_path),
|
||||
ping=maxsize,
|
||||
)
|
||||
else:
|
||||
|
|
@ -516,7 +516,7 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
|
|||
|
||||
if data.stream and not disable_request_streaming:
|
||||
return EventSourceResponse(
|
||||
stream_generate_chat_completion(prompt, data, model_path),
|
||||
stream_generate_chat_completion(prompt, data, request, model_path),
|
||||
ping=maxsize,
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
"""Chat completion utilities for OAI server."""
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import HTTPException, Request
|
||||
from jinja2 import TemplateError
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -192,14 +193,24 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
|||
|
||||
|
||||
async def stream_generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
"""Generator for the generation process."""
|
||||
abort_event = asyncio.Event()
|
||||
|
||||
try:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
|
||||
new_generation = model.container.generate_gen(prompt, **data.to_gen_params())
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, abort_event, **data.to_gen_params()
|
||||
)
|
||||
async for generation in new_generation:
|
||||
# Sometimes this fires, and sometimes a CancelledError will fire
|
||||
# Keep both implementations in to avoid the headache
|
||||
if await request.is_disconnected():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
|
||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
|
@ -210,6 +221,7 @@ async def stream_generate_chat_completion(
|
|||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Chat completion generation cancelled by user.")
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
"""Completion utilities for OAI server."""
|
||||
|
||||
import asyncio
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from fastapi import HTTPException
|
||||
from fastapi import HTTPException, Request
|
||||
from typing import Optional
|
||||
|
||||
from common import model
|
||||
|
|
@ -60,14 +61,24 @@ def _create_response(generation: dict, model_name: Optional[str]):
|
|||
return response
|
||||
|
||||
|
||||
async def stream_generate_completion(data: CompletionRequest, model_path: pathlib.Path):
|
||||
async def stream_generate_completion(
|
||||
data: CompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
"""Streaming generation for completions."""
|
||||
|
||||
abort_event = asyncio.Event()
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(
|
||||
data.prompt, **data.to_gen_params()
|
||||
data.prompt, abort_event, **data.to_gen_params()
|
||||
)
|
||||
async for generation in new_generation:
|
||||
# Sometimes this fires, and sometimes a CancelledError will fire
|
||||
# Keep both implementations in to avoid the headache
|
||||
if await request.is_disconnected():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
|
||||
response = _create_response(generation, model_path.name)
|
||||
yield response.model_dump_json()
|
||||
|
||||
|
|
@ -78,6 +89,7 @@ async def stream_generate_completion(data: CompletionRequest, model_path: pathli
|
|||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue