From 2a33ebbf2909d6d51fee0782b01acaee37bf2ad3 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 3 Aug 2024 16:05:34 -0400 Subject: [PATCH] Model: Bypass lock checks when shutting down Previously, when a SIGINT was emitted and a model load is running, the API didn't shut down until the load finished due to waitng for the lock. However, when shutting down, the lock doesn't matter since the process is being killed anyway. Signed-off-by: kingbri --- backends/exllamav2/model.py | 19 ++++++++++++------- common/model.py | 4 ++-- common/signals.py | 2 +- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 5233229..3ded71b 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -734,11 +734,15 @@ class ExllamaV2Container: Free all VRAM resources used by this model """ - try: - await self.load_lock.acquire() + # Shutdown immediately unloads and bypasses all locks + do_shutdown = kwargs.get("shutdown") - # Wait for other jobs to finish - await self.wait_for_jobs(kwargs.get("skip_wait")) + try: + if not do_shutdown: + await self.load_lock.acquire() + + # Wait for other jobs to finish + await self.wait_for_jobs(kwargs.get("skip_wait")) # Delete references held in the grammar module clear_grammar_func_cache() @@ -778,10 +782,11 @@ class ExllamaV2Container: logger.info("Loras unloaded." if loras_only else "Model unloaded.") finally: - self.load_lock.release() + if not do_shutdown: + self.load_lock.release() - async with self.load_condition: - self.load_condition.notify_all() + async with self.load_condition: + self.load_condition.notify_all() def encode_tokens(self, text: str, **kwargs): """Wrapper to encode tokens from a text string""" diff --git a/common/model.py b/common/model.py index feedc9f..0bfbab2 100644 --- a/common/model.py +++ b/common/model.py @@ -43,11 +43,11 @@ def load_progress(module, modules): yield module, modules -async def unload_model(skip_wait: bool = False): +async def unload_model(skip_wait: bool = False, shutdown: bool = False): """Unloads a model""" global container - await container.unload(skip_wait=skip_wait) + await container.unload(skip_wait=skip_wait, shutdown=shutdown) container = None diff --git a/common/signals.py b/common/signals.py index d4e144c..97f595b 100644 --- a/common/signals.py +++ b/common/signals.py @@ -29,7 +29,7 @@ async def signal_handler_async(*_): """Internal signal handler. Runs all async code to shut down the program.""" if model.container: - await model.unload_model(skip_wait=True) + await model.unload_model(skip_wait=True, shutdown=True) if model.embeddings_container: await model.unload_embedding_model()