API: Add generator error handling

If the generator errors, there's no proper handling to send an error
packet and close the connection.

This is especially important for unloading models if the load fails
at any stage to reclaim a user's VRAM. Raising an exception caused
the model_container object to lock and not get freed by the GC.

This made sense to propegate SSE errors across all generator functions
rather than relying on abort signals.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-30 00:37:48 -05:00
parent 2bc3da0155
commit 56f9b1d1a8
2 changed files with 88 additions and 44 deletions

104
main.py
View file

@ -1,6 +1,6 @@
import uvicorn
import yaml
import pathlib, os
import pathlib
from auth import check_admin_key, check_api_key, load_auth_keys
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
@ -24,7 +24,7 @@ from OAI.utils import (
create_chat_completion_stream_chunk
)
from typing import Optional
from utils import load_progress
from utils import get_generator_error, load_progress
from uuid import uuid4
app = FastAPI()
@ -102,38 +102,50 @@ async def load_model(data: ModelLoadRequest):
model_container = ModelContainer(model_path.resolve(), False, **load_data)
def generator():
global model_container
load_failed = False
model_type = "draft" if model_container.draft_enabled else "model"
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
elif module == modules:
loading_bar.next()
loading_bar.finish()
# TODO: Maybe create an erroring generator as a common utility function
try:
for (module, modules) in load_status:
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
elif module == modules:
loading_bar.next()
loading_bar.finish()
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="finished"
)
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="finished"
)
yield response.json(ensure_ascii=False)
yield response.json(ensure_ascii=False)
if model_container.draft_enabled:
model_type = "model"
else:
loading_bar.next()
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="processing"
)
if model_container.draft_enabled:
model_type = "model"
else:
loading_bar.next()
yield response.json(ensure_ascii=False)
response = ModelLoadResponse(
model_type=model_type,
module=module,
modules=modules,
status="processing"
)
yield response.json(ensure_ascii=False)
except Exception as e:
yield get_generator_error(e)
load_failed = True
finally:
if load_failed:
model_container.unload()
model_container = None
return EventSourceResponse(generator())
@ -174,14 +186,17 @@ async def generate_completion(request: Request, data: CompletionRequest):
if data.stream:
async def generator():
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params())
for part in new_generation:
if await request.is_disconnected():
break
try:
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params())
for part in new_generation:
if await request.is_disconnected():
break
response = create_completion_response(part, model_path.name)
response = create_completion_response(part, model_path.name)
yield response.json(ensure_ascii=False)
yield response.json(ensure_ascii=False)
except Exception as e:
yield get_generator_error(e)
return EventSourceResponse(generator())
else:
@ -203,18 +218,21 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
if data.stream:
const_id = f"chatcmpl-{uuid4().hex}"
async def generator():
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
for part in new_generation:
if await request.is_disconnected():
break
try:
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
for part in new_generation:
if await request.is_disconnected():
break
response = create_chat_completion_stream_chunk(
const_id,
part,
model_path.name
)
response = create_chat_completion_stream_chunk(
const_id,
part,
model_path.name
)
yield response.json(ensure_ascii=False)
yield response.json(ensure_ascii=False)
except Exception as e:
yield get_generator_error(e)
return EventSourceResponse(generator())
else:

View file

@ -1,3 +1,29 @@
import traceback
from pydantic import BaseModel
from typing import Optional
# Wrapper callback for load progress
def load_progress(module, modules):
yield module, modules
yield module, modules
# Common error types
class TabbyGeneratorErrorMessage(BaseModel):
message: str
trace: Optional[str] = None
class TabbyGeneratorError(BaseModel):
error: TabbyGeneratorErrorMessage
def get_generator_error(exception: Exception):
error_message = TabbyGeneratorErrorMessage(
message = str(exception),
trace = traceback.format_exc()
)
generator_error = TabbyGeneratorError(
error = error_message
)
# Log and send the exception
print(f"\n{generator_error.error.trace}")
return generator_error.json()