tabbyAPI-ollama/endpoints/OAI/utils/completion.py
kingbri 1ec8eb9620 Tree: Format
Signed-off-by: kingbri <bdashore3@proton.me>
2024-03-13 00:02:55 -04:00

105 lines
3.3 KiB
Python

"""Completion utilities for OAI server."""
import pathlib
from fastapi import HTTPException, Request
from fastapi.concurrency import run_in_threadpool
from loguru import logger
from typing import Optional
from common import model
from common.generators import release_semaphore
from common.utils import get_generator_error, handle_request_error, unwrap
from endpoints.OAI.types.completion import (
CompletionRequest,
CompletionResponse,
CompletionRespChoice,
CompletionLogProbs,
)
from endpoints.OAI.types.common import UsageStats
def _create_response(generation: dict, model_name: Optional[str]):
"""Create a completion response from the provided text."""
logprob_response = None
token_probs = unwrap(generation.get("token_probs"), {})
if token_probs:
logprobs = unwrap(generation.get("logprobs"), [])
offset = unwrap(generation.get("offset"), [])
logprob_response = CompletionLogProbs(
text_offset=offset if isinstance(offset, list) else [offset],
token_logprobs=token_probs.values(),
tokens=token_probs.keys(),
top_logprobs=logprobs if isinstance(logprobs, list) else [logprobs],
)
choice = CompletionRespChoice(
finish_reason="Generated",
text=unwrap(generation.get("text"), ""),
logprobs=logprob_response,
)
prompt_tokens = unwrap(generation.get("prompt_tokens"), 0)
completion_tokens = unwrap(generation.get("generated_tokens"), 0)
response = CompletionResponse(
choices=[choice],
model=unwrap(model_name, ""),
usage=UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
),
)
return response
async def stream_generate_completion(
request: Request, data: CompletionRequest, model_path: pathlib.Path
):
"""Streaming generation for completions."""
try:
new_generation = model.container.generate_gen(
data.prompt, **data.to_gen_params()
)
for generation in new_generation:
# Get out if the request gets disconnected
if await request.is_disconnected():
release_semaphore()
logger.error("Completion generation cancelled by user.")
return
response = _create_response(generation, model_path.name)
yield response.model_dump_json()
# Yield a finish response on successful generation
yield "[DONE]"
except Exception:
yield get_generator_error(
"Completion aborted. Please check the server console."
)
async def generate_completion(data: CompletionRequest, model_path: pathlib.Path):
"""Non-streaming generate for completions"""
try:
generation = await run_in_threadpool(
model.container.generate, data.prompt, **data.to_gen_params()
)
response = _create_response(generation, model_path.name)
return response
except Exception as exc:
error_message = handle_request_error(
"Completion aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message
# Server error if there's a generation exception
raise HTTPException(503, error_message) from exc