API: Append task index to generations with n > 1
Since jobs are tracked via request IDs now, each generation task should be uniquely identified in the event of cancellation. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
b43f0983c8
commit
9157be3e34
2 changed files with 36 additions and 12 deletions
|
|
@ -29,7 +29,7 @@ from endpoints.OAI.types.chat_completion import (
|
|||
ChatCompletionStreamChoice,
|
||||
)
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
from endpoints.OAI.utils.completion import _stream_collector
|
||||
from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector
|
||||
from endpoints.OAI.utils.tools import ToolCallProcessor
|
||||
|
||||
|
||||
|
|
@ -326,14 +326,17 @@ async def stream_generate_chat_completion(
|
|||
try:
|
||||
logger.info(f"Received chat completion streaming request {request.state.id}")
|
||||
|
||||
for n in range(0, data.n):
|
||||
for idx in range(0, data.n):
|
||||
task_gen_params = data.model_copy(deep=True)
|
||||
request_id = _parse_gen_request_id(
|
||||
data.n, request.state.id, idx
|
||||
)
|
||||
|
||||
gen_task = asyncio.create_task(
|
||||
_stream_collector(
|
||||
n,
|
||||
idx,
|
||||
gen_queue,
|
||||
request.state.id,
|
||||
request_id,
|
||||
prompt,
|
||||
task_gen_params,
|
||||
abort_event,
|
||||
|
|
@ -418,11 +421,15 @@ async def generate_chat_completion(
|
|||
gen_tasks: List[asyncio.Task] = []
|
||||
|
||||
try:
|
||||
for _ in range(0, data.n):
|
||||
for idx in range(0, data.n):
|
||||
request_id = _parse_gen_request_id(
|
||||
data.n, request.state.id, idx
|
||||
)
|
||||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
request.state.id,
|
||||
request_id,
|
||||
prompt,
|
||||
data,
|
||||
mm_embeddings=embeddings,
|
||||
|
|
@ -484,10 +491,14 @@ async def generate_tool_calls(
|
|||
data, current_generations
|
||||
)
|
||||
|
||||
request_id = _parse_gen_request_id(
|
||||
data.n, request.state.id, idx
|
||||
)
|
||||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
request.state.id,
|
||||
request_id,
|
||||
pre_tool_prompt,
|
||||
tool_data,
|
||||
embeddings=mm_embeddings,
|
||||
|
|
|
|||
|
|
@ -31,6 +31,13 @@ from endpoints.OAI.types.completion import (
|
|||
from endpoints.OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def _parse_gen_request_id(n: int, request_id: str, task_idx: int):
|
||||
if n > 1:
|
||||
return f"{request_id}-{task_idx}"
|
||||
else:
|
||||
return request_id
|
||||
|
||||
|
||||
def _create_response(
|
||||
request_id: str, generations: Union[dict, List[dict]], model_name: str = ""
|
||||
):
|
||||
|
|
@ -193,14 +200,17 @@ async def stream_generate_completion(
|
|||
try:
|
||||
logger.info(f"Received streaming completion request {request.state.id}")
|
||||
|
||||
for n in range(0, data.n):
|
||||
for idx in range(0, data.n):
|
||||
task_gen_params = data.model_copy(deep=True)
|
||||
request_id = _parse_gen_request_id(
|
||||
data.n, request.state.id, idx
|
||||
)
|
||||
|
||||
gen_task = asyncio.create_task(
|
||||
_stream_collector(
|
||||
n,
|
||||
idx,
|
||||
gen_queue,
|
||||
request.state.id,
|
||||
request_id,
|
||||
data.prompt,
|
||||
task_gen_params,
|
||||
abort_event,
|
||||
|
|
@ -255,13 +265,16 @@ async def generate_completion(
|
|||
try:
|
||||
logger.info(f"Recieved completion request {request.state.id}")
|
||||
|
||||
for _ in range(0, data.n):
|
||||
for idx in range(0, data.n):
|
||||
task_gen_params = data.model_copy(deep=True)
|
||||
request_id = _parse_gen_request_id(
|
||||
data.n, request.state.id, idx
|
||||
)
|
||||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
request.state.id,
|
||||
request_id,
|
||||
data.prompt,
|
||||
task_gen_params,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue