Generation: Explicitly release semaphore on disconnect

This prevents any lockups when querying another request.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-10 17:53:55 -04:00
parent 2025a1c857
commit 42c0dbe795
2 changed files with 12 additions and 1 deletions

View file

@ -34,7 +34,11 @@ from common.config import (
get_lora_config,
get_network_config,
)
from common.generators import call_with_semaphore, generate_with_semaphore
from common.generators import (
call_with_semaphore,
generate_with_semaphore,
release_semaphore,
)
from common.sampling import (
get_sampler_overrides,
set_overrides_from_file,
@ -236,6 +240,7 @@ async def load_model(request: Request, data: ModelLoadRequest):
for module, modules in load_status:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error(
"Model load cancelled by user. "
"Please make sure to run unload to free up resources."
@ -522,6 +527,7 @@ async def generate_completion(request: Request, data: CompletionRequest):
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Completion generation cancelled by user.")
return
@ -620,6 +626,7 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Chat completion generation cancelled by user.")
return