tabbyAPI-ollama/endpoints/OAI/utils/completion.py
kingbri e2a8b6e8ae OAI: Add "n" support for streaming generations
Use a queue-based system to get choices independently and send them
in the overall streaming payload. This method allows for unordered
streaming of generations.

The system is a bit redundant, so maybe make the code more optimized
in the future.

Signed-off-by: kingbri <bdashore3@proton.me>
2024-05-28 00:52:30 -04:00

179 lines
5.5 KiB
Python

"""Completion utilities for OAI server."""
import asyncio
import pathlib
from asyncio import CancelledError
from copy import deepcopy
from fastapi import HTTPException, Request
from typing import List, Union
from common import model
from common.networking import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
request_disconnect_loop,
)
from common.utils import unwrap
from endpoints.OAI.types.completion import (
CompletionRequest,
CompletionResponse,
CompletionRespChoice,
CompletionLogProbs,
)
from endpoints.OAI.types.common import UsageStats
def _create_response(generations: Union[dict, List[dict]], model_name: str = ""):
"""Create a completion response from the provided choices."""
# Convert the single choice object into a list
if not isinstance(generations, list):
generations = [generations]
choices: List[CompletionRespChoice] = []
for index, generation in enumerate(generations):
logprob_response = None
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
offset = unwrap(generation.get("offset"), [])
logprob_response = CompletionLogProbs(
text_offset=offset if isinstance(offset, list) else [offset],
token_logprobs=token_probs.values(),
tokens=token_probs.keys(),
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
)
# The index can be located in the generation itself
choice = CompletionRespChoice(
index=unwrap(generation.get("index"), index),
finish_reason=generation.get("finish_reason"),
text=unwrap(generation.get("text"), ""),
logprobs=logprob_response,
)
choices.append(choice)
prompt_tokens = unwrap(generations[-1].get("prompt_tokens"), 0)
completion_tokens = unwrap(generations[-1].get("generated_tokens"), 0)
response = CompletionResponse(
choices=choices,
model=model_name,
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
abort_event: asyncio.Event,
**kwargs,
):
"""Collects a stream and places results in a common queue"""
new_generation = model.container.generate_gen(prompt, abort_event, **kwargs)
async for generation in new_generation:
generation["index"] = task_idx
await gen_queue.put(generation)
if "finish_reason" in generation:
break
async def stream_generate_completion(
data: CompletionRequest, request: Request, model_path: pathlib.Path
):
"""Streaming generation for completions."""
abort_event = asyncio.Event()
gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = []
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try:
gen_params = data.to_gen_params()
for n in range(0, data.n):
if n > 0:
task_gen_params = deepcopy(gen_params)
else:
task_gen_params = gen_params
gen_task = asyncio.create_task(
_stream_collector(
n, gen_queue, data.prompt, abort_event, **task_gen_params
)
)
gen_tasks.append(gen_task)
# Consumer loop
while True:
if disconnect_task.done():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
generation = await gen_queue.get()
response = _create_response(generation, model_path.name)
yield response.model_dump_json()
# Check if all tasks are completed
if all(task.done() for task in gen_tasks) and gen_queue.empty():
yield "[DONE]"
break
except CancelledError:
# Get out if the request gets disconnected
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
except Exception:
yield get_generator_error(
"Completion aborted. Please check the server console."
)
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
"""Non-streaming generate for completions"""
gen_tasks: List[asyncio.Task] = []
gen_params = data.to_gen_params()
try:
for n in range(0, data.n):
# Deepcopy gen params above the first index
# to ensure nested structures aren't shared
if n > 0:
task_gen_params = deepcopy(gen_params)
else:
task_gen_params = gen_params
gen_tasks.append(
asyncio.create_task(
model.container.generate(data.prompt, **task_gen_params)
)
)
generations = await asyncio.gather(*gen_tasks)
response = _create_response(generations, model_path.name)
return response
except Exception as exc:
error_message = handle_request_error(
"Completion aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message
# Server error if there's a generation exception
raise HTTPException(503, error_message) from exc