API: Add ability to use request IDs
Identify which request is being processed to help users disambiguate which logs correspond to which request. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
38185a1ff4
commit
cae94b920c
6 changed files with 112 additions and 57 deletions
|
|
@ -828,10 +828,10 @@ class ExllamaV2Container:
|
||||||
|
|
||||||
return dict(zip_longest(top_tokens, cleaned_values))
|
return dict(zip_longest(top_tokens, cleaned_values))
|
||||||
|
|
||||||
async def generate(self, prompt: str, **kwargs):
|
async def generate(self, prompt: str, request_id: str, **kwargs):
|
||||||
"""Generate a response to a prompt"""
|
"""Generate a response to a prompt"""
|
||||||
generations = []
|
generations = []
|
||||||
async for generation in self.generate_gen(prompt, **kwargs):
|
async for generation in self.generate_gen(prompt, request_id, **kwargs):
|
||||||
generations.append(generation)
|
generations.append(generation)
|
||||||
|
|
||||||
joined_generation = {
|
joined_generation = {
|
||||||
|
|
@ -881,7 +881,11 @@ class ExllamaV2Container:
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
async def generate_gen(
|
async def generate_gen(
|
||||||
self, prompt: str, abort_event: Optional[asyncio.Event] = None, **kwargs
|
self,
|
||||||
|
prompt: str,
|
||||||
|
request_id: str,
|
||||||
|
abort_event: Optional[asyncio.Event] = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Create generator function for prompt completion.
|
Create generator function for prompt completion.
|
||||||
|
|
@ -1116,6 +1120,7 @@ class ExllamaV2Container:
|
||||||
# Log generation options to console
|
# Log generation options to console
|
||||||
# Some options are too large, so log the args instead
|
# Some options are too large, so log the args instead
|
||||||
log_generation_params(
|
log_generation_params(
|
||||||
|
request_id=request_id,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
min_tokens=min_tokens,
|
min_tokens=min_tokens,
|
||||||
stream=kwargs.get("stream"),
|
stream=kwargs.get("stream"),
|
||||||
|
|
@ -1138,9 +1143,10 @@ class ExllamaV2Container:
|
||||||
)
|
)
|
||||||
|
|
||||||
# Log prompt to console
|
# Log prompt to console
|
||||||
log_prompt(prompt, negative_prompt)
|
log_prompt(prompt, request_id, negative_prompt)
|
||||||
|
|
||||||
# Create and add a new job
|
# Create and add a new job
|
||||||
|
# Don't use the request ID here as there can be multiple jobs per request
|
||||||
job_id = uuid.uuid4().hex
|
job_id = uuid.uuid4().hex
|
||||||
job = ExLlamaV2DynamicJobAsync(
|
job = ExLlamaV2DynamicJobAsync(
|
||||||
self.generator,
|
self.generator,
|
||||||
|
|
|
||||||
|
|
@ -51,11 +51,13 @@ def log_generation_params(**kwargs):
|
||||||
logger.info(f"Generation options: {kwargs}\n")
|
logger.info(f"Generation options: {kwargs}\n")
|
||||||
|
|
||||||
|
|
||||||
def log_prompt(prompt: str, negative_prompt: Optional[str]):
|
def log_prompt(prompt: str, request_id: str, negative_prompt: Optional[str]):
|
||||||
"""Logs the prompt to console."""
|
"""Logs the prompt to console."""
|
||||||
if PREFERENCES.prompt:
|
if PREFERENCES.prompt:
|
||||||
formatted_prompt = "\n" + prompt
|
formatted_prompt = "\n" + prompt
|
||||||
logger.info(f"Prompt: {formatted_prompt if prompt else 'Empty'}\n")
|
logger.info(
|
||||||
|
f"Prompt (ID: {request_id}): {formatted_prompt if prompt else 'Empty'}\n"
|
||||||
|
)
|
||||||
|
|
||||||
if negative_prompt:
|
if negative_prompt:
|
||||||
formatted_negative_prompt = "\n" + negative_prompt
|
formatted_negative_prompt = "\n" + negative_prompt
|
||||||
|
|
|
||||||
|
|
@ -107,12 +107,14 @@ async def completion_request(
|
||||||
ping=maxsize,
|
ping=maxsize,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
generate_task = asyncio.create_task(generate_completion(data, model_path))
|
generate_task = asyncio.create_task(
|
||||||
|
generate_completion(data, request, model_path)
|
||||||
|
)
|
||||||
|
|
||||||
response = await run_with_request_disconnect(
|
response = await run_with_request_disconnect(
|
||||||
request,
|
request,
|
||||||
generate_task,
|
generate_task,
|
||||||
disconnect_message="Completion generation cancelled by user.",
|
disconnect_message=f"Completion {request.state.id} cancelled by user.",
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
@ -161,13 +163,13 @@ async def chat_completion_request(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
generate_task = asyncio.create_task(
|
generate_task = asyncio.create_task(
|
||||||
generate_chat_completion(prompt, data, model_path)
|
generate_chat_completion(prompt, data, request, model_path)
|
||||||
)
|
)
|
||||||
|
|
||||||
response = await run_with_request_disconnect(
|
response = await run_with_request_disconnect(
|
||||||
request,
|
request,
|
||||||
generate_task,
|
generate_task,
|
||||||
disconnect_message="Chat completion generation cancelled by user.",
|
disconnect_message=f"Chat completion {request.state.id} cancelled by user.",
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ import pathlib
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from jinja2 import TemplateError
|
from jinja2 import TemplateError
|
||||||
|
|
@ -30,9 +29,12 @@ from endpoints.OAI.types.chat_completion import (
|
||||||
ChatCompletionStreamChoice,
|
ChatCompletionStreamChoice,
|
||||||
)
|
)
|
||||||
from endpoints.OAI.types.common import UsageStats
|
from endpoints.OAI.types.common import UsageStats
|
||||||
|
from endpoints.OAI.utils.completion import _stream_collector
|
||||||
|
|
||||||
|
|
||||||
def _create_response(generations: List[dict], model_name: Optional[str]):
|
def _create_response(
|
||||||
|
request_id: str, generations: List[dict], model_name: Optional[str]
|
||||||
|
):
|
||||||
"""Create a chat completion response from the provided text."""
|
"""Create a chat completion response from the provided text."""
|
||||||
|
|
||||||
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
|
||||||
|
|
@ -77,6 +79,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||||
choices.append(choice)
|
choices.append(choice)
|
||||||
|
|
||||||
response = ChatCompletionResponse(
|
response = ChatCompletionResponse(
|
||||||
|
id=f"chatcmpl-{request_id}",
|
||||||
choices=choices,
|
choices=choices,
|
||||||
model=unwrap(model_name, ""),
|
model=unwrap(model_name, ""),
|
||||||
usage=UsageStats(
|
usage=UsageStats(
|
||||||
|
|
@ -90,7 +93,7 @@ def _create_response(generations: List[dict], model_name: Optional[str]):
|
||||||
|
|
||||||
|
|
||||||
def _create_stream_chunk(
|
def _create_stream_chunk(
|
||||||
const_id: str,
|
request_id: str,
|
||||||
generation: Optional[dict] = None,
|
generation: Optional[dict] = None,
|
||||||
model_name: Optional[str] = None,
|
model_name: Optional[str] = None,
|
||||||
is_usage_chunk: bool = False,
|
is_usage_chunk: bool = False,
|
||||||
|
|
@ -150,7 +153,7 @@ def _create_stream_chunk(
|
||||||
choices.append(choice)
|
choices.append(choice)
|
||||||
|
|
||||||
chunk = ChatCompletionStreamChunk(
|
chunk = ChatCompletionStreamChunk(
|
||||||
id=const_id,
|
id=f"chatcmpl-{request_id}",
|
||||||
choices=choices,
|
choices=choices,
|
||||||
model=unwrap(model_name, ""),
|
model=unwrap(model_name, ""),
|
||||||
usage=usage_stats,
|
usage=usage_stats,
|
||||||
|
|
@ -235,39 +238,18 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
||||||
raise HTTPException(400, error_message) from exc
|
raise HTTPException(400, error_message) from exc
|
||||||
|
|
||||||
|
|
||||||
async def _stream_collector(
|
|
||||||
task_idx: int,
|
|
||||||
gen_queue: asyncio.Queue,
|
|
||||||
prompt: str,
|
|
||||||
abort_event: asyncio.Event,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
"""Collects a stream and places results in a common queue"""
|
|
||||||
|
|
||||||
try:
|
|
||||||
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
|
|
||||||
async for generation in new_generation:
|
|
||||||
generation["index"] = task_idx
|
|
||||||
|
|
||||||
await gen_queue.put(generation)
|
|
||||||
|
|
||||||
if "finish_reason" in generation:
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
await gen_queue.put(e)
|
|
||||||
|
|
||||||
|
|
||||||
async def stream_generate_chat_completion(
|
async def stream_generate_chat_completion(
|
||||||
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||||
):
|
):
|
||||||
"""Generator for the generation process."""
|
"""Generator for the generation process."""
|
||||||
const_id = f"chatcmpl-{uuid4().hex}"
|
|
||||||
abort_event = asyncio.Event()
|
abort_event = asyncio.Event()
|
||||||
gen_queue = asyncio.Queue()
|
gen_queue = asyncio.Queue()
|
||||||
gen_tasks: List[asyncio.Task] = []
|
gen_tasks: List[asyncio.Task] = []
|
||||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"Recieved chat completion streaming request {request.state.id}")
|
||||||
|
|
||||||
gen_params = data.to_gen_params()
|
gen_params = data.to_gen_params()
|
||||||
|
|
||||||
for n in range(0, data.n):
|
for n in range(0, data.n):
|
||||||
|
|
@ -277,7 +259,14 @@ async def stream_generate_chat_completion(
|
||||||
task_gen_params = gen_params
|
task_gen_params = gen_params
|
||||||
|
|
||||||
gen_task = asyncio.create_task(
|
gen_task = asyncio.create_task(
|
||||||
_stream_collector(n, gen_queue, prompt, abort_event, **task_gen_params)
|
_stream_collector(
|
||||||
|
n,
|
||||||
|
gen_queue,
|
||||||
|
prompt,
|
||||||
|
request.state.id,
|
||||||
|
abort_event,
|
||||||
|
**task_gen_params,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
gen_tasks.append(gen_task)
|
gen_tasks.append(gen_task)
|
||||||
|
|
@ -286,7 +275,9 @@ async def stream_generate_chat_completion(
|
||||||
while True:
|
while True:
|
||||||
if disconnect_task.done():
|
if disconnect_task.done():
|
||||||
abort_event.set()
|
abort_event.set()
|
||||||
handle_request_disconnect("Completion generation cancelled by user.")
|
handle_request_disconnect(
|
||||||
|
f"Chat completion generation {request.state.id} cancelled by user."
|
||||||
|
)
|
||||||
|
|
||||||
generation = await gen_queue.get()
|
generation = await gen_queue.get()
|
||||||
|
|
||||||
|
|
@ -294,7 +285,9 @@ async def stream_generate_chat_completion(
|
||||||
if isinstance(generation, Exception):
|
if isinstance(generation, Exception):
|
||||||
raise generation
|
raise generation
|
||||||
|
|
||||||
response = _create_stream_chunk(const_id, generation, model_path.name)
|
response = _create_stream_chunk(
|
||||||
|
request.state.id, generation, model_path.name
|
||||||
|
)
|
||||||
yield response.model_dump_json()
|
yield response.model_dump_json()
|
||||||
|
|
||||||
# Check if all tasks are completed
|
# Check if all tasks are completed
|
||||||
|
|
@ -302,10 +295,17 @@ async def stream_generate_chat_completion(
|
||||||
# Send a usage chunk
|
# Send a usage chunk
|
||||||
if data.stream_options and data.stream_options.include_usage:
|
if data.stream_options and data.stream_options.include_usage:
|
||||||
usage_chunk = _create_stream_chunk(
|
usage_chunk = _create_stream_chunk(
|
||||||
const_id, generation, model_path.name, is_usage_chunk=True
|
request.state.id,
|
||||||
|
generation,
|
||||||
|
model_path.name,
|
||||||
|
is_usage_chunk=True,
|
||||||
)
|
)
|
||||||
yield usage_chunk.model_dump_json()
|
yield usage_chunk.model_dump_json()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Finished chat completion streaming request {request.state.id}"
|
||||||
|
)
|
||||||
|
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
break
|
break
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
|
|
@ -320,7 +320,7 @@ async def stream_generate_chat_completion(
|
||||||
|
|
||||||
|
|
||||||
async def generate_chat_completion(
|
async def generate_chat_completion(
|
||||||
prompt: str, data: ChatCompletionRequest, model_path: pathlib.Path
|
prompt: str, data: ChatCompletionRequest, request: Request, model_path: pathlib.Path
|
||||||
):
|
):
|
||||||
gen_tasks: List[asyncio.Task] = []
|
gen_tasks: List[asyncio.Task] = []
|
||||||
gen_params = data.to_gen_params()
|
gen_params = data.to_gen_params()
|
||||||
|
|
@ -335,16 +335,23 @@ async def generate_chat_completion(
|
||||||
task_gen_params = gen_params
|
task_gen_params = gen_params
|
||||||
|
|
||||||
gen_tasks.append(
|
gen_tasks.append(
|
||||||
asyncio.create_task(model.container.generate(prompt, **task_gen_params))
|
asyncio.create_task(
|
||||||
|
model.container.generate(
|
||||||
|
prompt, request.state.id, **task_gen_params
|
||||||
|
)
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
generations = await asyncio.gather(*gen_tasks)
|
generations = await asyncio.gather(*gen_tasks)
|
||||||
response = _create_response(generations, model_path.name)
|
response = _create_response(request.state.id, generations, model_path.name)
|
||||||
|
|
||||||
|
logger.info(f"Finished chat completion request {request.state.id}")
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
"Chat completion aborted. Maybe the model was unloaded? "
|
f"Chat completion {request.state.id} aborted. "
|
||||||
|
"Maybe the model was unloaded? "
|
||||||
"Please check the server console."
|
"Please check the server console."
|
||||||
).error.message
|
).error.message
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,8 @@ from copy import deepcopy
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
from common import model
|
from common import model
|
||||||
from common.networking import (
|
from common.networking import (
|
||||||
get_generator_error,
|
get_generator_error,
|
||||||
|
|
@ -24,7 +26,9 @@ from endpoints.OAI.types.completion import (
|
||||||
from endpoints.OAI.types.common import UsageStats
|
from endpoints.OAI.types.common import UsageStats
|
||||||
|
|
||||||
|
|
||||||
def _create_response(generations: Union[dict, List[dict]], model_name: str = ""):
|
def _create_response(
|
||||||
|
request_id: str, generations: Union[dict, List[dict]], model_name: str = ""
|
||||||
|
):
|
||||||
"""Create a completion response from the provided choices."""
|
"""Create a completion response from the provided choices."""
|
||||||
|
|
||||||
# Convert the single choice object into a list
|
# Convert the single choice object into a list
|
||||||
|
|
@ -61,6 +65,7 @@ def _create_response(generations: Union[dict, List[dict]], model_name: str = "")
|
||||||
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
|
||||||
|
|
||||||
response = CompletionResponse(
|
response = CompletionResponse(
|
||||||
|
id=f"cmpl-{request_id}",
|
||||||
choices=choices,
|
choices=choices,
|
||||||
model=model_name,
|
model=model_name,
|
||||||
usage=UsageStats(
|
usage=UsageStats(
|
||||||
|
|
@ -77,13 +82,16 @@ async def _stream_collector(
|
||||||
task_idx: int,
|
task_idx: int,
|
||||||
gen_queue: asyncio.Queue,
|
gen_queue: asyncio.Queue,
|
||||||
prompt: str,
|
prompt: str,
|
||||||
|
request_id: str,
|
||||||
abort_event: asyncio.Event,
|
abort_event: asyncio.Event,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""Collects a stream and places results in a common queue"""
|
"""Collects a stream and places results in a common queue"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
|
new_generation = model.container.generate_gen(
|
||||||
|
prompt, request_id, abort_event, **kwargs
|
||||||
|
)
|
||||||
async for generation in new_generation:
|
async for generation in new_generation:
|
||||||
generation["index"] = task_idx
|
generation["index"] = task_idx
|
||||||
|
|
||||||
|
|
@ -106,6 +114,8 @@ async def stream_generate_completion(
|
||||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"Recieved streaming completion request {request.state.id}")
|
||||||
|
|
||||||
gen_params = data.to_gen_params()
|
gen_params = data.to_gen_params()
|
||||||
|
|
||||||
for n in range(0, data.n):
|
for n in range(0, data.n):
|
||||||
|
|
@ -116,7 +126,12 @@ async def stream_generate_completion(
|
||||||
|
|
||||||
gen_task = asyncio.create_task(
|
gen_task = asyncio.create_task(
|
||||||
_stream_collector(
|
_stream_collector(
|
||||||
n, gen_queue, data.prompt, abort_event, **task_gen_params
|
n,
|
||||||
|
gen_queue,
|
||||||
|
data.prompt,
|
||||||
|
request.state.id,
|
||||||
|
abort_event,
|
||||||
|
**task_gen_params,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -126,7 +141,9 @@ async def stream_generate_completion(
|
||||||
while True:
|
while True:
|
||||||
if disconnect_task.done():
|
if disconnect_task.done():
|
||||||
abort_event.set()
|
abort_event.set()
|
||||||
handle_request_disconnect("Completion generation cancelled by user.")
|
handle_request_disconnect(
|
||||||
|
f"Completion generation {request.state.id} cancelled by user."
|
||||||
|
)
|
||||||
|
|
||||||
generation = await gen_queue.get()
|
generation = await gen_queue.get()
|
||||||
|
|
||||||
|
|
@ -134,31 +151,38 @@ async def stream_generate_completion(
|
||||||
if isinstance(generation, Exception):
|
if isinstance(generation, Exception):
|
||||||
raise generation
|
raise generation
|
||||||
|
|
||||||
response = _create_response(generation, model_path.name)
|
response = _create_response(request.state.id, generation, model_path.name)
|
||||||
yield response.model_dump_json()
|
yield response.model_dump_json()
|
||||||
|
|
||||||
# Check if all tasks are completed
|
# Check if all tasks are completed
|
||||||
if all(task.done() for task in gen_tasks) and gen_queue.empty():
|
if all(task.done() for task in gen_tasks) and gen_queue.empty():
|
||||||
yield "[DONE]"
|
yield "[DONE]"
|
||||||
|
logger.info(f"Finished streaming completion request {request.state.id}")
|
||||||
break
|
break
|
||||||
except CancelledError:
|
except CancelledError:
|
||||||
# Get out if the request gets disconnected
|
# Get out if the request gets disconnected
|
||||||
|
|
||||||
abort_event.set()
|
abort_event.set()
|
||||||
handle_request_disconnect("Completion generation cancelled by user.")
|
handle_request_disconnect(
|
||||||
|
f"Completion generation {request.state.id} cancelled by user."
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
yield get_generator_error(
|
yield get_generator_error(
|
||||||
"Completion aborted. Please check the server console."
|
f"Completion {request.state.id} aborted. Please check the server console."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
|
async def generate_completion(
|
||||||
|
data: CompletionRequest, request: Request, model_path: pathlib.Path
|
||||||
|
):
|
||||||
"""Non-streaming generate for completions"""
|
"""Non-streaming generate for completions"""
|
||||||
|
|
||||||
gen_tasks: List[asyncio.Task] = []
|
gen_tasks: List[asyncio.Task] = []
|
||||||
gen_params = data.to_gen_params()
|
gen_params = data.to_gen_params()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
logger.info(f"Recieved completion request {request.state.id}")
|
||||||
|
|
||||||
for n in range(0, data.n):
|
for n in range(0, data.n):
|
||||||
# Deepcopy gen params above the first index
|
# Deepcopy gen params above the first index
|
||||||
# to ensure nested structures aren't shared
|
# to ensure nested structures aren't shared
|
||||||
|
|
@ -169,17 +193,21 @@ async def generate_completion(data: CompletionRequest, model_path: pathlib.Path)
|
||||||
|
|
||||||
gen_tasks.append(
|
gen_tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
model.container.generate(data.prompt, **task_gen_params)
|
model.container.generate(
|
||||||
|
data.prompt, request.state.id, **task_gen_params
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
generations = await asyncio.gather(*gen_tasks)
|
generations = await asyncio.gather(*gen_tasks)
|
||||||
response = _create_response(generations, model_path.name)
|
response = _create_response(request.state.id, generations, model_path.name)
|
||||||
|
|
||||||
|
logger.info(f"Finished completion request {request.state.id}")
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
error_message = handle_request_error(
|
error_message = handle_request_error(
|
||||||
"Completion aborted. Maybe the model was unloaded? "
|
f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
|
||||||
"Please check the server console."
|
"Please check the server console."
|
||||||
).error.message
|
).error.message
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
|
from uuid import uuid4
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
|
|
||||||
|
|
@ -25,6 +26,15 @@ app.add_middleware(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def add_request_id(request: Request, call_next):
|
||||||
|
"""Middleware to append an ID to a request"""
|
||||||
|
|
||||||
|
request.state.id = uuid4().hex
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
def setup_app():
|
def setup_app():
|
||||||
"""Includes the correct routers for startup"""
|
"""Includes the correct routers for startup"""
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue