154 lines
4.5 KiB
Python
154 lines
4.5 KiB
Python
import asyncio
|
|
from asyncio import CancelledError
|
|
from fastapi import HTTPException, Request
|
|
from loguru import logger
|
|
from sse_starlette.event import ServerSentEvent
|
|
|
|
from common import model
|
|
from common.networking import (
|
|
get_generator_error,
|
|
handle_request_disconnect,
|
|
handle_request_error,
|
|
request_disconnect_loop,
|
|
)
|
|
from common.utils import unwrap
|
|
from endpoints.Kobold.types.generation import (
|
|
AbortResponse,
|
|
GenerateRequest,
|
|
GenerateResponse,
|
|
GenerateResponseResult,
|
|
StreamGenerateChunk,
|
|
)
|
|
|
|
|
|
generation_cache = {}
|
|
|
|
|
|
async def override_request_id(request: Request, data: GenerateRequest):
|
|
"""Overrides the request ID with a KAI genkey if present."""
|
|
|
|
if data.genkey:
|
|
request.state.id = data.genkey
|
|
|
|
|
|
def _create_response(text: str):
|
|
results = [GenerateResponseResult(text=text)]
|
|
return GenerateResponse(results=results)
|
|
|
|
|
|
def _create_stream_chunk(text: str):
|
|
return StreamGenerateChunk(token=text)
|
|
|
|
|
|
async def _stream_collector(data: GenerateRequest, request: Request):
|
|
"""Common async generator for generation streams."""
|
|
|
|
abort_event = asyncio.Event()
|
|
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
|
|
|
# Create a new entry in the cache
|
|
generation_cache[data.genkey] = {"abort": abort_event, "text": ""}
|
|
|
|
try:
|
|
logger.info(f"Received Kobold generation request {data.genkey}")
|
|
|
|
generator = model.container.stream_generate(
|
|
request_id=data.genkey,
|
|
prompt=data.prompt,
|
|
params=data,
|
|
abort_event=abort_event,
|
|
)
|
|
|
|
async for generation in generator:
|
|
if disconnect_task.done():
|
|
abort_event.set()
|
|
handle_request_disconnect(
|
|
f"Kobold generation {data.genkey} cancelled by user."
|
|
)
|
|
|
|
text = generation.get("text")
|
|
|
|
# Update the generation cache with the new chunk
|
|
if text:
|
|
generation_cache[data.genkey]["text"] += text
|
|
yield text
|
|
|
|
if "finish_reason" in generation:
|
|
logger.info(f"Finished streaming Kobold request {data.genkey}")
|
|
break
|
|
except CancelledError:
|
|
# If the request disconnects, break out
|
|
if not disconnect_task.done():
|
|
abort_event.set()
|
|
handle_request_disconnect(
|
|
f"Kobold generation {data.genkey} cancelled by user."
|
|
)
|
|
finally:
|
|
# Cleanup the cache
|
|
del generation_cache[data.genkey]
|
|
|
|
|
|
async def stream_generation(data: GenerateRequest, request: Request):
|
|
"""Wrapper for stream generations."""
|
|
|
|
# If the genkey doesn't exist, set it to the request ID
|
|
if not data.genkey:
|
|
data.genkey = request.state.id
|
|
|
|
try:
|
|
async for chunk in _stream_collector(data, request):
|
|
response = _create_stream_chunk(chunk)
|
|
yield ServerSentEvent(
|
|
event="message", data=response.model_dump_json(), sep="\n"
|
|
)
|
|
except Exception:
|
|
yield get_generator_error(
|
|
f"Kobold generation {data.genkey} aborted. Please check the server console."
|
|
)
|
|
|
|
|
|
async def get_generation(data: GenerateRequest, request: Request):
|
|
"""Wrapper to get a static generation."""
|
|
|
|
# If the genkey doesn't exist, set it to the request ID
|
|
if not data.genkey:
|
|
data.genkey = request.state.id
|
|
|
|
try:
|
|
full_text = ""
|
|
async for chunk in _stream_collector(data, request):
|
|
full_text += chunk
|
|
|
|
response = _create_response(full_text)
|
|
return response
|
|
except Exception as exc:
|
|
error_message = handle_request_error(
|
|
f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
|
|
"Please check the server console."
|
|
).error.message
|
|
|
|
# Server error if there's a generation exception
|
|
raise HTTPException(503, error_message) from exc
|
|
|
|
|
|
async def abort_generation(genkey: str):
|
|
"""Aborts a generation from the cache."""
|
|
|
|
abort_event = unwrap(generation_cache.get(genkey), {}).get("abort")
|
|
if abort_event:
|
|
abort_event.set()
|
|
handle_request_disconnect(f"Kobold generation {genkey} cancelled by user.")
|
|
|
|
return AbortResponse(success=True)
|
|
|
|
|
|
async def generation_status(genkey: str):
|
|
"""Fetches the status of a generation from the cache."""
|
|
|
|
current_text = unwrap(generation_cache.get(genkey), {}).get("text")
|
|
if current_text:
|
|
response = _create_response(current_text)
|
|
else:
|
|
response = GenerateResponse()
|
|
|
|
return response
|