API: Fix blocking iterator execution

Run these iterators on the background thread. On startup, the API
spawns a background thread as needed to run sync code on without blocking
the event loop.

Use asyncio's run_thread function since it allows for errors to be
propegated.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-16 22:31:50 -04:00 committed by Brian Dashore
parent 7fded4f183
commit 2755fd1af0
4 changed files with 56 additions and 24 deletions

View file

@ -1,8 +1,6 @@
"""The model container class for ExLlamaV2 models."""
import asyncio
import gc
from itertools import zip_longest
import pathlib
import time
@ -17,10 +15,12 @@ from exllamav2 import (
ExLlamaV2Lora,
)
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
from itertools import zip_longest
from loguru import logger
from typing import List, Optional, Union
from backends.exllamav2.grammar import ExLlamaV2Grammar
from common.concurrency import iterate_in_threadpool
from common.gen_logging import (
log_generation_params,
log_metrics,
@ -336,7 +336,7 @@ class ExllamaV2Container:
def progress(loaded_modules: int, total_modules: int)
"""
for _ in self.load_gen(progress_callback):
async for _ in self.load_gen(progress_callback):
pass
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
@ -373,6 +373,13 @@ class ExllamaV2Container:
return {"success": success, "failure": failure}
async def load_gen(self, progress_callback=None):
"""Basic async wrapper around the loading generator"""
load_generator = self.load_gen_sync(progress_callback)
async for value in iterate_in_threadpool(load_generator):
yield value
def load_gen_sync(self, progress_callback=None):
"""
Load model, generator function
@ -407,8 +414,6 @@ class ExllamaV2Container:
last_id_only=True,
callback_gen=progress_callback,
):
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
if value:
yield value
@ -429,8 +434,6 @@ class ExllamaV2Container:
self.gpu_split,
callback_gen=progress_callback,
):
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
if value:
yield value
@ -459,8 +462,6 @@ class ExllamaV2Container:
last_id_only=True,
callback_gen=progress_callback,
):
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
if value:
yield value
@ -627,6 +628,13 @@ class ExllamaV2Container:
return kwargs
async def generate_gen(self, prompt: str, **kwargs):
"""Basic async wrapper for completion generator"""
sync_generator = self.generate_gen_sync(prompt, **kwargs)
async for value in iterate_in_threadpool(sync_generator):
yield value
def generate_gen_sync(self, prompt: str, **kwargs):
"""
Create generator function for prompt completion
@ -899,9 +907,6 @@ class ExllamaV2Container:
chunk_tokens = 0
while True:
# Manually suspend the task to allow for other stuff to run
await asyncio.sleep(0)
# Ingest prompt
if chunk_tokens == 0:
ids = torch.cat((ids, save_tokens), dim=-1)

View file

@ -2,12 +2,40 @@
import asyncio
import inspect
from fastapi.concurrency import run_in_threadpool # noqa
from functools import partialmethod
from typing import AsyncGenerator, Generator, Union
generate_semaphore = asyncio.Semaphore(1)
# Originally from https://github.com/encode/starlette/blob/master/starlette/concurrency.py
# Uses generators instead of generics
class _StopIteration(Exception):
"""Wrapper for StopIteration because it doesn't send across threads."""
pass
def gen_next(generator: Generator):
"""Threaded function to get the next value in an iterator."""
try:
return next(generator)
except StopIteration as e:
raise _StopIteration from e
async def iterate_in_threadpool(generator: Generator) -> AsyncGenerator:
"""Iterates a generator within a threadpool."""
while True:
try:
yield await asyncio.to_thread(gen_next, generator)
except _StopIteration:
break
def release_semaphore():
generate_semaphore.release()
@ -16,19 +44,18 @@ async def generate_with_semaphore(generator: Union[AsyncGenerator, Generator]):
"""Generate with a semaphore."""
async with generate_semaphore:
if inspect.isasyncgenfunction:
async for result in generator():
yield result
else:
for result in generator():
yield result
if not inspect.isasyncgenfunction:
generator = iterate_in_threadpool(generator())
async for result in generator():
yield result
async def call_with_semaphore(callback: partialmethod):
"""Call with a semaphore."""
async with generate_semaphore:
if inspect.iscoroutinefunction(callback):
return await callback()
else:
return callback()
if not inspect.iscoroutinefunction:
callback = run_in_threadpool(callback)
return await callback()

View file

@ -72,7 +72,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
async def load_model(model_path: pathlib.Path, **kwargs):
async for _, _, _ in load_model_gen(model_path, **kwargs):
async for _ in load_model_gen(model_path, **kwargs):
pass

View file

@ -1,7 +1,7 @@
"""Completion utilities for OAI server."""
from asyncio import CancelledError
import pathlib
from asyncio import CancelledError
from fastapi import HTTPException
from typing import Optional