diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index 12c95a2..0b7c1a6 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -55,7 +55,14 @@ async def completion_request( """ if data.model: - await load_inline_model(data.model, request) + inline_load_task = asyncio.create_task(load_inline_model(data.model, request)) + + await run_with_request_disconnect( + request, + inline_load_task, + disconnect_message=f"Model switch for generation {request.state.id} " + + "cancelled by user.", + ) else: await check_model_container() diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index 2f51175..df4bf19 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -112,8 +112,12 @@ async def _stream_collector( async def load_inline_model(model_name: str, request: Request): """Load a model from the data.model parameter""" - # Return if the model container already exists - if model.container and model.container.model_dir.name == model_name: + # Return if the model container already exists and the model is fully loaded + if ( + model.container + and model.container.model_dir.name == model_name + and model.container.model_loaded + ): return # Inline model loading isn't enabled or the user isn't an admin