diff --git a/common/utils.py b/common/utils.py index 5e0ef78..0f207f4 100644 --- a/common/utils.py +++ b/common/utils.py @@ -29,7 +29,7 @@ def get_generator_error(message: str, exc_info: bool = True): generator_error = handle_request_error(message) - return get_sse_packet(generator_error.model_dump_json()) + return generator_error.model_dump_json() def handle_request_error(message: str, exc_info: bool = True): @@ -50,11 +50,6 @@ def handle_request_error(message: str, exc_info: bool = True): return request_error -def get_sse_packet(json_data: str): - """Get an SSE packet.""" - return f"data: {json_data}\n\n" - - def unwrap(wrapped, default=None): """Unwrap function for Optionals.""" if wrapped is None: diff --git a/main.py b/main.py index 916d32d..0aae42f 100644 --- a/main.py +++ b/main.py @@ -5,6 +5,7 @@ import signal import sys import time import threading +from sse_starlette import EventSourceResponse import uvicorn from asyncio import CancelledError from typing import Optional @@ -13,7 +14,6 @@ from jinja2 import TemplateError from fastapi import FastAPI, Depends, HTTPException, Request from fastapi.concurrency import run_in_threadpool from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse from functools import partial from loguru import logger @@ -47,7 +47,6 @@ from common.templating import ( ) from common.utils import ( get_generator_error, - get_sse_packet, handle_request_error, load_progress, unwrap, @@ -235,12 +234,14 @@ async def load_model(request: Request, data: ModelLoadRequest): progress.start() for module, modules in load_status: + + # Get out if the request gets disconnected if await request.is_disconnected(): logger.error( "Model load cancelled by user. " "Please make sure to run unload to free up resources." ) - break + return if module == 0: loading_task = progress.add_task( @@ -256,7 +257,7 @@ async def load_model(request: Request, data: ModelLoadRequest): status="processing", ) - yield get_sse_packet(response.model_dump_json()) + yield response.model_dump_json() if module == modules: response = ModelLoadResponse( @@ -266,7 +267,7 @@ async def load_model(request: Request, data: ModelLoadRequest): status="finished", ) - yield get_sse_packet(response.model_dump_json()) + yield response.model_dump_json() # Switch to model progress if the draft model is loaded if model_type == "draft": @@ -294,7 +295,7 @@ async def load_model(request: Request, data: ModelLoadRequest): else: generator_callback = partial(generate_with_semaphore, generator) - return StreamingResponse(generator_callback(), media_type="text/event-stream") + return EventSourceResponse(generator_callback()) # Unload model endpoint @@ -515,31 +516,30 @@ async def generate_completion(request: Request, data: CompletionRequest): if data.stream and not disable_request_streaming: async def generator(): - """Generator for the generation process.""" try: new_generation = MODEL_CONTAINER.generate_gen( data.prompt, **data.to_gen_params() ) for generation in new_generation: + + # Get out if the request gets disconnected if await request.is_disconnected(): logger.error("Completion generation cancelled by user.") - break + return response = create_completion_response(generation, model_path.name) - yield get_sse_packet(response.model_dump_json()) + yield response.model_dump_json() # Yield a finish response on successful generation - yield get_sse_packet("[DONE]") + yield "[DONE]" except Exception: yield get_generator_error( "Completion aborted. Please check the server console." ) + print("Finished generation") - return StreamingResponse( - generate_with_semaphore(generator), - media_type="text/event-stream", - ) + return EventSourceResponse(generate_with_semaphore(generator)) try: generation = await call_with_semaphore( @@ -620,30 +620,30 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest prompt, **data.to_gen_params() ) for generation in new_generation: + + # Get out if the request gets disconnected if await request.is_disconnected(): logger.error("Chat completion generation cancelled by user.") - break + return response = create_chat_completion_stream_chunk( const_id, generation, model_path.name ) - yield get_sse_packet(response.model_dump_json()) + yield response.model_dump_json() # Yield a finish response on successful generation finish_response = create_chat_completion_stream_chunk( const_id, finish_reason="stop" ) - yield get_sse_packet(finish_response.model_dump_json()) + yield finish_response.model_dump_json() except Exception: yield get_generator_error( "Chat completion aborted. Please check the server console." ) - return StreamingResponse( - generate_with_semaphore(generator), media_type="text/event-stream" - ) + return EventSourceResponse(generate_with_semaphore(generator)) try: generation = await call_with_semaphore(