API: Fix disconnect handling on streaming responses
Starlette's StreamingResponse has an issue where it yields after a request has disconnected. A bugfix to starlette will fix this issue, but FastAPI uses starlette <= 0.36 which isn't ideal. Therefore, switch back to sse-starlette which handles these disconnects correctly. Also don't try yielding after the request is disconnected. Just return out of the generator instead. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
6b4f100db2
commit
d45e847c7a
2 changed files with 21 additions and 26 deletions
40
main.py
40
main.py
|
|
@ -5,6 +5,7 @@ import signal
|
|||
import sys
|
||||
import time
|
||||
import threading
|
||||
from sse_starlette import EventSourceResponse
|
||||
import uvicorn
|
||||
from asyncio import CancelledError
|
||||
from typing import Optional
|
||||
|
|
@ -13,7 +14,6 @@ from jinja2 import TemplateError
|
|||
from fastapi import FastAPI, Depends, HTTPException, Request
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import StreamingResponse
|
||||
from functools import partial
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -47,7 +47,6 @@ from common.templating import (
|
|||
)
|
||||
from common.utils import (
|
||||
get_generator_error,
|
||||
get_sse_packet,
|
||||
handle_request_error,
|
||||
load_progress,
|
||||
unwrap,
|
||||
|
|
@ -235,12 +234,14 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
progress.start()
|
||||
|
||||
for module, modules in load_status:
|
||||
|
||||
# Get out if the request gets disconnected
|
||||
if await request.is_disconnected():
|
||||
logger.error(
|
||||
"Model load cancelled by user. "
|
||||
"Please make sure to run unload to free up resources."
|
||||
)
|
||||
break
|
||||
return
|
||||
|
||||
if module == 0:
|
||||
loading_task = progress.add_task(
|
||||
|
|
@ -256,7 +257,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
status="processing",
|
||||
)
|
||||
|
||||
yield get_sse_packet(response.model_dump_json())
|
||||
yield response.model_dump_json()
|
||||
|
||||
if module == modules:
|
||||
response = ModelLoadResponse(
|
||||
|
|
@ -266,7 +267,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
status="finished",
|
||||
)
|
||||
|
||||
yield get_sse_packet(response.model_dump_json())
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Switch to model progress if the draft model is loaded
|
||||
if model_type == "draft":
|
||||
|
|
@ -294,7 +295,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
|
|||
else:
|
||||
generator_callback = partial(generate_with_semaphore, generator)
|
||||
|
||||
return StreamingResponse(generator_callback(), media_type="text/event-stream")
|
||||
return EventSourceResponse(generator_callback())
|
||||
|
||||
|
||||
# Unload model endpoint
|
||||
|
|
@ -515,31 +516,30 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
if data.stream and not disable_request_streaming:
|
||||
|
||||
async def generator():
|
||||
"""Generator for the generation process."""
|
||||
try:
|
||||
new_generation = MODEL_CONTAINER.generate_gen(
|
||||
data.prompt, **data.to_gen_params()
|
||||
)
|
||||
for generation in new_generation:
|
||||
|
||||
# Get out if the request gets disconnected
|
||||
if await request.is_disconnected():
|
||||
logger.error("Completion generation cancelled by user.")
|
||||
break
|
||||
return
|
||||
|
||||
response = create_completion_response(generation, model_path.name)
|
||||
|
||||
yield get_sse_packet(response.model_dump_json())
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Yield a finish response on successful generation
|
||||
yield get_sse_packet("[DONE]")
|
||||
yield "[DONE]"
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Completion aborted. Please check the server console."
|
||||
)
|
||||
print("Finished generation")
|
||||
|
||||
return StreamingResponse(
|
||||
generate_with_semaphore(generator),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
return EventSourceResponse(generate_with_semaphore(generator))
|
||||
|
||||
try:
|
||||
generation = await call_with_semaphore(
|
||||
|
|
@ -620,30 +620,30 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
prompt, **data.to_gen_params()
|
||||
)
|
||||
for generation in new_generation:
|
||||
|
||||
# Get out if the request gets disconnected
|
||||
if await request.is_disconnected():
|
||||
logger.error("Chat completion generation cancelled by user.")
|
||||
break
|
||||
return
|
||||
|
||||
response = create_chat_completion_stream_chunk(
|
||||
const_id, generation, model_path.name
|
||||
)
|
||||
|
||||
yield get_sse_packet(response.model_dump_json())
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Yield a finish response on successful generation
|
||||
finish_response = create_chat_completion_stream_chunk(
|
||||
const_id, finish_reason="stop"
|
||||
)
|
||||
|
||||
yield get_sse_packet(finish_response.model_dump_json())
|
||||
yield finish_response.model_dump_json()
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Chat completion aborted. Please check the server console."
|
||||
)
|
||||
|
||||
return StreamingResponse(
|
||||
generate_with_semaphore(generator), media_type="text/event-stream"
|
||||
)
|
||||
return EventSourceResponse(generate_with_semaphore(generator))
|
||||
|
||||
try:
|
||||
generation = await call_with_semaphore(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue