From 56f9b1d1a842291eae4aec90b8db15c80af22d6c Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 30 Nov 2023 00:37:48 -0500 Subject: [PATCH] API: Add generator error handling If the generator errors, there's no proper handling to send an error packet and close the connection. This is especially important for unloading models if the load fails at any stage to reclaim a user's VRAM. Raising an exception caused the model_container object to lock and not get freed by the GC. This made sense to propegate SSE errors across all generator functions rather than relying on abort signals. Signed-off-by: kingbri --- main.py | 104 ++++++++++++++++++++++++++++++++----------------------- utils.py | 28 ++++++++++++++- 2 files changed, 88 insertions(+), 44 deletions(-) diff --git a/main.py b/main.py index e5d7d26..a76de8c 100644 --- a/main.py +++ b/main.py @@ -1,6 +1,6 @@ import uvicorn import yaml -import pathlib, os +import pathlib from auth import check_admin_key, check_api_key, load_auth_keys from fastapi import FastAPI, Request, HTTPException, Depends from fastapi.middleware.cors import CORSMiddleware @@ -24,7 +24,7 @@ from OAI.utils import ( create_chat_completion_stream_chunk ) from typing import Optional -from utils import load_progress +from utils import get_generator_error, load_progress from uuid import uuid4 app = FastAPI() @@ -102,38 +102,50 @@ async def load_model(data: ModelLoadRequest): model_container = ModelContainer(model_path.resolve(), False, **load_data) def generator(): + global model_container + + load_failed = False model_type = "draft" if model_container.draft_enabled else "model" load_status = model_container.load_gen(load_progress) - for (module, modules) in load_status: - if module == 0: - loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) - elif module == modules: - loading_bar.next() - loading_bar.finish() + # TODO: Maybe create an erroring generator as a common utility function + try: + for (module, modules) in load_status: + if module == 0: + loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules) + elif module == modules: + loading_bar.next() + loading_bar.finish() - response = ModelLoadResponse( - model_type=model_type, - module=module, - modules=modules, - status="finished" - ) + response = ModelLoadResponse( + model_type=model_type, + module=module, + modules=modules, + status="finished" + ) - yield response.json(ensure_ascii=False) + yield response.json(ensure_ascii=False) - if model_container.draft_enabled: - model_type = "model" - else: - loading_bar.next() - - response = ModelLoadResponse( - model_type=model_type, - module=module, - modules=modules, - status="processing" - ) + if model_container.draft_enabled: + model_type = "model" + else: + loading_bar.next() - yield response.json(ensure_ascii=False) + response = ModelLoadResponse( + model_type=model_type, + module=module, + modules=modules, + status="processing" + ) + + yield response.json(ensure_ascii=False) + except Exception as e: + yield get_generator_error(e) + load_failed = True + finally: + if load_failed: + model_container.unload() + model_container = None return EventSourceResponse(generator()) @@ -174,14 +186,17 @@ async def generate_completion(request: Request, data: CompletionRequest): if data.stream: async def generator(): - new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params()) - for part in new_generation: - if await request.is_disconnected(): - break + try: + new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params()) + for part in new_generation: + if await request.is_disconnected(): + break - response = create_completion_response(part, model_path.name) + response = create_completion_response(part, model_path.name) - yield response.json(ensure_ascii=False) + yield response.json(ensure_ascii=False) + except Exception as e: + yield get_generator_error(e) return EventSourceResponse(generator()) else: @@ -203,18 +218,21 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest if data.stream: const_id = f"chatcmpl-{uuid4().hex}" async def generator(): - new_generation = model_container.generate_gen(prompt, **data.to_gen_params()) - for part in new_generation: - if await request.is_disconnected(): - break + try: + new_generation = model_container.generate_gen(prompt, **data.to_gen_params()) + for part in new_generation: + if await request.is_disconnected(): + break - response = create_chat_completion_stream_chunk( - const_id, - part, - model_path.name - ) + response = create_chat_completion_stream_chunk( + const_id, + part, + model_path.name + ) - yield response.json(ensure_ascii=False) + yield response.json(ensure_ascii=False) + except Exception as e: + yield get_generator_error(e) return EventSourceResponse(generator()) else: diff --git a/utils.py b/utils.py index 1fa6283..dbfdb6b 100644 --- a/utils.py +++ b/utils.py @@ -1,3 +1,29 @@ +import traceback +from pydantic import BaseModel +from typing import Optional + # Wrapper callback for load progress def load_progress(module, modules): - yield module, modules \ No newline at end of file + yield module, modules + +# Common error types +class TabbyGeneratorErrorMessage(BaseModel): + message: str + trace: Optional[str] = None + +class TabbyGeneratorError(BaseModel): + error: TabbyGeneratorErrorMessage + +def get_generator_error(exception: Exception): + error_message = TabbyGeneratorErrorMessage( + message = str(exception), + trace = traceback.format_exc() + ) + + generator_error = TabbyGeneratorError( + error = error_message + ) + + # Log and send the exception + print(f"\n{generator_error.error.trace}") + return generator_error.json()