API: Fix error points and exceptions

On /v1/model/load, some internal server errors weren't being sent,
so migrate directory checking out and also add a check to make sure
the proposed model path exists.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-25 00:27:02 -05:00
parent d47c39da54
commit d929e0c826

14
main.py
View file

@ -1,6 +1,6 @@
import uvicorn
import yaml
import pathlib
import pathlib, os
from auth import check_admin_key, check_api_key, load_auth_keys
from fastapi import FastAPI, Request, HTTPException, Depends
from fastapi.middleware.cors import CORSMiddleware
@ -75,6 +75,8 @@ async def get_current_model():
# Load model endpoint
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(data: ModelLoadRequest):
global model_container
if model_container and model_container.model:
raise HTTPException(400, "A model is already loaded! Please unload it first.")
@ -94,13 +96,15 @@ async def load_model(data: ModelLoadRequest):
load_data["draft_model_dir"] = draft_config.get("draft_model_dir") or "models"
if not model_path.exists():
raise HTTPException(400, "model_path does not exist. Check model_name?")
model_container = ModelContainer(model_path.resolve(), False, **load_data)
def generator():
global model_container
model_container = ModelContainer(model_path.resolve(), False, **load_data)
model_type = "draft" if model_container.draft_enabled else "model"
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
loading_bar: IncrementalBar = IncrementalBar("Modules", max = modules)