From 42c0dbe795e8c31af76d5433fc24b91585d470fd Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 10 Mar 2024 17:53:55 -0400 Subject: [PATCH] Generation: Explicitly release semaphore on disconnect This prevents any lockups when querying another request. Signed-off-by: kingbri --- common/generators.py | 4 ++++ main.py | 9 ++++++++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/common/generators.py b/common/generators.py index d919ea3..485ca63 100644 --- a/common/generators.py +++ b/common/generators.py @@ -8,6 +8,10 @@ from typing import AsyncGenerator, Generator, Union generate_semaphore = asyncio.Semaphore(1) +def release_semaphore(): + generate_semaphore.release() + + async def generate_with_semaphore(generator: Union[AsyncGenerator, Generator]): """Generate with a semaphore.""" diff --git a/main.py b/main.py index 6b5dbb7..1fc76a6 100644 --- a/main.py +++ b/main.py @@ -34,7 +34,11 @@ from common.config import ( get_lora_config, get_network_config, ) -from common.generators import call_with_semaphore, generate_with_semaphore +from common.generators import ( + call_with_semaphore, + generate_with_semaphore, + release_semaphore, +) from common.sampling import ( get_sampler_overrides, set_overrides_from_file, @@ -236,6 +240,7 @@ async def load_model(request: Request, data: ModelLoadRequest): for module, modules in load_status: # Get out if the request gets disconnected if await request.is_disconnected(): + release_semaphore() logger.error( "Model load cancelled by user. " "Please make sure to run unload to free up resources." @@ -522,6 +527,7 @@ async def generate_completion(request: Request, data: CompletionRequest): for generation in new_generation: # Get out if the request gets disconnected if await request.is_disconnected(): + release_semaphore() logger.error("Completion generation cancelled by user.") return @@ -620,6 +626,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest for generation in new_generation: # Get out if the request gets disconnected if await request.is_disconnected(): + release_semaphore() logger.error("Chat completion generation cancelled by user.") return