API: Add generator error handling
If the generator errors, there's no proper handling to send an error packet and close the connection. This is especially important for unloading models if the load fails at any stage to reclaim a user's VRAM. Raising an exception caused the model_container object to lock and not get freed by the GC. This made sense to propegate SSE errors across all generator functions rather than relying on abort signals. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
2bc3da0155
commit
56f9b1d1a8
2 changed files with 88 additions and 44 deletions
104
main.py
104
main.py
|
|
@ -1,6 +1,6 @@
|
|||
import uvicorn
|
||||
import yaml
|
||||
import pathlib, os
|
||||
import pathlib
|
||||
from auth import check_admin_key, check_api_key, load_auth_keys
|
||||
from fastapi import FastAPI, Request, HTTPException, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
|
@ -24,7 +24,7 @@ from OAI.utils import (
|
|||
create_chat_completion_stream_chunk
|
||||
)
|
||||
from typing import Optional
|
||||
from utils import load_progress
|
||||
from utils import get_generator_error, load_progress
|
||||
from uuid import uuid4
|
||||
|
||||
app = FastAPI()
|
||||
|
|
@ -102,38 +102,50 @@ async def load_model(data: ModelLoadRequest):
|
|||
model_container = ModelContainer(model_path.resolve(), False, **load_data)
|
||||
|
||||
def generator():
|
||||
global model_container
|
||||
|
||||
load_failed = False
|
||||
model_type = "draft" if model_container.draft_enabled else "model"
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
|
||||
for (module, modules) in load_status:
|
||||
if module == 0:
|
||||
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
|
||||
elif module == modules:
|
||||
loading_bar.next()
|
||||
loading_bar.finish()
|
||||
# TODO: Maybe create an erroring generator as a common utility function
|
||||
try:
|
||||
for (module, modules) in load_status:
|
||||
if module == 0:
|
||||
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
|
||||
elif module == modules:
|
||||
loading_bar.next()
|
||||
loading_bar.finish()
|
||||
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="finished"
|
||||
)
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="finished"
|
||||
)
|
||||
|
||||
yield response.json(ensure_ascii=False)
|
||||
yield response.json(ensure_ascii=False)
|
||||
|
||||
if model_container.draft_enabled:
|
||||
model_type = "model"
|
||||
else:
|
||||
loading_bar.next()
|
||||
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="processing"
|
||||
)
|
||||
if model_container.draft_enabled:
|
||||
model_type = "model"
|
||||
else:
|
||||
loading_bar.next()
|
||||
|
||||
yield response.json(ensure_ascii=False)
|
||||
response = ModelLoadResponse(
|
||||
model_type=model_type,
|
||||
module=module,
|
||||
modules=modules,
|
||||
status="processing"
|
||||
)
|
||||
|
||||
yield response.json(ensure_ascii=False)
|
||||
except Exception as e:
|
||||
yield get_generator_error(e)
|
||||
load_failed = True
|
||||
finally:
|
||||
if load_failed:
|
||||
model_container.unload()
|
||||
model_container = None
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
|
||||
|
|
@ -174,14 +186,17 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
|
||||
if data.stream:
|
||||
async def generator():
|
||||
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params())
|
||||
for part in new_generation:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
try:
|
||||
new_generation = model_container.generate_gen(data.prompt, **data.to_gen_params())
|
||||
for part in new_generation:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
response = create_completion_response(part, model_path.name)
|
||||
response = create_completion_response(part, model_path.name)
|
||||
|
||||
yield response.json(ensure_ascii=False)
|
||||
yield response.json(ensure_ascii=False)
|
||||
except Exception as e:
|
||||
yield get_generator_error(e)
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
|
|
@ -203,18 +218,21 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
if data.stream:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
async def generator():
|
||||
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
|
||||
for part in new_generation:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
try:
|
||||
new_generation = model_container.generate_gen(prompt, **data.to_gen_params())
|
||||
for part in new_generation:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
response = create_chat_completion_stream_chunk(
|
||||
const_id,
|
||||
part,
|
||||
model_path.name
|
||||
)
|
||||
response = create_chat_completion_stream_chunk(
|
||||
const_id,
|
||||
part,
|
||||
model_path.name
|
||||
)
|
||||
|
||||
yield response.json(ensure_ascii=False)
|
||||
yield response.json(ensure_ascii=False)
|
||||
except Exception as e:
|
||||
yield get_generator_error(e)
|
||||
|
||||
return EventSourceResponse(generator())
|
||||
else:
|
||||
|
|
|
|||
28
utils.py
28
utils.py
|
|
@ -1,3 +1,29 @@
|
|||
import traceback
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
# Wrapper callback for load progress
|
||||
def load_progress(module, modules):
|
||||
yield module, modules
|
||||
yield module, modules
|
||||
|
||||
# Common error types
|
||||
class TabbyGeneratorErrorMessage(BaseModel):
|
||||
message: str
|
||||
trace: Optional[str] = None
|
||||
|
||||
class TabbyGeneratorError(BaseModel):
|
||||
error: TabbyGeneratorErrorMessage
|
||||
|
||||
def get_generator_error(exception: Exception):
|
||||
error_message = TabbyGeneratorErrorMessage(
|
||||
message = str(exception),
|
||||
trace = traceback.format_exc()
|
||||
)
|
||||
|
||||
generator_error = TabbyGeneratorError(
|
||||
error = error_message
|
||||
)
|
||||
|
||||
# Log and send the exception
|
||||
print(f"\n{generator_error.error.trace}")
|
||||
return generator_error.json()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue