API: Add KoboldAI server
Used for interacting with applications that use KoboldAI's API such as horde. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
4e808cbed7
commit
b7cb6f0b91
6 changed files with 330 additions and 3 deletions
|
|
@ -828,10 +828,14 @@ class ExllamaV2Container:
|
|||
|
||||
return dict(zip_longest(top_tokens, cleaned_values))
|
||||
|
||||
async def generate(self, prompt: str, request_id: str, **kwargs):
|
||||
async def generate(
|
||||
self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs
|
||||
):
|
||||
"""Generate a response to a prompt"""
|
||||
generations = []
|
||||
async for generation in self.generate_gen(prompt, request_id, **kwargs):
|
||||
async for generation in self.generate_gen(
|
||||
prompt, request_id, abort_event, **kwargs
|
||||
):
|
||||
generations.append(generation)
|
||||
|
||||
joined_generation = {
|
||||
|
|
|
|||
103
endpoints/Kobold/router.py
Normal file
103
endpoints/Kobold/router.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
from sys import maxsize
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
||||
from common import model
|
||||
from common.auth import check_api_key
|
||||
from common.model import check_model_container
|
||||
from common.utils import unwrap
|
||||
from endpoints.Kobold.types.generation import (
|
||||
AbortRequest,
|
||||
CheckGenerateRequest,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
)
|
||||
from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse
|
||||
from endpoints.Kobold.utils.generation import (
|
||||
abort_generation,
|
||||
generation_status,
|
||||
get_generation,
|
||||
stream_generation,
|
||||
)
|
||||
from endpoints.core.utils.model import get_current_model
|
||||
|
||||
|
||||
router = APIRouter(prefix="/api")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/generate",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def generate(request: Request, data: GenerateRequest) -> GenerateResponse:
|
||||
response = await get_generation(data, request)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extra/generate/stream",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse:
|
||||
response = EventSourceResponse(stream_generation(data, request), ping=maxsize)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extra/abort",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def abort_generate(data: AbortRequest):
|
||||
response = await abort_generation(data.genkey)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/extra/generate/check",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
@router.post(
|
||||
"/extra/generate/check",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
|
||||
response = await generation_status(data.genkey)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
|
||||
)
|
||||
async def current_model():
|
||||
"""Fetches the current model and who owns it."""
|
||||
|
||||
current_model_card = get_current_model()
|
||||
return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"}
|
||||
|
||||
|
||||
@router.post(
|
||||
"/extra/tokencount",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def get_tokencount(data: TokenCountRequest):
|
||||
raw_tokens = model.container.encode_tokens(data.prompt)
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
return TokenCountResponse(value=len(tokens), ids=tokens)
|
||||
|
||||
|
||||
@router.get("/v1/info/version")
|
||||
async def get_version():
|
||||
"""Impersonate KAI United."""
|
||||
|
||||
return {"result": "1.2.5"}
|
||||
|
||||
|
||||
@router.get("/extra/version")
|
||||
async def get_extra_version():
|
||||
"""Impersonate Koboldcpp."""
|
||||
|
||||
return {"result": "KoboldCpp", "version": "1.61"}
|
||||
53
endpoints/Kobold/types/generation.py
Normal file
53
endpoints/Kobold/types/generation.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from common.sampling import BaseSamplerRequest, get_default_sampler_value
|
||||
|
||||
|
||||
class GenerateRequest(BaseSamplerRequest):
|
||||
prompt: str
|
||||
use_default_badwordsids: Optional[bool] = False
|
||||
genkey: Optional[str] = None
|
||||
|
||||
max_length: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("max_tokens"),
|
||||
examples=[150],
|
||||
)
|
||||
rep_pen_range: Optional[int] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("penalty_range", -1),
|
||||
)
|
||||
rep_pen: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0),
|
||||
)
|
||||
|
||||
def to_gen_params(self, **kwargs):
|
||||
# Swap kobold generation params to OAI/Exl2 ones
|
||||
self.max_tokens = self.max_length
|
||||
self.repetition_penalty = self.rep_pen
|
||||
self.penalty_range = -1 if self.rep_pen_range == 0 else self.rep_pen_range
|
||||
|
||||
return super().to_gen_params(**kwargs)
|
||||
|
||||
|
||||
class GenerateResponseResult(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class GenerateResponse(BaseModel):
|
||||
results: List[GenerateResponseResult] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StreamGenerateChunk(BaseModel):
|
||||
token: str
|
||||
|
||||
|
||||
class AbortRequest(BaseModel):
|
||||
genkey: str
|
||||
|
||||
|
||||
class AbortResponse(BaseModel):
|
||||
success: bool
|
||||
|
||||
|
||||
class CheckGenerateRequest(BaseModel):
|
||||
genkey: str
|
||||
15
endpoints/Kobold/types/token.py
Normal file
15
endpoints/Kobold/types/token.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
|
||||
class TokenCountRequest(BaseModel):
|
||||
"""Represents a KAI tokenization request."""
|
||||
|
||||
prompt: str
|
||||
|
||||
|
||||
class TokenCountResponse(BaseModel):
|
||||
"""Represents a KAI tokenization response."""
|
||||
|
||||
value: int
|
||||
ids: List[int]
|
||||
151
endpoints/Kobold/utils/generation.py
Normal file
151
endpoints/Kobold/utils/generation.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
import asyncio
|
||||
from asyncio import CancelledError
|
||||
from fastapi import HTTPException, Request
|
||||
from loguru import logger
|
||||
from sse_starlette import ServerSentEvent
|
||||
|
||||
from common import model
|
||||
from common.networking import (
|
||||
get_generator_error,
|
||||
handle_request_disconnect,
|
||||
handle_request_error,
|
||||
request_disconnect_loop,
|
||||
)
|
||||
from common.utils import unwrap
|
||||
from endpoints.Kobold.types.generation import (
|
||||
AbortResponse,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
GenerateResponseResult,
|
||||
StreamGenerateChunk,
|
||||
)
|
||||
|
||||
|
||||
generation_cache = {}
|
||||
|
||||
|
||||
async def override_request_id(request: Request, data: GenerateRequest):
|
||||
"""Overrides the request ID with a KAI genkey if present."""
|
||||
|
||||
if data.genkey:
|
||||
request.state.id = data.genkey
|
||||
|
||||
|
||||
def _create_response(text: str):
|
||||
results = [GenerateResponseResult(text=text)]
|
||||
return GenerateResponse(results=results)
|
||||
|
||||
|
||||
def _create_stream_chunk(text: str):
|
||||
return StreamGenerateChunk(token=text)
|
||||
|
||||
|
||||
async def _stream_collector(data: GenerateRequest, request: Request):
|
||||
"""Common async generator for generation streams."""
|
||||
|
||||
abort_event = asyncio.Event()
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
|
||||
# Create a new entry in the cache
|
||||
generation_cache[data.genkey] = {"abort": abort_event, "text": ""}
|
||||
|
||||
try:
|
||||
logger.info(f"Received Kobold generation request {data.genkey}")
|
||||
|
||||
generator = model.container.generate_gen(
|
||||
data.prompt, data.genkey, abort_event, **data.to_gen_params()
|
||||
)
|
||||
async for generation in generator:
|
||||
if disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect(
|
||||
f"Kobold generation {data.genkey} cancelled by user."
|
||||
)
|
||||
|
||||
text = generation.get("text")
|
||||
|
||||
# Update the generation cache with the new chunk
|
||||
if text:
|
||||
generation_cache[data.genkey]["text"] += text
|
||||
yield text
|
||||
|
||||
if "finish_reason" in generation:
|
||||
logger.info(f"Finished streaming Kobold request {data.genkey}")
|
||||
break
|
||||
except CancelledError:
|
||||
# If the request disconnects, break out
|
||||
if not disconnect_task.done():
|
||||
abort_event.set()
|
||||
handle_request_disconnect(
|
||||
f"Kobold generation {data.genkey} cancelled by user."
|
||||
)
|
||||
finally:
|
||||
# Cleanup the cache
|
||||
del generation_cache[data.genkey]
|
||||
|
||||
|
||||
async def stream_generation(data: GenerateRequest, request: Request):
|
||||
"""Wrapper for stream generations."""
|
||||
|
||||
# If the genkey doesn't exist, set it to the request ID
|
||||
if not data.genkey:
|
||||
data.genkey = request.state.id
|
||||
|
||||
try:
|
||||
async for chunk in _stream_collector(data, request):
|
||||
response = _create_stream_chunk(chunk)
|
||||
yield ServerSentEvent(
|
||||
event="message", data=response.model_dump_json(), sep="\n"
|
||||
)
|
||||
except Exception:
|
||||
yield get_generator_error(
|
||||
f"Kobold generation {data.genkey} aborted. "
|
||||
"Please check the server console."
|
||||
)
|
||||
|
||||
|
||||
async def get_generation(data: GenerateRequest, request: Request):
|
||||
"""Wrapper to get a static generation."""
|
||||
|
||||
# If the genkey doesn't exist, set it to the request ID
|
||||
if not data.genkey:
|
||||
data.genkey = request.state.id
|
||||
|
||||
try:
|
||||
full_text = ""
|
||||
async for chunk in _stream_collector(data, request):
|
||||
full_text += chunk
|
||||
|
||||
response = _create_response(full_text)
|
||||
return response
|
||||
except Exception as exc:
|
||||
error_message = handle_request_error(
|
||||
f"Completion {request.state.id} aborted. Maybe the model was unloaded? "
|
||||
"Please check the server console."
|
||||
).error.message
|
||||
|
||||
# Server error if there's a generation exception
|
||||
raise HTTPException(503, error_message) from exc
|
||||
|
||||
|
||||
async def abort_generation(genkey: str):
|
||||
"""Aborts a generation from the cache."""
|
||||
|
||||
abort_event = unwrap(generation_cache.get(genkey), {}).get("abort")
|
||||
if abort_event:
|
||||
abort_event.set()
|
||||
handle_request_disconnect(f"Kobold generation {genkey} cancelled by user.")
|
||||
|
||||
return AbortResponse(success=True)
|
||||
|
||||
|
||||
async def generation_status(genkey: str):
|
||||
"""Fetches the status of a generation from the cache."""
|
||||
|
||||
current_text = unwrap(generation_cache.get(genkey), {}).get("text")
|
||||
if current_text:
|
||||
response = _create_response(current_text)
|
||||
else:
|
||||
response = GenerateResponse()
|
||||
|
||||
return response
|
||||
|
|
@ -9,6 +9,7 @@ from common.logger import UVICORN_LOG_CONFIG
|
|||
from common.networking import get_global_depends
|
||||
from common.utils import unwrap
|
||||
from endpoints.core.router import router as CoreRouter
|
||||
from endpoints.Kobold.router import router as KoboldRouter
|
||||
from endpoints.OAI.router import router as OAIRouter
|
||||
|
||||
|
||||
|
|
@ -37,7 +38,7 @@ def setup_app():
|
|||
api_servers = unwrap(config.network_config().get("api_servers"), [])
|
||||
|
||||
# Map for API id to server router
|
||||
router_mapping = {"oai": OAIRouter}
|
||||
router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter}
|
||||
|
||||
# Include the OAI api by default
|
||||
if api_servers:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue