tabbyAPI-ollama/common/concurrency.py
kingbri 2755fd1af0 API: Fix blocking iterator execution
Run these iterators on the background thread. On startup, the API
spawns a background thread as needed to run sync code on without blocking
the event loop.

Use asyncio's run_thread function since it allows for errors to be
propegated.

Signed-off-by: kingbri <bdashore3@proton.me>
2024-03-16 23:23:31 -04:00

61 lines
1.6 KiB
Python

"""Concurrency handling"""
import asyncio
import inspect
from fastapi.concurrency import run_in_threadpool # noqa
from functools import partialmethod
from typing import AsyncGenerator, Generator, Union
generate_semaphore = asyncio.Semaphore(1)
# Originally from https://github.com/encode/starlette/blob/master/starlette/concurrency.py
# Uses generators instead of generics
class _StopIteration(Exception):
"""Wrapper for StopIteration because it doesn't send across threads."""
pass
def gen_next(generator: Generator):
"""Threaded function to get the next value in an iterator."""
try:
return next(generator)
except StopIteration as e:
raise _StopIteration from e
async def iterate_in_threadpool(generator: Generator) -> AsyncGenerator:
"""Iterates a generator within a threadpool."""
while True:
try:
yield await asyncio.to_thread(gen_next, generator)
except _StopIteration:
break
def release_semaphore():
generate_semaphore.release()
async def generate_with_semaphore(generator: Union[AsyncGenerator, Generator]):
"""Generate with a semaphore."""
async with generate_semaphore:
if not inspect.isasyncgenfunction:
generator = iterate_in_threadpool(generator())
async for result in generator():
yield result
async def call_with_semaphore(callback: partialmethod):
"""Call with a semaphore."""
async with generate_semaphore:
if not inspect.iscoroutinefunction:
callback = run_in_threadpool(callback)
return await callback()