API: Fix load exception handling
Models do not fully unload if an exception is caught in load. Therefore, leave it to the client to unload on cancel. Also add handlers in the event a SSE stream is cancelled. These packets can't be sent back to the client since the client has severed the connection, so print them in terminal. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
7c92968558
commit
8ba3bfa6b3
2 changed files with 18 additions and 15 deletions
29
main.py
29
main.py
|
|
@ -1,6 +1,7 @@
|
|||
import uvicorn
|
||||
import yaml
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from auth import check_admin_key, check_api_key, load_auth_keys
|
||||
from fastapi import FastAPI, Request, HTTPException, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
|
@ -77,7 +78,7 @@ async def get_current_model():
|
|||
|
||||
# Load model endpoint
|
||||
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
|
||||
async def load_model(data: ModelLoadRequest):
|
||||
async def load_model(request: Request, data: ModelLoadRequest):
|
||||
global model_container
|
||||
|
||||
if model_container and model_container.model:
|
||||
|
|
@ -104,18 +105,19 @@ async def load_model(data: ModelLoadRequest):
|
|||
|
||||
model_container = ModelContainer(model_path.resolve(), False, **load_data)
|
||||
|
||||
def generator():
|
||||
async def generator():
|
||||
global model_container
|
||||
|
||||
load_failed = False
|
||||
model_type = "draft" if model_container.draft_enabled else "model"
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
|
||||
# TODO: Maybe create an erroring generator as a common utility function
|
||||
try:
|
||||
for (module, modules) in load_status:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
if module == 0:
|
||||
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
|
||||
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)
|
||||
elif module == modules:
|
||||
loading_bar.next()
|
||||
loading_bar.finish()
|
||||
|
|
@ -142,13 +144,10 @@ async def load_model(data: ModelLoadRequest):
|
|||
)
|
||||
|
||||
yield get_sse_packet(response.json(ensure_ascii=False))
|
||||
except CancelledError as e:
|
||||
print("\nError: Model load cancelled by user. Please make sure to run unload to free up resources.")
|
||||
except Exception as e:
|
||||
yield get_generator_error(e)
|
||||
load_failed = True
|
||||
finally:
|
||||
if load_failed:
|
||||
model_container.unload()
|
||||
model_container = None
|
||||
yield get_generator_error(str(e))
|
||||
|
||||
return StreamingResponse(generator(), media_type = "text/event-stream")
|
||||
|
||||
|
|
@ -201,8 +200,10 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
model_path.name)
|
||||
|
||||
yield get_sse_packet(response.json(ensure_ascii=False))
|
||||
except CancelledError:
|
||||
print("Error: Completion request cancelled by user.")
|
||||
except Exception as e:
|
||||
yield get_generator_error(e)
|
||||
yield get_generator_error(str(e))
|
||||
|
||||
return StreamingResponse(
|
||||
generate_with_semaphore(generator),
|
||||
|
|
@ -251,8 +252,10 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
|||
)
|
||||
|
||||
yield get_sse_packet(finish_response.json(ensure_ascii=False))
|
||||
except CancelledError:
|
||||
print("Error: Chat completion cancelled by user.")
|
||||
except Exception as e:
|
||||
yield get_generator_error(e)
|
||||
yield get_generator_error(str(e))
|
||||
|
||||
return StreamingResponse(
|
||||
generate_with_semaphore(generator),
|
||||
|
|
|
|||
4
utils.py
4
utils.py
|
|
@ -14,9 +14,9 @@ class TabbyGeneratorErrorMessage(BaseModel):
|
|||
class TabbyGeneratorError(BaseModel):
|
||||
error: TabbyGeneratorErrorMessage
|
||||
|
||||
def get_generator_error(exception: Exception):
|
||||
def get_generator_error(message: str):
|
||||
error_message = TabbyGeneratorErrorMessage(
|
||||
message = str(exception),
|
||||
message = message,
|
||||
trace = traceback.format_exc()
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue