API: Add KoboldAI server

Used for interacting with applications that use KoboldAI's API
such as horde.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-26 16:37:30 -04:00
parent 4e808cbed7
commit b7cb6f0b91
6 changed files with 330 additions and 3 deletions

View file

@ -828,10 +828,14 @@ class ExllamaV2Container:
return dict(zip_longest(top_tokens, cleaned_values))
async def generate(self, prompt: str, request_id: str, **kwargs):
async def generate(
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
):
"""Generate a response to a prompt"""
generations = []
async for generation in self.generate_gen(prompt, request_id, **kwargs):
async for generation in self.generate_gen(
prompt, request_id, abort_event, **kwargs
):
generations.append(generation)
joined_generation = {

103
endpoints/Kobold/router.py Normal file
View file

@ -0,0 +1,103 @@
from sys import maxsize
from fastapi import APIRouter, Depends, Request
from sse_starlette import EventSourceResponse
from common import model
from common.auth import check_api_key
from common.model import check_model_container
from common.utils import unwrap
from endpoints.Kobold.types.generation import (
AbortRequest,
CheckGenerateRequest,
GenerateRequest,
GenerateResponse,
)
from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse
from endpoints.Kobold.utils.generation import (
abort_generation,
generation_status,
get_generation,
stream_generation,
)
from endpoints.core.utils.model import get_current_model
router = APIRouter(prefix="/api")
@router.post(
"/v1/generate",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def generate(request: Request, data: GenerateRequest) -> GenerateResponse:
response = await get_generation(data, request)
return response
@router.post(
"/extra/generate/stream",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse:
response = EventSourceResponse(stream_generation(data, request), ping=maxsize)
return response
@router.post(
"/extra/abort",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def abort_generate(data: AbortRequest):
response = await abort_generation(data.genkey)
return response
@router.get(
"/extra/generate/check",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@router.post(
"/extra/generate/check",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
response = await generation_status(data.genkey)
return response
@router.get(
"/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
)
async def current_model():
"""Fetches the current model and who owns it."""
current_model_card = get_current_model()
return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"}
@router.post(
"/extra/tokencount",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_tokencount(data: TokenCountRequest):
raw_tokens = model.container.encode_tokens(data.prompt)
tokens = unwrap(raw_tokens, [])
return TokenCountResponse(value=len(tokens), ids=tokens)
@router.get("/v1/info/version")
async def get_version():
"""Impersonate KAI United."""
return {"result": "1.2.5"}
@router.get("/extra/version")
async def get_extra_version():
"""Impersonate Koboldcpp."""
return {"result": "KoboldCpp", "version": "1.61"}

View file

@ -0,0 +1,53 @@
from typing import List, Optional
from pydantic import BaseModel, Field
from common.sampling import BaseSamplerRequest, get_default_sampler_value
class GenerateRequest(BaseSamplerRequest):
prompt: str
use_default_badwordsids: Optional[bool] = False
genkey: Optional[str] = None
max_length: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("max_tokens"),
examples=[150],
)
rep_pen_range: Optional[int] = Field(
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
)
rep_pen: Optional[float] = Field(
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
)
def to_gen_params(self, **kwargs):
# Swap kobold generation params to OAI/Exl2 ones
self.max_tokens = self.max_length
self.repetition_penalty = self.rep_pen
self.penalty_range = -1 if self.rep_pen_range == 0 else self.rep_pen_range
return super().to_gen_params(**kwargs)
class GenerateResponseResult(BaseModel):
text: str
class GenerateResponse(BaseModel):
results: List[GenerateResponseResult] = Field(default_factory=list)
class StreamGenerateChunk(BaseModel):
token: str
class AbortRequest(BaseModel):
genkey: str
class AbortResponse(BaseModel):
success: bool
class CheckGenerateRequest(BaseModel):
genkey: str

View file

@ -0,0 +1,15 @@
from pydantic import BaseModel
from typing import List
class TokenCountRequest(BaseModel):
"""Represents a KAI tokenization request."""
prompt: str
class TokenCountResponse(BaseModel):
"""Represents a KAI tokenization response."""
value: int
ids: List[int]

View file

@ -0,0 +1,151 @@
import asyncio
from asyncio import CancelledError
from fastapi import HTTPException, Request
from loguru import logger
from sse_starlette import ServerSentEvent
from common import model
from common.networking import (
get_generator_error,
handle_request_disconnect,
handle_request_error,
request_disconnect_loop,
)
from common.utils import unwrap
from endpoints.Kobold.types.generation import (
AbortResponse,
GenerateRequest,
GenerateResponse,
GenerateResponseResult,
StreamGenerateChunk,
)
generation_cache = {}
async def override_request_id(request: Request, data: GenerateRequest):
"""Overrides the request ID with a KAI genkey if present."""
if data.genkey:
request.state.id = data.genkey
def _create_response(text: str):
results = [GenerateResponseResult(text=text)]
return GenerateResponse(results=results)
def _create_stream_chunk(text: str):
return StreamGenerateChunk(token=text)
async def _stream_collector(data: GenerateRequest, request: Request):
"""Common async generator for generation streams."""
abort_event = asyncio.Event()
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
# Create a new entry in the cache
generation_cache[data.genkey] = {"abort": abort_event, "text": ""}
try:
logger.info(f"Received Kobold generation request {data.genkey}")
generator = model.container.generate_gen(
data.prompt, data.genkey, abort_event, **data.to_gen_params()
)
async for generation in generator:
if disconnect_task.done():
abort_event.set()
handle_request_disconnect(
f"Kobold generation {data.genkey} cancelled by user."
)
text = generation.get("text")
# Update the generation cache with the new chunk
if text:
generation_cache[data.genkey]["text"] += text
yield text
if "finish_reason" in generation:
logger.info(f"Finished streaming Kobold request {data.genkey}")
break
except CancelledError:
# If the request disconnects, break out
if not disconnect_task.done():
abort_event.set()
handle_request_disconnect(
f"Kobold generation {data.genkey} cancelled by user."
)
finally:
# Cleanup the cache
del generation_cache[data.genkey]
async def stream_generation(data: GenerateRequest, request: Request):
"""Wrapper for stream generations."""
# If the genkey doesn't exist, set it to the request ID
if not data.genkey:
data.genkey = request.state.id
try:
async for chunk in _stream_collector(data, request):
response = _create_stream_chunk(chunk)
yield ServerSentEvent(
event="message", data=response.model_dump_json(), sep="\n"
)
except Exception:
yield get_generator_error(
f"Kobold generation {data.genkey} aborted. "
"Please check the server console."
)
async def get_generation(data: GenerateRequest, request: Request):
"""Wrapper to get a static generation."""
# If the genkey doesn't exist, set it to the request ID
if not data.genkey:
data.genkey = request.state.id
try:
full_text = ""
async for chunk in _stream_collector(data, request):
full_text += chunk
response = _create_response(full_text)
return response
except Exception as exc:
error_message = handle_request_error(
f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
"Please check the server console."
).error.message
# Server error if there's a generation exception
raise HTTPException(503, error_message) from exc
async def abort_generation(genkey: str):
"""Aborts a generation from the cache."""
abort_event = unwrap(generation_cache.get(genkey), {}).get("abort")
if abort_event:
abort_event.set()
handle_request_disconnect(f"Kobold generation {genkey} cancelled by user.")
return AbortResponse(success=True)
async def generation_status(genkey: str):
"""Fetches the status of a generation from the cache."""
current_text = unwrap(generation_cache.get(genkey), {}).get("text")
if current_text:
response = _create_response(current_text)
else:
response = GenerateResponse()
return response

View file

@ -9,6 +9,7 @@ from common.logger import UVICORN_LOG_CONFIG
from common.networking import get_global_depends
from common.utils import unwrap
from endpoints.core.router import router as CoreRouter
from endpoints.Kobold.router import router as KoboldRouter
from endpoints.OAI.router import router as OAIRouter
@ -37,7 +38,7 @@ def setup_app():
api_servers = unwrap(config.network_config().get("api_servers"), [])
# Map for API id to server router
router_mapping = {"oai": OAIRouter}
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}
# Include the OAI api by default
if api_servers: