API: Fix sequential requests

FastAPI is kinda weird with queueing. If an await is used within an
async def, requests aren't executed sequentially. Get the sequential
requests back by using a semaphore to limit concurrent execution from
generator functions.

Also scaffold the framework to move generator functions to their own
file.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-03 22:54:34 -05:00
parent e740b53478
commit ed6c962aad
2 changed files with 26 additions and 5 deletions

10
generators.py Normal file
View file

@ -0,0 +1,10 @@
from asyncio import Semaphore
from typing import AsyncGenerator
generate_semaphore = Semaphore(1)
# Async generation that blocks on a semaphore
async def generate_with_semaphore(generator: AsyncGenerator):
async with generate_semaphore:
async for result in generator():
yield result

21
main.py
View file

@ -7,6 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from model import ModelContainer
from progress.bar import IncrementalBar
from generators import generate_with_semaphore
from OAI.types.completion import CompletionRequest
from OAI.types.chat_completion import ChatCompletionRequest
from OAI.types.model import ModelCard, ModelLoadRequest, ModelLoadResponse
@ -200,10 +201,15 @@ async def generate_completion(request: Request, data: CompletionRequest):
model_path.name)
yield get_sse_packet(response.json(ensure_ascii=False))
except GeneratorExit:
print("Completion response aborted")
except Exception as e:
yield get_generator_error(e)
return StreamingResponse(generator(), media_type = "text/event-stream")
return StreamingResponse(
generate_with_semaphore(generator),
media_type = "text/event-stream"
)
else:
response_text, prompt_tokens, completion_tokens = model_container.generate(data.prompt, **data.to_gen_params())
response = create_completion_response(response_text,
@ -238,12 +244,14 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
model_path.name
)
yield get_sse_packet(response.json(ensure_ascii=False))
yield get_sse_packet(response.json(ensure_ascii=False))
except GeneratorExit:
print("Chat completion response aborted")
except Exception as e:
yield get_generator_error(e)
finally:
# Always finish no matter what
# FIXME: An error currently fires here since the generator is closed, move this somewhere else
finish_response = create_chat_completion_stream_chunk(
const_id,
finish_reason = "stop"
@ -251,7 +259,10 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
yield get_sse_packet(finish_response.json(ensure_ascii=False))
return StreamingResponse(generator(), media_type = "text/event-stream")
return StreamingResponse(
generate_with_semaphore(generator),
media_type = "text/event-stream"
)
else:
response_text, prompt_tokens, completion_tokens = model_container.generate(prompt, **data.to_gen_params())
response = create_chat_completion_response(response_text,
@ -283,7 +294,7 @@ if __name__ == "__main__":
if "model_name" in model_config:
model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_path = model_path / model_config.get("model_name")
model_container = ModelContainer(model_path.resolve(), False, **model_config)
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status: