OAI: Add "n" support for streaming generations
Use a queue-based system to get choices independently and send them in the overall streaming payload. This method allows for unordered streaming of generations. The system is a bit redundant, so maybe make the code more optimized in the future. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
c8371e0f50
commit
e2a8b6e8ae
2 changed files with 100 additions and 37 deletions
|
|
@ -96,11 +96,13 @@ def _create_stream_chunk(
|
|||
):
|
||||
"""Create a chat completion stream chunk from the provided text."""
|
||||
|
||||
index = generation.get("index")
|
||||
logprob_response = None
|
||||
|
||||
if "finish_reason" in generation:
|
||||
choice = ChatCompletionStreamChoice(
|
||||
finish_reason=generation.get("finish_reason")
|
||||
index=index,
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
)
|
||||
else:
|
||||
message = ChatCompletionMessage(
|
||||
|
|
@ -125,6 +127,7 @@ def _create_stream_chunk(
|
|||
logprob_response = ChatCompletionLogprobs(content=[token_prob_response])
|
||||
|
||||
choice = ChatCompletionStreamChoice(
|
||||
index=index,
|
||||
delta=message,
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
|
|
@ -199,34 +202,62 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
|||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
async def _stream_collector(
|
||||
task_idx: int,
|
||||
gen_queue: asyncio.Queue,
|
||||
prompt: str,
|
||||
abort_event: asyncio.Event,
|
||||
**kwargs,
|
||||
):
|
||||
"""Collects a stream and places results in a common queue"""
|
||||
|
||||
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
|
||||
async for generation in new_generation:
|
||||
generation["index"] = task_idx
|
||||
|
||||
await gen_queue.put(generation)
|
||||
|
||||
if "finish_reason" in generation:
|
||||
break
|
||||
|
||||
|
||||
async def stream_generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
"""Generator for the generation process."""
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
abort_event = asyncio.Event()
|
||||
gen_queue = asyncio.Queue()
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
|
||||
try:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
gen_params = data.to_gen_params()
|
||||
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, abort_event, **data.to_gen_params()
|
||||
)
|
||||
# Create a background task to avoid blocking the loop
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
for n in range(0, data.n):
|
||||
if n > 0:
|
||||
task_gen_params = deepcopy(gen_params)
|
||||
else:
|
||||
task_gen_params = gen_params
|
||||
|
||||
async for generation in new_generation:
|
||||
# Sometimes this fires, and sometimes a CancelledError will fire
|
||||
# Keep both implementations in to avoid the headache
|
||||
gen_task = asyncio.create_task(
|
||||
_stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params)
|
||||
)
|
||||
|
||||
gen_tasks.append(gen_task)
|
||||
|
||||
# Consumer loop
|
||||
while True:
|
||||
if disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
|
||||
generation = await gen_queue.get()
|
||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Break if the generation is finished
|
||||
if "finish_reason" in generation:
|
||||
# Check if all tasks are completed
|
||||
if all(task.done() for task in gen_tasks) and gen_queue.empty():
|
||||
break
|
||||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
|
@ -247,7 +278,6 @@ async def generate_chat_completion(
|
|||
|
||||
try:
|
||||
for n in range(0, data.n):
|
||||
|
||||
# Deepcopy gen params above the first index
|
||||
# to ensure nested structures aren't shared
|
||||
if n > 0:
|
||||
|
|
@ -256,9 +286,7 @@ async def generate_chat_completion(
|
|||
task_gen_params = gen_params
|
||||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(prompt, **task_gen_params)
|
||||
)
|
||||
asyncio.create_task(model.container.generate(prompt, **task_gen_params))
|
||||
)
|
||||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import pathlib
|
|||
from asyncio import CancelledError
|
||||
from copy import deepcopy
|
||||
from fastapi import HTTPException, Request
|
||||
from typing import List, Optional
|
||||
from typing import List, Union
|
||||
|
||||
from common import model
|
||||
from common.networking import (
|
||||
|
|
@ -24,13 +24,14 @@ from endpoints.OAI.types.completion import (
|
|||
from endpoints.OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||
"""Create a completion response from the provided text."""
|
||||
def _create_response(generations: Union[dict, List[dict]], model_name: str = ""):
|
||||
"""Create a completion response from the provided choices."""
|
||||
|
||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||
# Convert the single choice object into a list
|
||||
if not isinstance(generations, list):
|
||||
generations = [generations]
|
||||
|
||||
choices = []
|
||||
choices: List[CompletionRespChoice] = []
|
||||
for index, generation in enumerate(generations):
|
||||
logprob_response = None
|
||||
|
||||
|
|
@ -46,8 +47,9 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
|
|||
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
|
||||
)
|
||||
|
||||
# The index can be located in the generation itself
|
||||
choice = CompletionRespChoice(
|
||||
index=index,
|
||||
index=unwrap(generation.get("index"), index),
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
text=unwrap(generation.get("text"), ""),
|
||||
logprobs=logprob_response,
|
||||
|
|
@ -55,9 +57,12 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
|
|||
|
||||
choices.append(choice)
|
||||
|
||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||
|
||||
response = CompletionResponse(
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
model=model_name,
|
||||
usage=UsageStats(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
|
|
@ -68,33 +73,64 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
|
|||
return response
|
||||
|
||||
|
||||
async def _stream_collector(
|
||||
task_idx: int,
|
||||
gen_queue: asyncio.Queue,
|
||||
prompt: str,
|
||||
abort_event: asyncio.Event,
|
||||
**kwargs,
|
||||
):
|
||||
"""Collects a stream and places results in a common queue"""
|
||||
|
||||
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
|
||||
async for generation in new_generation:
|
||||
generation["index"] = task_idx
|
||||
|
||||
await gen_queue.put(generation)
|
||||
|
||||
if "finish_reason" in generation:
|
||||
break
|
||||
|
||||
|
||||
async def stream_generate_completion(
|
||||
data: CompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
"""Streaming generation for completions."""
|
||||
|
||||
abort_event = asyncio.Event()
|
||||
gen_queue = asyncio.Queue()
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(
|
||||
data.prompt, abort_event, **data.to_gen_params()
|
||||
)
|
||||
gen_params = data.to_gen_params()
|
||||
|
||||
# Create a background task to avoid blocking the loop
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
for n in range(0, data.n):
|
||||
if n > 0:
|
||||
task_gen_params = deepcopy(gen_params)
|
||||
else:
|
||||
task_gen_params = gen_params
|
||||
|
||||
async for generation in new_generation:
|
||||
# Sometimes this fires, and sometimes a CancelledError will fire
|
||||
# Keep both implementations in to avoid the headache
|
||||
gen_task = asyncio.create_task(
|
||||
_stream_collector(
|
||||
n, gen_queue, data.prompt, abort_event, **task_gen_params
|
||||
)
|
||||
)
|
||||
|
||||
gen_tasks.append(gen_task)
|
||||
|
||||
# Consumer loop
|
||||
while True:
|
||||
if disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
|
||||
response = _create_response([generation], model_path.name)
|
||||
generation = await gen_queue.get()
|
||||
response = _create_response(generation, model_path.name)
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Break if the generation is finished
|
||||
if "finish_reason" in generation:
|
||||
# Check if all tasks are completed
|
||||
if all(task.done() for task in gen_tasks) and gen_queue.empty():
|
||||
yield "[DONE]"
|
||||
break
|
||||
except CancelledError:
|
||||
|
|
@ -116,7 +152,6 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
|
|||
|
||||
try:
|
||||
for n in range(0, data.n):
|
||||
|
||||
# Deepcopy gen params above the first index
|
||||
# to ensure nested structures aren't shared
|
||||
if n > 0:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue