API: Add ability to use request IDs

Identify which request is being processed to help users disambiguate
which logs correspond to which request.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-21 21:01:05 -04:00
parent 38185a1ff4
commit cae94b920c
6 changed files with 112 additions and 57 deletions

View file

@ -828,10 +828,10 @@ class ExllamaV2Container:
return dict(zip_longest(top_tokens, cleaned_values))
async def generate(self, prompt: str, **kwargs):
async def generate(self, prompt: str, request_id: str, **kwargs):
"""Generate a response to a prompt"""
generations = []
async for generation in self.generate_gen(prompt, **kwargs):
async for generation in self.generate_gen(prompt, request_id, **kwargs):
generations.append(generation)
joined_generation = {
@ -881,7 +881,11 @@ class ExllamaV2Container:
return kwargs
async def generate_gen(
self, prompt: str, abort_event: Optional[asyncio.Event] = None, **kwargs
self,
prompt: str,
request_id: str,
abort_event: Optional[asyncio.Event] = None,
**kwargs,
):
"""
Create generator function for prompt completion.
@ -1116,6 +1120,7 @@ class ExllamaV2Container:
# Log generation options to console
# Some options are too large, so log the args instead
log_generation_params(
request_id=request_id,
max_tokens=max_tokens,
min_tokens=min_tokens,
stream=kwargs.get("stream"),
@ -1138,9 +1143,10 @@ class ExllamaV2Container:
)
# Log prompt to console
log_prompt(prompt, negative_prompt)
log_prompt(prompt, request_id, negative_prompt)
# Create and add a new job
# Don't use the request ID here as there can be multiple jobs per request
job_id = uuid.uuid4().hex
job = ExLlamaV2DynamicJobAsync(
self.generator,

View file

@ -51,11 +51,13 @@ def log_generation_params(**kwargs):
logger.info(f"Generation options: {kwargs}\n")
def log_prompt(prompt: str, negative_prompt: Optional[str]):
def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]):
"""Logs the prompt to console."""
if PREFERENCES.prompt:
formatted_prompt = "\n" + prompt
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")
logger.info(
f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n"
)
if negative_prompt:
formatted_negative_prompt = "\n" + negative_prompt

View file

@ -107,12 +107,14 @@ async def completion_request(
ping=maxsize,
)
else:
generate_task = asyncio.create_task(generate_completion(data, model_path))
generate_task = asyncio.create_task(
generate_completion(data, request, model_path)
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Completion generation cancelled by user.",
disconnect_message=f"Completion {request.state.id} cancelled by user.",
)
return response
@ -161,13 +163,13 @@ async def chat_completion_request(
)
else:
generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, model_path)
generate_chat_completion(prompt, data, request, model_path)
)
response = await run_with_request_disconnect(
request,
generate_task,
disconnect_message="Chat completion generation cancelled by user.",
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
)
return response

View file

@ -5,7 +5,6 @@ import pathlib
from asyncio import CancelledError
from copy import deepcopy
from typing import List, Optional
from uuid import uuid4
from fastapi import HTTPException, Request
from jinja2 import TemplateError
@ -30,9 +29,12 @@ from endpoints.OAI.types.chat_completion import (
ChatCompletionStreamChoice,
)
from endpoints.OAI.types.common import UsageStats
from endpoints.OAI.utils.completion import _stream_collector
def _create_response(generations: List[dict], model_name: Optional[str]):
def _create_response(
request_id: str, generations: List[dict], model_name: Optional[str]
):
"""Create a chat completion response from the provided text."""
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
@ -77,6 +79,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
choices.append(choice)
response = ChatCompletionResponse(
id=f"chatcmpl-{request_id}",
choices=choices,
model=unwrap(model_name, ""),
usage=UsageStats(
@ -90,7 +93,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
def _create_stream_chunk(
const_id: str,
request_id: str,
generation: Optional[dict] = None,
model_name: Optional[str] = None,
is_usage_chunk: bool = False,
@ -150,7 +153,7 @@ def _create_stream_chunk(
choices.append(choice)
chunk = ChatCompletionStreamChunk(
id=const_id,
id=f"chatcmpl-{request_id}",
choices=choices,
model=unwrap(model_name, ""),
usage=usage_stats,
@ -235,39 +238,18 @@ def format_prompt_with_template(data: ChatCompletionRequest):
raise HTTPException(400, error_message) from exc
async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
abort_event: asyncio.Event,
**kwargs,
):
"""Collects a stream and places results in a common queue"""
try:
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
async for generation in new_generation:
generation["index"] = task_idx
await gen_queue.put(generation)
if "finish_reason" in generation:
break
except Exception as e:
await gen_queue.put(e)
async def stream_generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
):
"""Generator for the generation process."""
const_id = f"chatcmpl-{uuid4().hex}"
abort_event = asyncio.Event()
gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = []
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try:
logger.info(f"Recieved chat completion streaming request {request.state.id}")
gen_params = data.to_gen_params()
for n in range(0, data.n):
@ -277,7 +259,14 @@ async def stream_generate_chat_completion(
task_gen_params = gen_params
gen_task = asyncio.create_task(
_stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params)
_stream_collector(
n,
gen_queue,
prompt,
request.state.id,
abort_event,
**task_gen_params,
)
)
gen_tasks.append(gen_task)
@ -286,7 +275,9 @@ async def stream_generate_chat_completion(
while True:
if disconnect_task.done():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
handle_request_disconnect(
f"Chat completion generation {request.state.id} cancelled by user."
)
generation = await gen_queue.get()
@ -294,7 +285,9 @@ async def stream_generate_chat_completion(
if isinstance(generation, Exception):
raise generation
response = _create_stream_chunk(const_id, generation, model_path.name)
response = _create_stream_chunk(
request.state.id, generation, model_path.name
)
yield response.model_dump_json()
# Check if all tasks are completed
@ -302,10 +295,17 @@ async def stream_generate_chat_completion(
# Send a usage chunk
if data.stream_options and data.stream_options.include_usage:
usage_chunk = _create_stream_chunk(
const_id, generation, model_path.name, is_usage_chunk=True
request.state.id,
generation,
model_path.name,
is_usage_chunk=True,
)
yield usage_chunk.model_dump_json()
logger.info(
f"Finished chat completion streaming request {request.state.id}"
)
yield "[DONE]"
break
except CancelledError:
@ -320,7 +320,7 @@ async def stream_generate_chat_completion(
async def generate_chat_completion(
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
):
gen_tasks: List[asyncio.Task] = []
gen_params = data.to_gen_params()
@ -335,16 +335,23 @@ async def generate_chat_completion(
task_gen_params = gen_params
gen_tasks.append(
asyncio.create_task(model.container.generate(prompt, **task_gen_params))
asyncio.create_task(
model.container.generate(
prompt, request.state.id, **task_gen_params
)
)
)
generations = await asyncio.gather(*gen_tasks)
response = _create_response(generations, model_path.name)
response = _create_response(request.state.id, generations, model_path.name)
logger.info(f"Finished chat completion request {request.state.id}")
return response
except Exception as exc:
error_message = handle_request_error(
"Chat completion aborted. Maybe the model was unloaded? "
f"Chat completion {request.state.id} aborted. "
"Maybe the model was unloaded? "
"Please check the server console."
).error.message

View file

@ -7,6 +7,8 @@ from copy import deepcopy
from fastapi import HTTPException, Request
from typing import List, Union
from loguru import logger
from common import model
from common.networking import (
get_generator_error,
@ -24,7 +26,9 @@ from endpoints.OAI.types.completion import (
from endpoints.OAI.types.common import UsageStats
def _create_response(generations: Union[dict, List[dict]], model_name: str = ""):
def _create_response(
request_id: str, generations: Union[dict, List[dict]], model_name: str = ""
):
"""Create a completion response from the provided choices."""
# Convert the single choice object into a list
@ -61,6 +65,7 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "")
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
response = CompletionResponse(
id=f"cmpl-{request_id}",
choices=choices,
model=model_name,
usage=UsageStats(
@ -77,13 +82,16 @@ async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
request_id: str,
abort_event: asyncio.Event,
**kwargs,
):
"""Collects a stream and places results in a common queue"""
try:
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
new_generation = model.container.generate_gen(
prompt, request_id, abort_event, **kwargs
)
async for generation in new_generation:
generation["index"] = task_idx
@ -106,6 +114,8 @@ async def stream_generate_completion(
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try:
logger.info(f"Recieved streaming completion request {request.state.id}")
gen_params = data.to_gen_params()
for n in range(0, data.n):
@ -116,7 +126,12 @@ async def stream_generate_completion(
gen_task = asyncio.create_task(
_stream_collector(
n, gen_queue, data.prompt, abort_event, **task_gen_params
n,
gen_queue,
data.prompt,
request.state.id,
abort_event,
**task_gen_params,
)
)
@ -126,7 +141,9 @@ async def stream_generate_completion(
while True:
if disconnect_task.done():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
handle_request_disconnect(
f"Completion generation {request.state.id} cancelled by user."
)
generation = await gen_queue.get()
@ -134,31 +151,38 @@ async def stream_generate_completion(
if isinstance(generation, Exception):
raise generation
response = _create_response(generation, model_path.name)
response = _create_response(request.state.id, generation, model_path.name)
yield response.model_dump_json()
# Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty():
yield "[DONE]"
logger.info(f"Finished streaming completion request {request.state.id}")
break
except CancelledError:
# Get out if the request gets disconnected
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
handle_request_disconnect(
f"Completion generation {request.state.id} cancelled by user."
)
except Exception:
yield get_generator_error(
"Completion aborted. Please check the server console."
f"Completion {request.state.id} aborted. Please check the server console."
)
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
async def generate_completion(
data: CompletionRequest, request: Request, model_path: pathlib.Path
):
"""Non-streaming generate for completions"""
gen_tasks: List[asyncio.Task] = []
gen_params = data.to_gen_params()
try:
logger.info(f"Recieved completion request {request.state.id}")
for n in range(0, data.n):
# Deepcopy gen params above the first index
# to ensure nested structures aren't shared
@ -169,17 +193,21 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
gen_tasks.append(
asyncio.create_task(
model.container.generate(data.prompt, **task_gen_params)
model.container.generate(
data.prompt, request.state.id, **task_gen_params
)
)
)
generations = await asyncio.gather(*gen_tasks)
response = _create_response(generations, model_path.name)
response = _create_response(request.state.id, generations, model_path.name)
logger.info(f"Finished completion request {request.state.id}")
return response
except Exception as exc:
error_message = handle_request_error(
"Completion aborted. Maybe the model was unloaded? "
f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message

View file

@ -1,5 +1,6 @@
from uuid import uuid4
import uvicorn
from fastapi import FastAPI
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
@ -25,6 +26,15 @@ app.add_middleware(
)
@app.middleware("http")
async def add_request_id(request: Request, call_next):
"""Middleware to append an ID to a request"""
request.state.id = uuid4().hex
response = await call_next(request)
return response
def setup_app():
"""Includes the correct routers for startup"""