From bee26a2f2cf8182c3d1bba3f39e1ef76297ddfb8 Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 21 Feb 2024 23:00:11 -0500 Subject: [PATCH] API: Auto-unload on a load request Automatically unload the existing model when calling /load. This was requested many times, and does make more sense in the long run. Signed-off-by: kingbri --- backends/exllamav2/model.py | 2 ++ main.py | 16 ++++++++++++---- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index bf8b678..1629392 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -464,6 +464,8 @@ class ExllamaV2Container: gc.collect() torch.cuda.empty_cache() + logger.info("Model unloaded.") + def encode_tokens(self, text: str, **kwargs): """Wrapper to encode tokens from a text string""" diff --git a/main.py b/main.py index 85907ef..4b4e512 100644 --- a/main.py +++ b/main.py @@ -172,11 +172,19 @@ async def load_model(request: Request, data: ModelLoadRequest): """Loads a model into the model container.""" global MODEL_CONTAINER - if MODEL_CONTAINER and MODEL_CONTAINER.model: - raise HTTPException(400, "A model is already loaded! Please unload it first.") - if not data.name: - raise HTTPException(400, "model_name not found.") + raise HTTPException(400, "A model name was not provided.") + + # Unload the existing model + if MODEL_CONTAINER and MODEL_CONTAINER.model: + loaded_model_name = MODEL_CONTAINER.get_model_path().name + + if loaded_model_name == data.name: + raise HTTPException( + 400, f"Model \"{loaded_model_name}\"is already loaded! Aborting." + ) + else: + MODEL_CONTAINER.unload() model_path = pathlib.Path(unwrap(get_model_config().get("model_dir"), "models")) model_path = model_path / data.name