OAI: Switch to background task for disconnect checks

Waiting for request disconnect takes some extra time and allows
generation chunks to pile up, resulting in large payloads being sent
at once not making up a smooth stream.

Use the polling method in non-streaming requests by creating a background
task and then check if the task is done, signifying that the request
has been disconnected.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-26 13:52:20 -04:00
parent 660f9b8432
commit d710a1b441
2 changed files with 11 additions and 2 deletions

View file

@ -15,6 +15,7 @@ from common.networking import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
request_disconnect_loop,
)
from common.utils import unwrap
from endpoints.OAI.types.chat_completion import (
@ -204,10 +205,13 @@ async def stream_generate_chat_completion(
new_generation = model.container.generate_gen(
prompt, abort_event, **data.to_gen_params()
)
# Create a background task to avoid blocking the loop
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
async for generation in new_generation:
# Sometimes this fires, and sometimes a CancelledError will fire
# Keep both implementations in to avoid the headache
if await request.is_disconnected():
if disconnect_task.done():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")

View file

@ -11,6 +11,7 @@ from common.networking import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
request_disconnect_loop,
)
from common.utils import unwrap
from endpoints.OAI.types.completion import (
@ -72,10 +73,14 @@ async def stream_generate_completion(
new_generation = model.container.generate_gen(
data.prompt, abort_event, **data.to_gen_params()
)
# Create a background task to avoid blocking the loop
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
async for generation in new_generation:
# Sometimes this fires, and sometimes a CancelledError will fire
# Keep both implementations in to avoid the headache
if await request.is_disconnected():
if disconnect_task.done():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")