API + Model: Add blocks and checks for various load requests

Add a sequential lock and wait until jobs are completed before executing
any loading requests that directly alter the model. However, we also
need to block any new requests that come in until the load is finished,
so add a condition that triggers once the lock is free.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-25 18:24:11 -04:00 committed by Brian Dashore
parent 408c66a1f2
commit 43cd7f57e8
5 changed files with 268 additions and 249 deletions

View file

@ -1,12 +1,8 @@
"""Concurrency handling"""
import asyncio
import inspect
from fastapi.concurrency import run_in_threadpool # noqa
from functools import partialmethod
from typing import AsyncGenerator, Generator, Union
generate_semaphore = asyncio.Semaphore(1)
from typing import AsyncGenerator, Generator
# Originally from https://github.com/encode/starlette/blob/master/starlette/concurrency.py
@ -34,24 +30,3 @@ async def iterate_in_threadpool(generator: Generator) -> AsyncGenerator:
yield await asyncio.to_thread(gen_next, generator)
except _StopIteration:
break
async def generate_with_semaphore(generator: Union[AsyncGenerator, Generator]):
"""Generate with a semaphore."""
async with generate_semaphore:
if not inspect.isasyncgenfunction:
generator = iterate_in_threadpool(generator())
async for result in generator():
yield result
async def call_with_semaphore(callback: partialmethod):
"""Call with a semaphore."""
async with generate_semaphore:
if not inspect.iscoroutinefunction:
callback = run_in_threadpool(callback)
return await callback()

View file

@ -20,11 +20,11 @@ def load_progress(module, modules):
yield module, modules
async def unload_model():
async def unload_model(skip_wait: bool = False):
"""Unloads a model"""
global container
container.unload()
await container.unload(skip_wait=skip_wait)
container = None
@ -49,7 +49,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
container = ExllamaV2Container(model_path.resolve(), False, **kwargs)
model_type = "draft" if container.draft_config else "model"
load_status = container.load_gen(load_progress)
load_status = container.load_gen(load_progress, **kwargs)
progress = get_loading_progress_bar()
progress.start()
@ -81,12 +81,12 @@ async def load_model(model_path: pathlib.Path, **kwargs):
async def load_loras(lora_dir, **kwargs):
"""Wrapper to load loras."""
if len(container.active_loras) > 0:
unload_loras()
if len(container.get_loras()) > 0:
await unload_loras()
return await container.load_loras(lora_dir, **kwargs)
def unload_loras():
async def unload_loras():
"""Wrapper to unload loras"""
container.unload(loras_only=True)
await container.unload(loras_only=True)