From 5defb1b0b4e974fa317382974bd362d88ff35fe1 Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 16 Nov 2023 11:41:03 -0500 Subject: [PATCH] Config: Fix errors when stuff doesn't exist Add safe fallbacks if any part of the config tree doesn't exist. This prevents random internal server errors from showing up. Signed-off-by: kingbri --- main.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/main.py b/main.py index 81022d8..45b161e 100644 --- a/main.py +++ b/main.py @@ -25,13 +25,18 @@ app = FastAPI() # Globally scoped variables. Undefined until initalized in main model_container: Optional[ModelContainer] = None -config: Optional[dict] = None +config: dict = {} @app.get("/v1/models", dependencies=[Depends(check_api_key)]) @app.get("/v1/model/list", dependencies=[Depends(check_api_key)]) async def list_models(): - model_config = config["model"] - models = get_model_list(pathlib.Path(model_config["model_dir"] or "models")) + model_config = config.get("model", {}) + if "model_dir" in model_config: + model_path = pathlib.Path(model_config["model_dir"]) + else: + model_path = pathlib.Path("models") + + models = get_model_list(model_path) return models.json() @@ -50,8 +55,13 @@ async def load_model(data: ModelLoadRequest): def generator(): global model_container - model_config = config["model"] - model_path = pathlib.Path(model_config["model_dir"] or "models") + + model_config = config.get("model", {}) + if "model_dir" in model_config: + model_path = pathlib.Path(model_config["model_dir"]) + else: + model_path = pathlib.Path("models") + model_path = model_path / data.name model_container = ModelContainer(model_path, False, **data.dict()) @@ -165,13 +175,17 @@ if __name__ == "__main__": load_auth_keys() # Load from YAML config. Possibly add a config -> kwargs conversion function - with open('config.yml', 'r') as config_file: - config = yaml.safe_load(config_file) + try: + with open('config.yml', 'r') as config_file: + config = yaml.safe_load(config_file) or {} + except: + config = {} # If an initial model name is specified, create a container and load the model - model_config = config["model"] - if model_config["model_name"]: - model_path = pathlib.Path(model_config["model_dir"] or "models") + + model_config = config.get("model", {}) + if "model_name" in model_config: + model_path = pathlib.Path(model_config.get("model", "models")) model_path = model_path / model_config["model_name"] model_container = ModelContainer(model_path, False, **model_config) @@ -185,5 +199,5 @@ if __name__ == "__main__": else: loading_bar.next() - network_config = config["network"] - uvicorn.run(app, host=network_config["host"] or "127.0.0.1", port=network_config["port"] or 8012, log_level="debug") + network_config = config.get("network", {}) + uvicorn.run(app, host=network_config.get("host", "127.0.0.1"), port=network_config.get("port", 5000), log_level="debug")