OAI: Add "n" support for streaming generations

Use a queue-based system to get choices independently and send them
in the overall streaming payload. This method allows for unordered
streaming of generations.

The system is a bit redundant, so maybe make the code more optimized
in the future.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-27 23:53:25 -04:00 committed by Brian Dashore
parent c8371e0f50
commit e2a8b6e8ae
2 changed files with 100 additions and 37 deletions

View file

@ -96,11 +96,13 @@ def _create_stream_chunk(
):
"""Create a chat completion stream chunk from the provided text."""
index = generation.get("index")
logprob_response = None
if "finish_reason" in generation:
choice = ChatCompletionStreamChoice(
finish_reason=generation.get("finish_reason")
index=index,
finish_reason=generation.get("finish_reason"),
)
else:
message = ChatCompletionMessage(
@ -125,6 +127,7 @@ def _create_stream_chunk(
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
choice = ChatCompletionStreamChoice(
index=index,
delta=message,
logprobs=logprob_response,
)
@ -199,34 +202,62 @@ def format_prompt_with_template(data: ChatCompletionRequest):
raise HTTPException(400, error_message) from exc
async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
abort_event: asyncio.Event,
**kwargs,
):
"""Collects a stream and places results in a common queue"""
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
async for generation in new_generation:
generation["index"] = task_idx
await gen_queue.put(generation)
if "finish_reason" in generation:
break
async def stream_generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
):
"""Generator for the generation process."""
const_id = f"chatcmpl-{uuid4().hex}"
abort_event = asyncio.Event()
gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = []
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try:
const_id = f"chatcmpl-{uuid4().hex}"
gen_params = data.to_gen_params()
new_generation = model.container.generate_gen(
prompt, abort_event, **data.to_gen_params()
)
# Create a background task to avoid blocking the loop
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
for n in range(0, data.n):
if n > 0:
task_gen_params = deepcopy(gen_params)
else:
task_gen_params = gen_params
async for generation in new_generation:
# Sometimes this fires, and sometimes a CancelledError will fire
# Keep both implementations in to avoid the headache
gen_task = asyncio.create_task(
_stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params)
)
gen_tasks.append(gen_task)
# Consumer loop
while True:
if disconnect_task.done():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
generation = await gen_queue.get()
response = _create_stream_chunk(const_id, generation, model_path.name)
yield response.model_dump_json()
# Break if the generation is finished
if "finish_reason" in generation:
# Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty():
break
except CancelledError:
# Get out if the request gets disconnected
@ -247,7 +278,6 @@ async def generate_chat_completion(
try:
for n in range(0, data.n):
# Deepcopy gen params above the first index
# to ensure nested structures aren't shared
if n > 0:
@ -256,9 +286,7 @@ async def generate_chat_completion(
task_gen_params = gen_params
gen_tasks.append(
asyncio.create_task(
model.container.generate(prompt, **task_gen_params)
)
asyncio.create_task(model.container.generate(prompt, **task_gen_params))
)
generations = await asyncio.gather(*gen_tasks)

View file

@ -5,7 +5,7 @@ import pathlib
from asyncio import CancelledError
from copy import deepcopy
from fastapi import HTTPException, Request
from typing import List, Optional
from typing import List, Union
from common import model
from common.networking import (
@ -24,13 +24,14 @@ from endpoints.OAI.types.completion import (
from endpoints.OAI.types.common import UsageStats
def _create_response(generations: List[dict], model_name: Optional[str]):
"""Create a completion response from the provided text."""
def _create_response(generations: Union[dict, List[dict]], model_name: str = ""):
"""Create a completion response from the provided choices."""
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
# Convert the single choice object into a list
if not isinstance(generations, list):
generations = [generations]
choices = []
choices: List[CompletionRespChoice] = []
for index, generation in enumerate(generations):
logprob_response = None
@ -46,8 +47,9 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
)
# The index can be located in the generation itself
choice = CompletionRespChoice(
index=index,
index=unwrap(generation.get("index"), index),
finish_reason=generation.get("finish_reason"),
text=unwrap(generation.get("text"), ""),
logprobs=logprob_response,
@ -55,9 +57,12 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
choices.append(choice)
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
response = CompletionResponse(
choices=choices,
model=unwrap(model_name, ""),
model=model_name,
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
@ -68,33 +73,64 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
return response
async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
abort_event: asyncio.Event,
**kwargs,
):
"""Collects a stream and places results in a common queue"""
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
async for generation in new_generation:
generation["index"] = task_idx
await gen_queue.put(generation)
if "finish_reason" in generation:
break
async def stream_generate_completion(
data: CompletionRequest, request: Request, model_path: pathlib.Path
):
"""Streaming generation for completions."""
abort_event = asyncio.Event()
gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = []
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try:
new_generation = model.container.generate_gen(
data.prompt, abort_event, **data.to_gen_params()
)
gen_params = data.to_gen_params()
# Create a background task to avoid blocking the loop
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
for n in range(0, data.n):
if n > 0:
task_gen_params = deepcopy(gen_params)
else:
task_gen_params = gen_params
async for generation in new_generation:
# Sometimes this fires, and sometimes a CancelledError will fire
# Keep both implementations in to avoid the headache
gen_task = asyncio.create_task(
_stream_collector(
n, gen_queue, data.prompt, abort_event, **task_gen_params
)
)
gen_tasks.append(gen_task)
# Consumer loop
while True:
if disconnect_task.done():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
response = _create_response([generation], model_path.name)
generation = await gen_queue.get()
response = _create_response(generation, model_path.name)
yield response.model_dump_json()
# Break if the generation is finished
if "finish_reason" in generation:
# Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty():
yield "[DONE]"
break
except CancelledError:
@ -116,7 +152,6 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
try:
for n in range(0, data.n):
# Deepcopy gen params above the first index
# to ensure nested structures aren't shared
if n > 0: