OAI: Switch to background task for disconnect checks
Waiting for request disconnect takes some extra time and allows generation chunks to pile up, resulting in large payloads being sent at once not making up a smooth stream. Use the polling method in non-streaming requests by creating a background task and then check if the task is done, signifying that the request has been disconnected. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
660f9b8432
commit
d710a1b441
2 changed files with 11 additions and 2 deletions
|
|
@ -15,6 +15,7 @@ from common.networking import (
|
|||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
handle_request_error,
|
||||
request_disconnect_loop,
|
||||
)
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
|
|
@ -204,10 +205,13 @@ async def stream_generate_chat_completion(
|
|||
new_generation = model.container.generate_gen(
|
||||
prompt, abort_event, **data.to_gen_params()
|
||||
)
|
||||
# Create a background task to avoid blocking the loop
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
|
||||
async for generation in new_generation:
|
||||
# Sometimes this fires, and sometimes a CancelledError will fire
|
||||
# Keep both implementations in to avoid the headache
|
||||
if await request.is_disconnected():
|
||||
if disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from common.networking import (
|
|||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
handle_request_error,
|
||||
request_disconnect_loop,
|
||||
)
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.completion import (
|
||||
|
|
@ -72,10 +73,14 @@ async def stream_generate_completion(
|
|||
new_generation = model.container.generate_gen(
|
||||
data.prompt, abort_event, **data.to_gen_params()
|
||||
)
|
||||
|
||||
# Create a background task to avoid blocking the loop
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
|
||||
async for generation in new_generation:
|
||||
# Sometimes this fires, and sometimes a CancelledError will fire
|
||||
# Keep both implementations in to avoid the headache
|
||||
if await request.is_disconnected():
|
||||
if disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue