Config: Fix errors when stuff doesn't exist

Add safe fallbacks if any part of the config tree doesn't exist. This
prevents random internal server errors from showing up.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-16 11:41:03 -05:00
parent 03f45cb0a3
commit 5defb1b0b4

38
main.py
View file

@ -25,13 +25,18 @@ app = FastAPI()
# Globally scoped variables. Undefined until initalized in main
model_container: Optional[ModelContainer] = None
config: Optional[dict] = None
config: dict = {}
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
model_config = config["model"]
models = get_model_list(pathlib.Path(model_config["model_dir"] or "models"))
model_config = config.get("model", {})
if "model_dir" in model_config:
model_path = pathlib.Path(model_config["model_dir"])
else:
model_path = pathlib.Path("models")
models = get_model_list(model_path)
return models.json()
@ -50,8 +55,13 @@ async def load_model(data: ModelLoadRequest):
def generator():
global model_container
model_config = config["model"]
model_path = pathlib.Path(model_config["model_dir"] or "models")
model_config = config.get("model", {})
if "model_dir" in model_config:
model_path = pathlib.Path(model_config["model_dir"])
else:
model_path = pathlib.Path("models")
model_path = model_path / data.name
model_container = ModelContainer(model_path, False, **data.dict())
@ -165,13 +175,17 @@ if __name__ == "__main__":
load_auth_keys()
# Load from YAML config. Possibly add a config -> kwargs conversion function
with open('config.yml', 'r') as config_file:
config = yaml.safe_load(config_file)
try:
with open('config.yml', 'r') as config_file:
config = yaml.safe_load(config_file) or {}
except:
config = {}
# If an initial model name is specified, create a container and load the model
model_config = config["model"]
if model_config["model_name"]:
model_path = pathlib.Path(model_config["model_dir"] or "models")
model_config = config.get("model", {})
if "model_name" in model_config:
model_path = pathlib.Path(model_config.get("model", "models"))
model_path = model_path / model_config["model_name"]
model_container = ModelContainer(model_path, False, **model_config)
@ -185,5 +199,5 @@ if __name__ == "__main__":
else:
loading_bar.next()
network_config = config["network"]
uvicorn.run(app, host=network_config["host"] or "127.0.0.1", port=network_config["port"] or 8012, log_level="debug")
network_config = config.get("network", {})
uvicorn.run(app, host=network_config.get("host", "127.0.0.1"), port=network_config.get("port", 5000), log_level="debug")