API: Add ability to use request IDs
Identify which request is being processed to help users disambiguate which logs correspond to which request. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
38185a1ff4
commit
cae94b920c
6 changed files with 112 additions and 57 deletions
|
|
@ -7,6 +7,8 @@ from copy import deepcopy
|
|||
from fastapi import HTTPException, Request
|
||||
from typing import List, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
|
|
@ -24,7 +26,9 @@ from endpoints.OAI.types.completion import (
|
|||
from endpoints.OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def _create_response(generations: Union[dict, List[dict]], model_name: str = ""):
|
||||
def _create_response(
|
||||
request_id: str, generations: Union[dict, List[dict]], model_name: str = ""
|
||||
):
|
||||
"""Create a completion response from the provided choices."""
|
||||
|
||||
# Convert the single choice object into a list
|
||||
|
|
@ -61,6 +65,7 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "")
|
|||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||
|
||||
response = CompletionResponse(
|
||||
id=f"cmpl-{request_id}",
|
||||
choices=choices,
|
||||
model=model_name,
|
||||
usage=UsageStats(
|
||||
|
|
@ -77,13 +82,16 @@ async def _stream_collector(
|
|||
task_idx: int,
|
||||
gen_queue: asyncio.Queue,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
abort_event: asyncio.Event,
|
||||
**kwargs,
|
||||
):
|
||||
"""Collects a stream and places results in a common queue"""
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
)
|
||||
async for generation in new_generation:
|
||||
generation["index"] = task_idx
|
||||
|
||||
|
|
@ -106,6 +114,8 @@ async def stream_generate_completion(
|
|||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
|
||||
try:
|
||||
logger.info(f"Recieved streaming completion request {request.state.id}")
|
||||
|
||||
gen_params = data.to_gen_params()
|
||||
|
||||
for n in range(0, data.n):
|
||||
|
|
@ -116,7 +126,12 @@ async def stream_generate_completion(
|
|||
|
||||
gen_task = asyncio.create_task(
|
||||
_stream_collector(
|
||||
n, gen_queue, data.prompt, abort_event, **task_gen_params
|
||||
n,
|
||||
gen_queue,
|
||||
data.prompt,
|
||||
request.state.id,
|
||||
abort_event,
|
||||
**task_gen_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -126,7 +141,9 @@ async def stream_generate_completion(
|
|||
while True:
|
||||
if disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
handle_request_disconnect(
|
||||
f"Completion generation {request.state.id} cancelled by user."
|
||||
)
|
||||
|
||||
generation = await gen_queue.get()
|
||||
|
||||
|
|
@ -134,31 +151,38 @@ async def stream_generate_completion(
|
|||
if isinstance(generation, Exception):
|
||||
raise generation
|
||||
|
||||
response = _create_response(generation, model_path.name)
|
||||
response = _create_response(request.state.id, generation, model_path.name)
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Check if all tasks are completed
|
||||
if all(task.done() for task in gen_tasks) and gen_queue.empty():
|
||||
yield "[DONE]"
|
||||
logger.info(f"Finished streaming completion request {request.state.id}")
|
||||
break
|
||||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
handle_request_disconnect(
|
||||
f"Completion generation {request.state.id} cancelled by user."
|
||||
)
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Completion aborted. Please check the server console."
|
||||
f"Completion {request.state.id} aborted. Please check the server console."
|
||||
)
|
||||
|
||||
|
||||
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
|
||||
async def generate_completion(
|
||||
data: CompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
"""Non-streaming generate for completions"""
|
||||
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
gen_params = data.to_gen_params()
|
||||
|
||||
try:
|
||||
logger.info(f"Recieved completion request {request.state.id}")
|
||||
|
||||
for n in range(0, data.n):
|
||||
# Deepcopy gen params above the first index
|
||||
# to ensure nested structures aren't shared
|
||||
|
|
@ -169,17 +193,21 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
|
|||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(data.prompt, **task_gen_params)
|
||||
model.container.generate(
|
||||
data.prompt, request.state.id, **task_gen_params
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
response = _create_response(generations, model_path.name)
|
||||
response = _create_response(request.state.id, generations, model_path.name)
|
||||
|
||||
logger.info(f"Finished completion request {request.state.id}")
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Completion aborted. Maybe the model was unloaded? "
|
||||
f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue