API: Append task index to generations with n > 1

Since jobs are tracked via request IDs now, each generation task should
be uniquely identified in the event of cancellation.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-28 22:29:48 -04:00
parent b43f0983c8
commit 9157be3e34
2 changed files with 36 additions and 12 deletions

View file

@ -29,7 +29,7 @@ from endpoints.OAI.types.chat_completion import (
ChatCompletionStreamChoice,
)
from endpoints.OAI.types.common import UsageStats
from endpoints.OAI.utils.completion import _stream_collector
from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector
from endpoints.OAI.utils.tools import ToolCallProcessor
@ -326,14 +326,17 @@ async def stream_generate_chat_completion(
try:
logger.info(f"Received chat completion streaming request {request.state.id}")
for n in range(0, data.n):
for idx in range(0, data.n):
task_gen_params = data.model_copy(deep=True)
request_id = _parse_gen_request_id(
data.n, request.state.id, idx
)
gen_task = asyncio.create_task(
_stream_collector(
n,
idx,
gen_queue,
request.state.id,
request_id,
prompt,
task_gen_params,
abort_event,
@ -418,11 +421,15 @@ async def generate_chat_completion(
gen_tasks: List[asyncio.Task] = []
try:
for _ in range(0, data.n):
for idx in range(0, data.n):
request_id = _parse_gen_request_id(
data.n, request.state.id, idx
)
gen_tasks.append(
asyncio.create_task(
model.container.generate(
request.state.id,
request_id,
prompt,
data,
mm_embeddings=embeddings,
@ -484,10 +491,14 @@ async def generate_tool_calls(
data, current_generations
)
request_id = _parse_gen_request_id(
data.n, request.state.id, idx
)
gen_tasks.append(
asyncio.create_task(
model.container.generate(
request.state.id,
request_id,
pre_tool_prompt,
tool_data,
embeddings=mm_embeddings,

View file

@ -31,6 +31,13 @@ from endpoints.OAI.types.completion import (
from endpoints.OAI.types.common import UsageStats
def _parse_gen_request_id(n: int, request_id: str, task_idx: int):
if n > 1:
return f"{request_id}-{task_idx}"
else:
return request_id
def _create_response(
request_id: str, generations: Union[dict, List[dict]], model_name: str = ""
):
@ -193,14 +200,17 @@ async def stream_generate_completion(
try:
logger.info(f"Received streaming completion request {request.state.id}")
for n in range(0, data.n):
for idx in range(0, data.n):
task_gen_params = data.model_copy(deep=True)
request_id = _parse_gen_request_id(
data.n, request.state.id, idx
)
gen_task = asyncio.create_task(
_stream_collector(
n,
idx,
gen_queue,
request.state.id,
request_id,
data.prompt,
task_gen_params,
abort_event,
@ -255,13 +265,16 @@ async def generate_completion(
try:
logger.info(f"Recieved completion request {request.state.id}")
for _ in range(0, data.n):
for idx in range(0, data.n):
task_gen_params = data.model_copy(deep=True)
request_id = _parse_gen_request_id(
data.n, request.state.id, idx
)
gen_tasks.append(
asyncio.create_task(
model.container.generate(
request.state.id,
request_id,
data.prompt,
task_gen_params,
)