API: Fix blocking iterator execution
Run these iterators on the background thread. On startup, the API spawns a background thread as needed to run sync code on without blocking the event loop. Use asyncio's run_thread function since it allows for errors to be propegated. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
7fded4f183
commit
2755fd1af0
4 changed files with 56 additions and 24 deletions
|
|
@ -1,8 +1,6 @@
|
|||
"""The model container class for ExLlamaV2 models."""
|
||||
|
||||
import asyncio
|
||||
import gc
|
||||
from itertools import zip_longest
|
||||
import pathlib
|
||||
import time
|
||||
|
||||
|
|
@ -17,10 +15,12 @@ from exllamav2 import (
|
|||
ExLlamaV2Lora,
|
||||
)
|
||||
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
|
||||
from itertools import zip_longest
|
||||
from loguru import logger
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from backends.exllamav2.grammar import ExLlamaV2Grammar
|
||||
from common.concurrency import iterate_in_threadpool
|
||||
from common.gen_logging import (
|
||||
log_generation_params,
|
||||
log_metrics,
|
||||
|
|
@ -336,7 +336,7 @@ class ExllamaV2Container:
|
|||
def progress(loaded_modules: int, total_modules: int)
|
||||
"""
|
||||
|
||||
for _ in self.load_gen(progress_callback):
|
||||
async for _ in self.load_gen(progress_callback):
|
||||
pass
|
||||
|
||||
async def load_loras(self, lora_directory: pathlib.Path, **kwargs):
|
||||
|
|
@ -373,6 +373,13 @@ class ExllamaV2Container:
|
|||
return {"success": success, "failure": failure}
|
||||
|
||||
async def load_gen(self, progress_callback=None):
|
||||
"""Basic async wrapper around the loading generator"""
|
||||
|
||||
load_generator = self.load_gen_sync(progress_callback)
|
||||
async for value in iterate_in_threadpool(load_generator):
|
||||
yield value
|
||||
|
||||
def load_gen_sync(self, progress_callback=None):
|
||||
"""
|
||||
Load model, generator function
|
||||
|
||||
|
|
@ -407,8 +414,6 @@ class ExllamaV2Container:
|
|||
last_id_only=True,
|
||||
callback_gen=progress_callback,
|
||||
):
|
||||
# Manually suspend the task to allow for other stuff to run
|
||||
await asyncio.sleep(0)
|
||||
if value:
|
||||
yield value
|
||||
|
||||
|
|
@ -429,8 +434,6 @@ class ExllamaV2Container:
|
|||
self.gpu_split,
|
||||
callback_gen=progress_callback,
|
||||
):
|
||||
# Manually suspend the task to allow for other stuff to run
|
||||
await asyncio.sleep(0)
|
||||
if value:
|
||||
yield value
|
||||
|
||||
|
|
@ -459,8 +462,6 @@ class ExllamaV2Container:
|
|||
last_id_only=True,
|
||||
callback_gen=progress_callback,
|
||||
):
|
||||
# Manually suspend the task to allow for other stuff to run
|
||||
await asyncio.sleep(0)
|
||||
if value:
|
||||
yield value
|
||||
|
||||
|
|
@ -627,6 +628,13 @@ class ExllamaV2Container:
|
|||
return kwargs
|
||||
|
||||
async def generate_gen(self, prompt: str, **kwargs):
|
||||
"""Basic async wrapper for completion generator"""
|
||||
|
||||
sync_generator = self.generate_gen_sync(prompt, **kwargs)
|
||||
async for value in iterate_in_threadpool(sync_generator):
|
||||
yield value
|
||||
|
||||
def generate_gen_sync(self, prompt: str, **kwargs):
|
||||
"""
|
||||
Create generator function for prompt completion
|
||||
|
||||
|
|
@ -899,9 +907,6 @@ class ExllamaV2Container:
|
|||
chunk_tokens = 0
|
||||
|
||||
while True:
|
||||
# Manually suspend the task to allow for other stuff to run
|
||||
await asyncio.sleep(0)
|
||||
|
||||
# Ingest prompt
|
||||
if chunk_tokens == 0:
|
||||
ids = torch.cat((ids, save_tokens), dim=-1)
|
||||
|
|
|
|||
|
|
@ -2,12 +2,40 @@
|
|||
|
||||
import asyncio
|
||||
import inspect
|
||||
from fastapi.concurrency import run_in_threadpool # noqa
|
||||
from functools import partialmethod
|
||||
from typing import AsyncGenerator, Generator, Union
|
||||
|
||||
generate_semaphore = asyncio.Semaphore(1)
|
||||
|
||||
|
||||
# Originally from https://github.com/encode/starlette/blob/master/starlette/concurrency.py
|
||||
# Uses generators instead of generics
|
||||
class _StopIteration(Exception):
|
||||
"""Wrapper for StopIteration because it doesn't send across threads."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def gen_next(generator: Generator):
|
||||
"""Threaded function to get the next value in an iterator."""
|
||||
|
||||
try:
|
||||
return next(generator)
|
||||
except StopIteration as e:
|
||||
raise _StopIteration from e
|
||||
|
||||
|
||||
async def iterate_in_threadpool(generator: Generator) -> AsyncGenerator:
|
||||
"""Iterates a generator within a threadpool."""
|
||||
|
||||
while True:
|
||||
try:
|
||||
yield await asyncio.to_thread(gen_next, generator)
|
||||
except _StopIteration:
|
||||
break
|
||||
|
||||
|
||||
def release_semaphore():
|
||||
generate_semaphore.release()
|
||||
|
||||
|
|
@ -16,19 +44,18 @@ async def generate_with_semaphore(generator: Union[AsyncGenerator, Generator]):
|
|||
"""Generate with a semaphore."""
|
||||
|
||||
async with generate_semaphore:
|
||||
if inspect.isasyncgenfunction:
|
||||
async for result in generator():
|
||||
yield result
|
||||
else:
|
||||
for result in generator():
|
||||
yield result
|
||||
if not inspect.isasyncgenfunction:
|
||||
generator = iterate_in_threadpool(generator())
|
||||
|
||||
async for result in generator():
|
||||
yield result
|
||||
|
||||
|
||||
async def call_with_semaphore(callback: partialmethod):
|
||||
"""Call with a semaphore."""
|
||||
|
||||
async with generate_semaphore:
|
||||
if inspect.iscoroutinefunction(callback):
|
||||
return await callback()
|
||||
else:
|
||||
return callback()
|
||||
if not inspect.iscoroutinefunction:
|
||||
callback = run_in_threadpool(callback)
|
||||
|
||||
return await callback()
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs):
|
|||
|
||||
|
||||
async def load_model(model_path: pathlib.Path, **kwargs):
|
||||
async for _, _, _ in load_model_gen(model_path, **kwargs):
|
||||
async for _ in load_model_gen(model_path, **kwargs):
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Completion utilities for OAI server."""
|
||||
|
||||
from asyncio import CancelledError
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from fastapi import HTTPException
|
||||
from typing import Optional
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue