tabbyAPI-ollama/endpoints/OAI/utils/completion.py
kingbri 6f03be9523 API: Split functions into their own files
Previously, generation function were bundled with the request function
causing the overall code structure and API to look ugly and unreadable.

Split these up and cleanup a lot of the methods that were previously
overlooked in the API itself.

Signed-off-by: kingbri <bdashore3@proton.me>
2024-03-12 23:59:30 -04:00

104 lines
3.3 KiB
Python

"""Completion utilities for OAI server."""
import pathlib
from fastapi import HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from loguru import logger
from typing import Optional
from common import model
from common.generators import release_semaphore
from common.utils import get_generator_error, handle_request_error, unwrap
from endpoints.OAI.types.completion import (
CompletionRequest,
CompletionResponse,
CompletionRespChoice,
CompletionLogProbs,
)
from endpoints.OAI.types.common import UsageStats
def _create_response(generation: dict, model_name: Optional[str]):
"""Create a completion response from the provided text."""
logprob_response = None
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
offset = unwrap(generation.get("offset"), [])
logprob_response = CompletionLogProbs(
text_offset=offset if isinstance(offset, list) else [offset],
token_logprobs=token_probs.values(),
tokens=token_probs.keys(),
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
)
choice = CompletionRespChoice(
finish_reason="Generated",
text=unwrap(generation.get("text"), ""),
logprobs=logprob_response,
)
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
response = CompletionResponse(
choices=[choice],
model=unwrap(model_name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
async def stream_generate_completion(
request: Request, data: CompletionRequest, model_path: pathlib.Path
):
"""Streaming generation for completions."""
try:
new_generation = model.container.generate_gen(
data.prompt, **data.to_gen_params()
)
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Completion generation cancelled by user.")
return
response = _create_response(generation, model_path.name)
yield response.model_dump_json()
# Yield a finish response on successful generation
yield "[DONE]"
except Exception:
yield get_generator_error(
"Completion aborted. Please check the server console."
)
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
"""Non-streaming generate for completions"""
try:
generation = await run_in_threadpool(
model.container.generate, data.prompt, **data.to_gen_params()
)
response = _create_response(generation, model_path.name)
return response
except Exception as exc:
error_message = handle_request_error(
"Completion aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message
# Server error if there's a generation exception
raise HTTPException(503, error_message) from exc