API: Fix load exception handling

Models do not fully unload if an exception is caught in load. Therefore,
leave it to the client to unload on cancel.

Also add handlers in the event a SSE stream is cancelled. These packets
can't be sent back to the client since the client has severed the
connection, so print them in terminal.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-05 00:23:15 -05:00
parent 7c92968558
commit 8ba3bfa6b3
2 changed files with 18 additions and 15 deletions

29
main.py
View file

@ -1,6 +1,7 @@
import uvicorn
import yaml
import pathlib
from asyncio import CancelledError
from auth import check_admin_key, check_api_key, load_auth_keys
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
@ -77,7 +78,7 @@ async def get_current_model():
# Load model endpoint
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest):
async def load_model(request: Request, data: ModelLoadRequest):
global model_container
if model_container and model_container.model:
@ -104,18 +105,19 @@ async def load_model(data: ModelLoadRequest):
model_container = ModelContainer(model_path.resolve(), False, **load_data)
def generator():
async def generator():
global model_container
load_failed = False
model_type = "draft" if model_container.draft_enabled else "model"
load_status = model_container.load_gen(load_progress)
# TODO: Maybe create an erroring generator as a common utility function
try:
for (module, modules) in load_status:
if await request.is_disconnected():
break
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
elif module == modules:
loading_bar.next()
loading_bar.finish()
@ -142,13 +144,10 @@ async def load_model(data: ModelLoadRequest):
)
yield get_sse_packet(response.json(ensure_ascii=False))
except CancelledError as e:
print("\nError: Model load cancelled by user. Please make sure to run unload to free up resources.")
except Exception as e:
yield get_generator_error(e)
load_failed = True
finally:
if load_failed:
model_container.unload()
model_container = None
yield get_generator_error(str(e))
return StreamingResponse(generator(), media_type = "text/event-stream")
@ -201,8 +200,10 @@ async def generate_completion(request: Request, data: CompletionRequest):
model_path.name)
yield get_sse_packet(response.json(ensure_ascii=False))
except CancelledError:
print("Error: Completion request cancelled by user.")
except Exception as e:
yield get_generator_error(e)
yield get_generator_error(str(e))
return StreamingResponse(
generate_with_semaphore(generator),
@ -251,8 +252,10 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
)
yield get_sse_packet(finish_response.json(ensure_ascii=False))
except CancelledError:
print("Error: Chat completion cancelled by user.")
except Exception as e:
yield get_generator_error(e)
yield get_generator_error(str(e))
return StreamingResponse(
generate_with_semaphore(generator),

View file

@ -14,9 +14,9 @@ class TabbyGeneratorErrorMessage(BaseModel):
class TabbyGeneratorError(BaseModel):
error: TabbyGeneratorErrorMessage
def get_generator_error(exception: Exception):
def get_generator_error(message: str):
error_message = TabbyGeneratorErrorMessage(
message = str(exception),
message = message,
trace = traceback.format_exc()
)