diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 5233229..3ded71b 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -734,11 +734,15 @@ class ExllamaV2Container: Free all VRAM resources used by this model """ - try: - await self.load_lock.acquire() + # Shutdown immediately unloads and bypasses all locks + do_shutdown = kwargs.get("shutdown") - # Wait for other jobs to finish - await self.wait_for_jobs(kwargs.get("skip_wait")) + try: + if not do_shutdown: + await self.load_lock.acquire() + + # Wait for other jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) # Delete references held in the grammar module clear_grammar_func_cache() @@ -778,10 +782,11 @@ class ExllamaV2Container: logger.info("Loras unloaded." if loras_only else "Model unloaded.") finally: - self.load_lock.release() + if not do_shutdown: + self.load_lock.release() - async with self.load_condition: - self.load_condition.notify_all() + async with self.load_condition: + self.load_condition.notify_all() def encode_tokens(self, text: str, **kwargs): """Wrapper to encode tokens from a text string""" diff --git a/common/model.py b/common/model.py index feedc9f..0bfbab2 100644 --- a/common/model.py +++ b/common/model.py @@ -43,11 +43,11 @@ def load_progress(module, modules): yield module, modules -async def unload_model(skip_wait: bool = False): +async def unload_model(skip_wait: bool = False, shutdown: bool = False): """Unloads a model""" global container - await container.unload(skip_wait=skip_wait) + await container.unload(skip_wait=skip_wait, shutdown=shutdown) container = None diff --git a/common/signals.py b/common/signals.py index d4e144c..97f595b 100644 --- a/common/signals.py +++ b/common/signals.py @@ -29,7 +29,7 @@ async def signal_handler_async(*_): """Internal signal handler. Runs all async code to shut down the program.""" if model.container: - await model.unload_model(skip_wait=True) + await model.unload_model(skip_wait=True, shutdown=True) if model.embeddings_container: await model.unload_embedding_model()