API: Add ability to use request IDs

Identify which request is being processed to help users disambiguate
which logs correspond to which request.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-21 21:01:05 -04:00
parent 38185a1ff4
commit cae94b920c
6 changed files with 112 additions and 57 deletions

View file

@ -828,10 +828,10 @@ class ExllamaV2Container:
return dict(zip_longest(top_tokens, cleaned_values)) return dict(zip_longest(top_tokens, cleaned_values))
async def generate(self, prompt: str, **kwargs): async def generate(self, prompt: str, request_id: str, **kwargs):
"""Generate a response to a prompt""" """Generate a response to a prompt"""
generations = [] generations = []
async for generation in self.generate_gen(prompt, **kwargs): async for generation in self.generate_gen(prompt, request_id, **kwargs):
generations.append(generation) generations.append(generation)
joined_generation = { joined_generation = {
@ -881,7 +881,11 @@ class ExllamaV2Container:
return kwargs return kwargs
async def generate_gen( async def generate_gen(
self, prompt: str, abort_event: Optional[asyncio.Event] = None, **kwargs self,
prompt: str,
request_id: str,
abort_event: Optional[asyncio.Event] = None,
**kwargs,
): ):
""" """
Create generator function for prompt completion. Create generator function for prompt completion.
@ -1116,6 +1120,7 @@ class ExllamaV2Container:
# Log generation options to console # Log generation options to console
# Some options are too large, so log the args instead # Some options are too large, so log the args instead
log_generation_params( log_generation_params(
request_id=request_id,
max_tokens=max_tokens, max_tokens=max_tokens,
min_tokens=min_tokens, min_tokens=min_tokens,
stream=kwargs.get("stream"), stream=kwargs.get("stream"),
@ -1138,9 +1143,10 @@ class ExllamaV2Container:
) )
# Log prompt to console # Log prompt to console
log_prompt(prompt, negative_prompt) log_prompt(prompt, request_id, negative_prompt)
# Create and add a new job # Create and add a new job
# Don't use the request ID here as there can be multiple jobs per request
job_id = uuid.uuid4().hex job_id = uuid.uuid4().hex
job = ExLlamaV2DynamicJobAsync( job = ExLlamaV2DynamicJobAsync(
self.generator, self.generator,

View file

@ -51,11 +51,13 @@ def log_generation_params(**kwargs):
logger.info(f"Generation options: {kwargs}\n") logger.info(f"Generation options: {kwargs}\n")
def log_prompt(prompt: str, negative_prompt: Optional[str]): def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]):
"""Logs the prompt to console.""" """Logs the prompt to console."""
if PREFERENCES.prompt: if PREFERENCES.prompt:
formatted_prompt = "\n" + prompt formatted_prompt = "\n" + prompt
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n") logger.info(
f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n"
)
if negative_prompt: if negative_prompt:
formatted_negative_prompt = "\n" + negative_prompt formatted_negative_prompt = "\n" + negative_prompt

View file

@ -107,12 +107,14 @@ async def completion_request(
ping=maxsize, ping=maxsize,
) )
else: else:
generate_task = asyncio.create_task(generate_completion(data, model_path)) generate_task = asyncio.create_task(
generate_completion(data, request, model_path)
)
response = await run_with_request_disconnect( response = await run_with_request_disconnect(
request, request,
generate_task, generate_task,
disconnect_message="Completion generation cancelled by user.", disconnect_message=f"Completion {request.state.id} cancelled by user.",
) )
return response return response
@ -161,13 +163,13 @@ async def chat_completion_request(
) )
else: else:
generate_task = asyncio.create_task( generate_task = asyncio.create_task(
generate_chat_completion(prompt, data, model_path) generate_chat_completion(prompt, data, request, model_path)
) )
response = await run_with_request_disconnect( response = await run_with_request_disconnect(
request, request,
generate_task, generate_task,
disconnect_message="Chat completion generation cancelled by user.", disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
) )
return response return response

View file

@ -5,7 +5,6 @@ import pathlib
from asyncio import CancelledError from asyncio import CancelledError
from copy import deepcopy from copy import deepcopy
from typing import List, Optional from typing import List, Optional
from uuid import uuid4
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from jinja2 import TemplateError from jinja2 import TemplateError
@ -30,9 +29,12 @@ from endpoints.OAI.types.chat_completion import (
ChatCompletionStreamChoice, ChatCompletionStreamChoice,
) )
from endpoints.OAI.types.common import UsageStats from endpoints.OAI.types.common import UsageStats
from endpoints.OAI.utils.completion import _stream_collector
def _create_response(generations: List[dict], model_name: Optional[str]): def _create_response(
request_id: str, generations: List[dict], model_name: Optional[str]
):
"""Create a chat completion response from the provided text.""" """Create a chat completion response from the provided text."""
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0) prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
@ -77,6 +79,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
choices.append(choice) choices.append(choice)
response = ChatCompletionResponse( response = ChatCompletionResponse(
id=f"chatcmpl-{request_id}",
choices=choices, choices=choices,
model=unwrap(model_name, ""), model=unwrap(model_name, ""),
usage=UsageStats( usage=UsageStats(
@ -90,7 +93,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
def _create_stream_chunk( def _create_stream_chunk(
const_id: str, request_id: str,
generation: Optional[dict] = None, generation: Optional[dict] = None,
model_name: Optional[str] = None, model_name: Optional[str] = None,
is_usage_chunk: bool = False, is_usage_chunk: bool = False,
@ -150,7 +153,7 @@ def _create_stream_chunk(
choices.append(choice) choices.append(choice)
chunk = ChatCompletionStreamChunk( chunk = ChatCompletionStreamChunk(
id=const_id, id=f"chatcmpl-{request_id}",
choices=choices, choices=choices,
model=unwrap(model_name, ""), model=unwrap(model_name, ""),
usage=usage_stats, usage=usage_stats,
@ -235,39 +238,18 @@ def format_prompt_with_template(data: ChatCompletionRequest):
raise HTTPException(400, error_message) from exc raise HTTPException(400, error_message) from exc
async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
abort_event: asyncio.Event,
**kwargs,
):
"""Collects a stream and places results in a common queue"""
try:
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
async for generation in new_generation:
generation["index"] = task_idx
await gen_queue.put(generation)
if "finish_reason" in generation:
break
except Exception as e:
await gen_queue.put(e)
async def stream_generate_chat_completion( async def stream_generate_chat_completion(
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
): ):
"""Generator for the generation process.""" """Generator for the generation process."""
const_id = f"chatcmpl-{uuid4().hex}"
abort_event = asyncio.Event() abort_event = asyncio.Event()
gen_queue = asyncio.Queue() gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = [] gen_tasks: List[asyncio.Task] = []
disconnect_task = asyncio.create_task(request_disconnect_loop(request)) disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try: try:
logger.info(f"Recieved chat completion streaming request {request.state.id}")
gen_params = data.to_gen_params() gen_params = data.to_gen_params()
for n in range(0, data.n): for n in range(0, data.n):
@ -277,7 +259,14 @@ async def stream_generate_chat_completion(
task_gen_params = gen_params task_gen_params = gen_params
gen_task = asyncio.create_task( gen_task = asyncio.create_task(
_stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params) _stream_collector(
n,
gen_queue,
prompt,
request.state.id,
abort_event,
**task_gen_params,
)
) )
gen_tasks.append(gen_task) gen_tasks.append(gen_task)
@ -286,7 +275,9 @@ async def stream_generate_chat_completion(
while True: while True:
if disconnect_task.done(): if disconnect_task.done():
abort_event.set() abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.") handle_request_disconnect(
f"Chat completion generation {request.state.id} cancelled by user."
)
generation = await gen_queue.get() generation = await gen_queue.get()
@ -294,7 +285,9 @@ async def stream_generate_chat_completion(
if isinstance(generation, Exception): if isinstance(generation, Exception):
raise generation raise generation
response = _create_stream_chunk(const_id, generation, model_path.name) response = _create_stream_chunk(
request.state.id, generation, model_path.name
)
yield response.model_dump_json() yield response.model_dump_json()
# Check if all tasks are completed # Check if all tasks are completed
@ -302,10 +295,17 @@ async def stream_generate_chat_completion(
# Send a usage chunk # Send a usage chunk
if data.stream_options and data.stream_options.include_usage: if data.stream_options and data.stream_options.include_usage:
usage_chunk = _create_stream_chunk( usage_chunk = _create_stream_chunk(
const_id, generation, model_path.name, is_usage_chunk=True request.state.id,
generation,
model_path.name,
is_usage_chunk=True,
) )
yield usage_chunk.model_dump_json() yield usage_chunk.model_dump_json()
logger.info(
f"Finished chat completion streaming request {request.state.id}"
)
yield "[DONE]" yield "[DONE]"
break break
except CancelledError: except CancelledError:
@ -320,7 +320,7 @@ async def stream_generate_chat_completion(
async def generate_chat_completion( async def generate_chat_completion(
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
): ):
gen_tasks: List[asyncio.Task] = [] gen_tasks: List[asyncio.Task] = []
gen_params = data.to_gen_params() gen_params = data.to_gen_params()
@ -335,16 +335,23 @@ async def generate_chat_completion(
task_gen_params = gen_params task_gen_params = gen_params
gen_tasks.append( gen_tasks.append(
asyncio.create_task(model.container.generate(prompt, **task_gen_params)) asyncio.create_task(
model.container.generate(
prompt, request.state.id, **task_gen_params
)
)
) )
generations = await asyncio.gather(*gen_tasks) generations = await asyncio.gather(*gen_tasks)
response = _create_response(generations, model_path.name) response = _create_response(request.state.id, generations, model_path.name)
logger.info(f"Finished chat completion request {request.state.id}")
return response return response
except Exception as exc: except Exception as exc:
error_message = handle_request_error( error_message = handle_request_error(
"Chat completion aborted. Maybe the model was unloaded? " f"Chat completion {request.state.id} aborted. "
"Maybe the model was unloaded? "
"Please check the server console." "Please check the server console."
).error.message ).error.message

View file

@ -7,6 +7,8 @@ from copy import deepcopy
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from typing import List, Union from typing import List, Union
from loguru import logger
from common import model from common import model
from common.networking import ( from common.networking import (
get_generator_error, get_generator_error,
@ -24,7 +26,9 @@ from endpoints.OAI.types.completion import (
from endpoints.OAI.types.common import UsageStats from endpoints.OAI.types.common import UsageStats
def _create_response(generations: Union[dict, List[dict]], model_name: str = ""): def _create_response(
request_id: str, generations: Union[dict, List[dict]], model_name: str = ""
):
"""Create a completion response from the provided choices.""" """Create a completion response from the provided choices."""
# Convert the single choice object into a list # Convert the single choice object into a list
@ -61,6 +65,7 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "")
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0) completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
response = CompletionResponse( response = CompletionResponse(
id=f"cmpl-{request_id}",
choices=choices, choices=choices,
model=model_name, model=model_name,
usage=UsageStats( usage=UsageStats(
@ -77,13 +82,16 @@ async def _stream_collector(
task_idx: int, task_idx: int,
gen_queue: asyncio.Queue, gen_queue: asyncio.Queue,
prompt: str, prompt: str,
request_id: str,
abort_event: asyncio.Event, abort_event: asyncio.Event,
**kwargs, **kwargs,
): ):
"""Collects a stream and places results in a common queue""" """Collects a stream and places results in a common queue"""
try: try:
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs) new_generation = model.container.generate_gen(
prompt, request_id, abort_event, **kwargs
)
async for generation in new_generation: async for generation in new_generation:
generation["index"] = task_idx generation["index"] = task_idx
@ -106,6 +114,8 @@ async def stream_generate_completion(
disconnect_task = asyncio.create_task(request_disconnect_loop(request)) disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try: try:
logger.info(f"Recieved streaming completion request {request.state.id}")
gen_params = data.to_gen_params() gen_params = data.to_gen_params()
for n in range(0, data.n): for n in range(0, data.n):
@ -116,7 +126,12 @@ async def stream_generate_completion(
gen_task = asyncio.create_task( gen_task = asyncio.create_task(
_stream_collector( _stream_collector(
n, gen_queue, data.prompt, abort_event, **task_gen_params n,
gen_queue,
data.prompt,
request.state.id,
abort_event,
**task_gen_params,
) )
) )
@ -126,7 +141,9 @@ async def stream_generate_completion(
while True: while True:
if disconnect_task.done(): if disconnect_task.done():
abort_event.set() abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.") handle_request_disconnect(
f"Completion generation {request.state.id} cancelled by user."
)
generation = await gen_queue.get() generation = await gen_queue.get()
@ -134,31 +151,38 @@ async def stream_generate_completion(
if isinstance(generation, Exception): if isinstance(generation, Exception):
raise generation raise generation
response = _create_response(generation, model_path.name) response = _create_response(request.state.id, generation, model_path.name)
yield response.model_dump_json() yield response.model_dump_json()
# Check if all tasks are completed # Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty(): if all(task.done() for task in gen_tasks) and gen_queue.empty():
yield "[DONE]" yield "[DONE]"
logger.info(f"Finished streaming completion request {request.state.id}")
break break
except CancelledError: except CancelledError:
# Get out if the request gets disconnected # Get out if the request gets disconnected
abort_event.set() abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.") handle_request_disconnect(
f"Completion generation {request.state.id} cancelled by user."
)
except Exception: except Exception:
yield get_generator_error( yield get_generator_error(
"Completion aborted. Please check the server console." f"Completion {request.state.id} aborted. Please check the server console."
) )
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path): async def generate_completion(
data: CompletionRequest, request: Request, model_path: pathlib.Path
):
"""Non-streaming generate for completions""" """Non-streaming generate for completions"""
gen_tasks: List[asyncio.Task] = [] gen_tasks: List[asyncio.Task] = []
gen_params = data.to_gen_params() gen_params = data.to_gen_params()
try: try:
logger.info(f"Recieved completion request {request.state.id}")
for n in range(0, data.n): for n in range(0, data.n):
# Deepcopy gen params above the first index # Deepcopy gen params above the first index
# to ensure nested structures aren't shared # to ensure nested structures aren't shared
@ -169,17 +193,21 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
gen_tasks.append( gen_tasks.append(
asyncio.create_task( asyncio.create_task(
model.container.generate(data.prompt, **task_gen_params) model.container.generate(
data.prompt, request.state.id, **task_gen_params
)
) )
) )
generations = await asyncio.gather(*gen_tasks) generations = await asyncio.gather(*gen_tasks)
response = _create_response(generations, model_path.name) response = _create_response(request.state.id, generations, model_path.name)
logger.info(f"Finished completion request {request.state.id}")
return response return response
except Exception as exc: except Exception as exc:
error_message = handle_request_error( error_message = handle_request_error(
"Completion aborted. Maybe the model was unloaded? " f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
"Please check the server console." "Please check the server console."
).error.message ).error.message

View file

@ -1,5 +1,6 @@
from uuid import uuid4
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from loguru import logger from loguru import logger
@ -25,6 +26,15 @@ app.add_middleware(
) )
@app.middleware("http")
async def add_request_id(request: Request, call_next):
"""Middleware to append an ID to a request"""
request.state.id = uuid4().hex
response = await call_next(request)
return response
def setup_app(): def setup_app():
"""Includes the correct routers for startup""" """Includes the correct routers for startup"""