API: Fix disconnect handling on streaming responses

Starlette's StreamingResponse has an issue where it yields after
a request has disconnected. A bugfix to starlette will fix this
issue, but FastAPI uses starlette <= 0.36 which isn't ideal.

Therefore, switch back to sse-starlette which handles these disconnects
correctly.

Also don't try yielding after the request is disconnected. Just return
out of the generator instead.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-10 17:31:47 -04:00
parent 6b4f100db2
commit d45e847c7a
2 changed files with 21 additions and 26 deletions

View file

@ -29,7 +29,7 @@ def get_generator_error(message: str, exc_info: bool = True):
generator_error = handle_request_error(message)
return get_sse_packet(generator_error.model_dump_json())
return generator_error.model_dump_json()
def handle_request_error(message: str, exc_info: bool = True):
@ -50,11 +50,6 @@ def handle_request_error(message: str, exc_info: bool = True):
return request_error
def get_sse_packet(json_data: str):
"""Get an SSE packet."""
return f"data: {json_data}\n\n"
def unwrap(wrapped, default=None):
"""Unwrap function for Optionals."""
if wrapped is None:

40
main.py
View file

@ -5,6 +5,7 @@ import signal
import sys
import time
import threading
from sse_starlette import EventSourceResponse
import uvicorn
from asyncio import CancelledError
from typing import Optional
@ -13,7 +14,6 @@ from jinja2 import TemplateError
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from functools import partial
from loguru import logger
@ -47,7 +47,6 @@ from common.templating import (
)
from common.utils import (
get_generator_error,
get_sse_packet,
handle_request_error,
load_progress,
unwrap,
@ -235,12 +234,14 @@ async def load_model(request: Request, data: ModelLoadRequest):
progress.start()
for module, modules in load_status:
# Get out if the request gets disconnected
if await request.is_disconnected():
logger.error(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
)
break
return
if module == 0:
loading_task = progress.add_task(
@ -256,7 +257,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
status="processing",
)
yield get_sse_packet(response.model_dump_json())
yield response.model_dump_json()
if module == modules:
response = ModelLoadResponse(
@ -266,7 +267,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
status="finished",
)
yield get_sse_packet(response.model_dump_json())
yield response.model_dump_json()
# Switch to model progress if the draft model is loaded
if model_type == "draft":
@ -294,7 +295,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
else:
generator_callback = partial(generate_with_semaphore, generator)
return StreamingResponse(generator_callback(), media_type="text/event-stream")
return EventSourceResponse(generator_callback())
# Unload model endpoint
@ -515,31 +516,30 @@ async def generate_completion(request: Request, data: CompletionRequest):
if data.stream and not disable_request_streaming:
async def generator():
"""Generator for the generation process."""
try:
new_generation = MODEL_CONTAINER.generate_gen(
data.prompt, **data.to_gen_params()
)
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
logger.error("Completion generation cancelled by user.")
break
return
response = create_completion_response(generation, model_path.name)
yield get_sse_packet(response.model_dump_json())
yield response.model_dump_json()
# Yield a finish response on successful generation
yield get_sse_packet("[DONE]")
yield "[DONE]"
except Exception:
yield get_generator_error(
"Completion aborted. Please check the server console."
)
print("Finished generation")
return StreamingResponse(
generate_with_semaphore(generator),
media_type="text/event-stream",
)
return EventSourceResponse(generate_with_semaphore(generator))
try:
generation = await call_with_semaphore(
@ -620,30 +620,30 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
prompt, **data.to_gen_params()
)
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
logger.error("Chat completion generation cancelled by user.")
break
return
response = create_chat_completion_stream_chunk(
const_id, generation, model_path.name
)
yield get_sse_packet(response.model_dump_json())
yield response.model_dump_json()
# Yield a finish response on successful generation
finish_response = create_chat_completion_stream_chunk(
const_id, finish_reason="stop"
)
yield get_sse_packet(finish_response.model_dump_json())
yield finish_response.model_dump_json()
except Exception:
yield get_generator_error(
"Chat completion aborted. Please check the server console."
)
return StreamingResponse(
generate_with_semaphore(generator), media_type="text/event-stream"
)
return EventSourceResponse(generate_with_semaphore(generator))
try:
generation = await call_with_semaphore(