tabbyAPI-ollama/endpoints/OAI/utils/completion.py
kingbri 660f9b8432 OAI: Fix request cancellation behavior
Depending on the day of the week, Starlette can work with a CancelledError
or using await request.is_disconnected(). Run the same behavior for both
cases and allow cancellation.

Streaming requests now set an event to cancel the batched job and break
out of the generation loop.

Signed-off-by: kingbri <bdashore3@proton.me>
2024-05-26 13:00:33 -04:00

115 lines
3.7 KiB
Python

"""Completion utilities for OAI server."""
import asyncio
import pathlib
from asyncio import CancelledError
from fastapi import HTTPException, Request
from typing import Optional
from common import model
from common.networking import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
)
from common.utils import unwrap
from endpoints.OAI.types.completion import (
CompletionRequest,
CompletionResponse,
CompletionRespChoice,
CompletionLogProbs,
)
from endpoints.OAI.types.common import UsageStats
def _create_response(generation: dict, model_name: Optional[str]):
"""Create a completion response from the provided text."""
logprob_response = None
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
offset = unwrap(generation.get("offset"), [])
logprob_response = CompletionLogProbs(
text_offset=offset if isinstance(offset, list) else [offset],
token_logprobs=token_probs.values(),
tokens=token_probs.keys(),
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
)
choice = CompletionRespChoice(
finish_reason=generation.get("finish_reason"),
text=unwrap(generation.get("text"), ""),
logprobs=logprob_response,
)
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
response = CompletionResponse(
choices=[choice],
model=unwrap(model_name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
async def stream_generate_completion(
data: CompletionRequest, request: Request, model_path: pathlib.Path
):
"""Streaming generation for completions."""
abort_event = asyncio.Event()
try:
new_generation = model.container.generate_gen(
data.prompt, abort_event, **data.to_gen_params()
)
async for generation in new_generation:
# Sometimes this fires, and sometimes a CancelledError will fire
# Keep both implementations in to avoid the headache
if await request.is_disconnected():
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
response = _create_response(generation, model_path.name)
yield response.model_dump_json()
# Break if the generation is finished
if "finish_reason" in generation:
yield "[DONE]"
break
except CancelledError:
# Get out if the request gets disconnected
abort_event.set()
handle_request_disconnect("Completion generation cancelled by user.")
except Exception:
yield get_generator_error(
"Completion aborted. Please check the server console."
)
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
"""Non-streaming generate for completions"""
try:
generation = await model.container.generate(data.prompt, **data.to_gen_params())
response = _create_response(generation, model_path.name)
return response
except Exception as exc:
error_message = handle_request_error(
"Completion aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message
# Server error if there's a generation exception
raise HTTPException(503, error_message) from exc