API: Add ability to use request IDs
Identify which request is being processed to help users disambiguate which logs correspond to which request. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
38185a1ff4
commit
cae94b920c
6 changed files with 112 additions and 57 deletions
|
|
@ -828,10 +828,10 @@ class ExllamaV2Container:
|
|||
|
||||
return dict(zip_longest(top_tokens, cleaned_values))
|
||||
|
||||
async def generate(self, prompt: str, **kwargs):
|
||||
async def generate(self, prompt: str, request_id: str, **kwargs):
|
||||
"""Generate a response to a prompt"""
|
||||
generations = []
|
||||
async for generation in self.generate_gen(prompt, **kwargs):
|
||||
async for generation in self.generate_gen(prompt, request_id, **kwargs):
|
||||
generations.append(generation)
|
||||
|
||||
joined_generation = {
|
||||
|
|
@ -881,7 +881,11 @@ class ExllamaV2Container:
|
|||
return kwargs
|
||||
|
||||
async def generate_gen(
|
||||
self, prompt: str, abort_event: Optional[asyncio.Event] = None, **kwargs
|
||||
self,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
abort_event: Optional[asyncio.Event] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Create generator function for prompt completion.
|
||||
|
|
@ -1116,6 +1120,7 @@ class ExllamaV2Container:
|
|||
# Log generation options to console
|
||||
# Some options are too large, so log the args instead
|
||||
log_generation_params(
|
||||
request_id=request_id,
|
||||
max_tokens=max_tokens,
|
||||
min_tokens=min_tokens,
|
||||
stream=kwargs.get("stream"),
|
||||
|
|
@ -1138,9 +1143,10 @@ class ExllamaV2Container:
|
|||
)
|
||||
|
||||
# Log prompt to console
|
||||
log_prompt(prompt, negative_prompt)
|
||||
log_prompt(prompt, request_id, negative_prompt)
|
||||
|
||||
# Create and add a new job
|
||||
# Don't use the request ID here as there can be multiple jobs per request
|
||||
job_id = uuid.uuid4().hex
|
||||
job = ExLlamaV2DynamicJobAsync(
|
||||
self.generator,
|
||||
|
|
|
|||
|
|
@ -51,11 +51,13 @@ def log_generation_params(**kwargs):
|
|||
logger.info(f"Generation options: {kwargs}\n")
|
||||
|
||||
|
||||
def log_prompt(prompt: str, negative_prompt: Optional[str]):
|
||||
def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]):
|
||||
"""Logs the prompt to console."""
|
||||
if PREFERENCES.prompt:
|
||||
formatted_prompt = "\n" + prompt
|
||||
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")
|
||||
logger.info(
|
||||
f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n"
|
||||
)
|
||||
|
||||
if negative_prompt:
|
||||
formatted_negative_prompt = "\n" + negative_prompt
|
||||
|
|
|
|||
|
|
@ -107,12 +107,14 @@ async def completion_request(
|
|||
ping=maxsize,
|
||||
)
|
||||
else:
|
||||
generate_task = asyncio.create_task(generate_completion(data, model_path))
|
||||
generate_task = asyncio.create_task(
|
||||
generate_completion(data, request, model_path)
|
||||
)
|
||||
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
generate_task,
|
||||
disconnect_message="Completion generation cancelled by user.",
|
||||
disconnect_message=f"Completion {request.state.id} cancelled by user.",
|
||||
)
|
||||
return response
|
||||
|
||||
|
|
@ -161,13 +163,13 @@ async def chat_completion_request(
|
|||
)
|
||||
else:
|
||||
generate_task = asyncio.create_task(
|
||||
generate_chat_completion(prompt, data, model_path)
|
||||
generate_chat_completion(prompt, data, request, model_path)
|
||||
)
|
||||
|
||||
response = await run_with_request_disconnect(
|
||||
request,
|
||||
generate_task,
|
||||
disconnect_message="Chat completion generation cancelled by user.",
|
||||
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
|
||||
)
|
||||
return response
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import pathlib
|
|||
from asyncio import CancelledError
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from jinja2 import TemplateError
|
||||
|
|
@ -30,9 +29,12 @@ from endpoints.OAI.types.chat_completion import (
|
|||
ChatCompletionStreamChoice,
|
||||
)
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
from endpoints.OAI.utils.completion import _stream_collector
|
||||
|
||||
|
||||
def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||
def _create_response(
|
||||
request_id: str, generations: List[dict], model_name: Optional[str]
|
||||
):
|
||||
"""Create a chat completion response from the provided text."""
|
||||
|
||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||
|
|
@ -77,6 +79,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
|
|||
choices.append(choice)
|
||||
|
||||
response = ChatCompletionResponse(
|
||||
id=f"chatcmpl-{request_id}",
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
usage=UsageStats(
|
||||
|
|
@ -90,7 +93,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
|
|||
|
||||
|
||||
def _create_stream_chunk(
|
||||
const_id: str,
|
||||
request_id: str,
|
||||
generation: Optional[dict] = None,
|
||||
model_name: Optional[str] = None,
|
||||
is_usage_chunk: bool = False,
|
||||
|
|
@ -150,7 +153,7 @@ def _create_stream_chunk(
|
|||
choices.append(choice)
|
||||
|
||||
chunk = ChatCompletionStreamChunk(
|
||||
id=const_id,
|
||||
id=f"chatcmpl-{request_id}",
|
||||
choices=choices,
|
||||
model=unwrap(model_name, ""),
|
||||
usage=usage_stats,
|
||||
|
|
@ -235,39 +238,18 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
|||
raise HTTPException(400, error_message) from exc
|
||||
|
||||
|
||||
async def _stream_collector(
|
||||
task_idx: int,
|
||||
gen_queue: asyncio.Queue,
|
||||
prompt: str,
|
||||
abort_event: asyncio.Event,
|
||||
**kwargs,
|
||||
):
|
||||
"""Collects a stream and places results in a common queue"""
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
|
||||
async for generation in new_generation:
|
||||
generation["index"] = task_idx
|
||||
|
||||
await gen_queue.put(generation)
|
||||
|
||||
if "finish_reason" in generation:
|
||||
break
|
||||
except Exception as e:
|
||||
await gen_queue.put(e)
|
||||
|
||||
|
||||
async def stream_generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
"""Generator for the generation process."""
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
abort_event = asyncio.Event()
|
||||
gen_queue = asyncio.Queue()
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
|
||||
try:
|
||||
logger.info(f"Recieved chat completion streaming request {request.state.id}")
|
||||
|
||||
gen_params = data.to_gen_params()
|
||||
|
||||
for n in range(0, data.n):
|
||||
|
|
@ -277,7 +259,14 @@ async def stream_generate_chat_completion(
|
|||
task_gen_params = gen_params
|
||||
|
||||
gen_task = asyncio.create_task(
|
||||
_stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params)
|
||||
_stream_collector(
|
||||
n,
|
||||
gen_queue,
|
||||
prompt,
|
||||
request.state.id,
|
||||
abort_event,
|
||||
**task_gen_params,
|
||||
)
|
||||
)
|
||||
|
||||
gen_tasks.append(gen_task)
|
||||
|
|
@ -286,7 +275,9 @@ async def stream_generate_chat_completion(
|
|||
while True:
|
||||
if disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
handle_request_disconnect(
|
||||
f"Chat completion generation {request.state.id} cancelled by user."
|
||||
)
|
||||
|
||||
generation = await gen_queue.get()
|
||||
|
||||
|
|
@ -294,7 +285,9 @@ async def stream_generate_chat_completion(
|
|||
if isinstance(generation, Exception):
|
||||
raise generation
|
||||
|
||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
||||
response = _create_stream_chunk(
|
||||
request.state.id, generation, model_path.name
|
||||
)
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Check if all tasks are completed
|
||||
|
|
@ -302,10 +295,17 @@ async def stream_generate_chat_completion(
|
|||
# Send a usage chunk
|
||||
if data.stream_options and data.stream_options.include_usage:
|
||||
usage_chunk = _create_stream_chunk(
|
||||
const_id, generation, model_path.name, is_usage_chunk=True
|
||||
request.state.id,
|
||||
generation,
|
||||
model_path.name,
|
||||
is_usage_chunk=True,
|
||||
)
|
||||
yield usage_chunk.model_dump_json()
|
||||
|
||||
logger.info(
|
||||
f"Finished chat completion streaming request {request.state.id}"
|
||||
)
|
||||
|
||||
yield "[DONE]"
|
||||
break
|
||||
except CancelledError:
|
||||
|
|
@ -320,7 +320,7 @@ async def stream_generate_chat_completion(
|
|||
|
||||
|
||||
async def generate_chat_completion(
|
||||
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
|
||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
gen_params = data.to_gen_params()
|
||||
|
|
@ -335,16 +335,23 @@ async def generate_chat_completion(
|
|||
task_gen_params = gen_params
|
||||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(model.container.generate(prompt, **task_gen_params))
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
prompt, request.state.id, **task_gen_params
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
response = _create_response(generations, model_path.name)
|
||||
response = _create_response(request.state.id, generations, model_path.name)
|
||||
|
||||
logger.info(f"Finished chat completion request {request.state.id}")
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Chat completion aborted. Maybe the model was unloaded? "
|
||||
f"Chat completion {request.state.id} aborted. "
|
||||
"Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ from copy import deepcopy
|
|||
from fastapi import HTTPException, Request
|
||||
from typing import List, Union
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from common import model
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
|
|
@ -24,7 +26,9 @@ from endpoints.OAI.types.completion import (
|
|||
from endpoints.OAI.types.common import UsageStats
|
||||
|
||||
|
||||
def _create_response(generations: Union[dict, List[dict]], model_name: str = ""):
|
||||
def _create_response(
|
||||
request_id: str, generations: Union[dict, List[dict]], model_name: str = ""
|
||||
):
|
||||
"""Create a completion response from the provided choices."""
|
||||
|
||||
# Convert the single choice object into a list
|
||||
|
|
@ -61,6 +65,7 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "")
|
|||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||
|
||||
response = CompletionResponse(
|
||||
id=f"cmpl-{request_id}",
|
||||
choices=choices,
|
||||
model=model_name,
|
||||
usage=UsageStats(
|
||||
|
|
@ -77,13 +82,16 @@ async def _stream_collector(
|
|||
task_idx: int,
|
||||
gen_queue: asyncio.Queue,
|
||||
prompt: str,
|
||||
request_id: str,
|
||||
abort_event: asyncio.Event,
|
||||
**kwargs,
|
||||
):
|
||||
"""Collects a stream and places results in a common queue"""
|
||||
|
||||
try:
|
||||
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
|
||||
new_generation = model.container.generate_gen(
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
)
|
||||
async for generation in new_generation:
|
||||
generation["index"] = task_idx
|
||||
|
||||
|
|
@ -106,6 +114,8 @@ async def stream_generate_completion(
|
|||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
|
||||
try:
|
||||
logger.info(f"Recieved streaming completion request {request.state.id}")
|
||||
|
||||
gen_params = data.to_gen_params()
|
||||
|
||||
for n in range(0, data.n):
|
||||
|
|
@ -116,7 +126,12 @@ async def stream_generate_completion(
|
|||
|
||||
gen_task = asyncio.create_task(
|
||||
_stream_collector(
|
||||
n, gen_queue, data.prompt, abort_event, **task_gen_params
|
||||
n,
|
||||
gen_queue,
|
||||
data.prompt,
|
||||
request.state.id,
|
||||
abort_event,
|
||||
**task_gen_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -126,7 +141,9 @@ async def stream_generate_completion(
|
|||
while True:
|
||||
if disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
handle_request_disconnect(
|
||||
f"Completion generation {request.state.id} cancelled by user."
|
||||
)
|
||||
|
||||
generation = await gen_queue.get()
|
||||
|
||||
|
|
@ -134,31 +151,38 @@ async def stream_generate_completion(
|
|||
if isinstance(generation, Exception):
|
||||
raise generation
|
||||
|
||||
response = _create_response(generation, model_path.name)
|
||||
response = _create_response(request.state.id, generation, model_path.name)
|
||||
yield response.model_dump_json()
|
||||
|
||||
# Check if all tasks are completed
|
||||
if all(task.done() for task in gen_tasks) and gen_queue.empty():
|
||||
yield "[DONE]"
|
||||
logger.info(f"Finished streaming completion request {request.state.id}")
|
||||
break
|
||||
except CancelledError:
|
||||
# Get out if the request gets disconnected
|
||||
|
||||
abort_event.set()
|
||||
handle_request_disconnect("Completion generation cancelled by user.")
|
||||
handle_request_disconnect(
|
||||
f"Completion generation {request.state.id} cancelled by user."
|
||||
)
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
"Completion aborted. Please check the server console."
|
||||
f"Completion {request.state.id} aborted. Please check the server console."
|
||||
)
|
||||
|
||||
|
||||
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
|
||||
async def generate_completion(
|
||||
data: CompletionRequest, request: Request, model_path: pathlib.Path
|
||||
):
|
||||
"""Non-streaming generate for completions"""
|
||||
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
gen_params = data.to_gen_params()
|
||||
|
||||
try:
|
||||
logger.info(f"Recieved completion request {request.state.id}")
|
||||
|
||||
for n in range(0, data.n):
|
||||
# Deepcopy gen params above the first index
|
||||
# to ensure nested structures aren't shared
|
||||
|
|
@ -169,17 +193,21 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
|
|||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(data.prompt, **task_gen_params)
|
||||
model.container.generate(
|
||||
data.prompt, request.state.id, **task_gen_params
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
response = _create_response(generations, model_path.name)
|
||||
response = _create_response(request.state.id, generations, model_path.name)
|
||||
|
||||
logger.info(f"Finished completion request {request.state.id}")
|
||||
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
"Completion aborted. Maybe the model was unloaded? "
|
||||
f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from uuid import uuid4
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -25,6 +26,15 @@ app.add_middleware(
|
|||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def add_request_id(request: Request, call_next):
|
||||
"""Middleware to append an ID to a request"""
|
||||
|
||||
request.state.id = uuid4().hex
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
|
||||
def setup_app():
|
||||
"""Includes the correct routers for startup"""
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue