From b7cb6f0b91c9299ca447107cac82f5dc65ca4d59 Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 26 Jul 2024 16:37:30 -0400 Subject: [PATCH] API: Add KoboldAI server Used for interacting with applications that use KoboldAI's API such as horde. Signed-off-by: kingbri --- backends/exllamav2/model.py | 8 +- endpoints/Kobold/router.py | 103 ++++++++++++++++++ endpoints/Kobold/types/generation.py | 53 ++++++++++ endpoints/Kobold/types/token.py | 15 +++ endpoints/Kobold/utils/generation.py | 151 +++++++++++++++++++++++++++ endpoints/server.py | 3 +- 6 files changed, 330 insertions(+), 3 deletions(-) create mode 100644 endpoints/Kobold/router.py create mode 100644 endpoints/Kobold/types/generation.py create mode 100644 endpoints/Kobold/types/token.py create mode 100644 endpoints/Kobold/utils/generation.py diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 200be6b..3515123 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -828,10 +828,14 @@ class ExllamaV2Container: return dict(zip_longest(top_tokens, cleaned_values)) - async def generate(self, prompt: str, request_id: str, **kwargs): + async def generate( + self, prompt: str, request_id: str, abort_event: asyncio.Event = None, **kwargs + ): """Generate a response to a prompt""" generations = [] - async for generation in self.generate_gen(prompt, request_id, **kwargs): + async for generation in self.generate_gen( + prompt, request_id, abort_event, **kwargs + ): generations.append(generation) joined_generation = { diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py new file mode 100644 index 0000000..c3265fc --- /dev/null +++ b/endpoints/Kobold/router.py @@ -0,0 +1,103 @@ +from sys import maxsize +from fastapi import APIRouter, Depends, Request +from sse_starlette import EventSourceResponse + +from common import model +from common.auth import check_api_key +from common.model import check_model_container +from common.utils import unwrap +from endpoints.Kobold.types.generation import ( + AbortRequest, + CheckGenerateRequest, + GenerateRequest, + GenerateResponse, +) +from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse +from endpoints.Kobold.utils.generation import ( + abort_generation, + generation_status, + get_generation, + stream_generation, +) +from endpoints.core.utils.model import get_current_model + + +router = APIRouter(prefix="/api") + + +@router.post( + "/v1/generate", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def generate(request: Request, data: GenerateRequest) -> GenerateResponse: + response = await get_generation(data, request) + + return response + + +@router.post( + "/extra/generate/stream", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse: + response = EventSourceResponse(stream_generation(data, request), ping=maxsize) + + return response + + +@router.post( + "/extra/abort", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def abort_generate(data: AbortRequest): + response = await abort_generation(data.genkey) + + return response + + +@router.get( + "/extra/generate/check", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +@router.post( + "/extra/generate/check", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def check_generate(data: CheckGenerateRequest) -> GenerateResponse: + response = await generation_status(data.genkey) + + return response + + +@router.get( + "/v1/model", dependencies=[Depends(check_api_key), Depends(check_model_container)] +) +async def current_model(): + """Fetches the current model and who owns it.""" + + current_model_card = get_current_model() + return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"} + + +@router.post( + "/extra/tokencount", + dependencies=[Depends(check_api_key), Depends(check_model_container)], +) +async def get_tokencount(data: TokenCountRequest): + raw_tokens = model.container.encode_tokens(data.prompt) + tokens = unwrap(raw_tokens, []) + return TokenCountResponse(value=len(tokens), ids=tokens) + + +@router.get("/v1/info/version") +async def get_version(): + """Impersonate KAI United.""" + + return {"result": "1.2.5"} + + +@router.get("/extra/version") +async def get_extra_version(): + """Impersonate Koboldcpp.""" + + return {"result": "KoboldCpp", "version": "1.61"} diff --git a/endpoints/Kobold/types/generation.py b/endpoints/Kobold/types/generation.py new file mode 100644 index 0000000..5468741 --- /dev/null +++ b/endpoints/Kobold/types/generation.py @@ -0,0 +1,53 @@ +from typing import List, Optional + +from pydantic import BaseModel, Field +from common.sampling import BaseSamplerRequest, get_default_sampler_value + + +class GenerateRequest(BaseSamplerRequest): + prompt: str + use_default_badwordsids: Optional[bool] = False + genkey: Optional[str] = None + + max_length: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("max_tokens"), + examples=[150], + ) + rep_pen_range: Optional[int] = Field( + default_factory=lambda: get_default_sampler_value("penalty_range", -1), + ) + rep_pen: Optional[float] = Field( + default_factory=lambda: get_default_sampler_value("repetition_penalty", 1.0), + ) + + def to_gen_params(self, **kwargs): + # Swap kobold generation params to OAI/Exl2 ones + self.max_tokens = self.max_length + self.repetition_penalty = self.rep_pen + self.penalty_range = -1 if self.rep_pen_range == 0 else self.rep_pen_range + + return super().to_gen_params(**kwargs) + + +class GenerateResponseResult(BaseModel): + text: str + + +class GenerateResponse(BaseModel): + results: List[GenerateResponseResult] = Field(default_factory=list) + + +class StreamGenerateChunk(BaseModel): + token: str + + +class AbortRequest(BaseModel): + genkey: str + + +class AbortResponse(BaseModel): + success: bool + + +class CheckGenerateRequest(BaseModel): + genkey: str diff --git a/endpoints/Kobold/types/token.py b/endpoints/Kobold/types/token.py new file mode 100644 index 0000000..e6639d9 --- /dev/null +++ b/endpoints/Kobold/types/token.py @@ -0,0 +1,15 @@ +from pydantic import BaseModel +from typing import List + + +class TokenCountRequest(BaseModel): + """Represents a KAI tokenization request.""" + + prompt: str + + +class TokenCountResponse(BaseModel): + """Represents a KAI tokenization response.""" + + value: int + ids: List[int] diff --git a/endpoints/Kobold/utils/generation.py b/endpoints/Kobold/utils/generation.py new file mode 100644 index 0000000..5febcff --- /dev/null +++ b/endpoints/Kobold/utils/generation.py @@ -0,0 +1,151 @@ +import asyncio +from asyncio import CancelledError +from fastapi import HTTPException, Request +from loguru import logger +from sse_starlette import ServerSentEvent + +from common import model +from common.networking import ( + get_generator_error, + handle_request_disconnect, + handle_request_error, + request_disconnect_loop, +) +from common.utils import unwrap +from endpoints.Kobold.types.generation import ( + AbortResponse, + GenerateRequest, + GenerateResponse, + GenerateResponseResult, + StreamGenerateChunk, +) + + +generation_cache = {} + + +async def override_request_id(request: Request, data: GenerateRequest): + """Overrides the request ID with a KAI genkey if present.""" + + if data.genkey: + request.state.id = data.genkey + + +def _create_response(text: str): + results = [GenerateResponseResult(text=text)] + return GenerateResponse(results=results) + + +def _create_stream_chunk(text: str): + return StreamGenerateChunk(token=text) + + +async def _stream_collector(data: GenerateRequest, request: Request): + """Common async generator for generation streams.""" + + abort_event = asyncio.Event() + disconnect_task = asyncio.create_task(request_disconnect_loop(request)) + + # Create a new entry in the cache + generation_cache[data.genkey] = {"abort": abort_event, "text": ""} + + try: + logger.info(f"Received Kobold generation request {data.genkey}") + + generator = model.container.generate_gen( + data.prompt, data.genkey, abort_event, **data.to_gen_params() + ) + async for generation in generator: + if disconnect_task.done(): + abort_event.set() + handle_request_disconnect( + f"Kobold generation {data.genkey} cancelled by user." + ) + + text = generation.get("text") + + # Update the generation cache with the new chunk + if text: + generation_cache[data.genkey]["text"] += text + yield text + + if "finish_reason" in generation: + logger.info(f"Finished streaming Kobold request {data.genkey}") + break + except CancelledError: + # If the request disconnects, break out + if not disconnect_task.done(): + abort_event.set() + handle_request_disconnect( + f"Kobold generation {data.genkey} cancelled by user." + ) + finally: + # Cleanup the cache + del generation_cache[data.genkey] + + +async def stream_generation(data: GenerateRequest, request: Request): + """Wrapper for stream generations.""" + + # If the genkey doesn't exist, set it to the request ID + if not data.genkey: + data.genkey = request.state.id + + try: + async for chunk in _stream_collector(data, request): + response = _create_stream_chunk(chunk) + yield ServerSentEvent( + event="message", data=response.model_dump_json(), sep="\n" + ) + except Exception: + yield get_generator_error( + f"Kobold generation {data.genkey} aborted. " + "Please check the server console." + ) + + +async def get_generation(data: GenerateRequest, request: Request): + """Wrapper to get a static generation.""" + + # If the genkey doesn't exist, set it to the request ID + if not data.genkey: + data.genkey = request.state.id + + try: + full_text = "" + async for chunk in _stream_collector(data, request): + full_text += chunk + + response = _create_response(full_text) + return response + except Exception as exc: + error_message = handle_request_error( + f"Completion {request.state.id} aborted. Maybe the model was unloaded? " + "Please check the server console." + ).error.message + + # Server error if there's a generation exception + raise HTTPException(503, error_message) from exc + + +async def abort_generation(genkey: str): + """Aborts a generation from the cache.""" + + abort_event = unwrap(generation_cache.get(genkey), {}).get("abort") + if abort_event: + abort_event.set() + handle_request_disconnect(f"Kobold generation {genkey} cancelled by user.") + + return AbortResponse(success=True) + + +async def generation_status(genkey: str): + """Fetches the status of a generation from the cache.""" + + current_text = unwrap(generation_cache.get(genkey), {}).get("text") + if current_text: + response = _create_response(current_text) + else: + response = GenerateResponse() + + return response diff --git a/endpoints/server.py b/endpoints/server.py index 401b211..dfb1cdd 100644 --- a/endpoints/server.py +++ b/endpoints/server.py @@ -9,6 +9,7 @@ from common.logger import UVICORN_LOG_CONFIG from common.networking import get_global_depends from common.utils import unwrap from endpoints.core.router import router as CoreRouter +from endpoints.Kobold.router import router as KoboldRouter from endpoints.OAI.router import router as OAIRouter @@ -37,7 +38,7 @@ def setup_app(): api_servers = unwrap(config.network_config().get("api_servers"), []) # Map for API id to server router - router_mapping = {"oai": OAIRouter} + router_mapping = {"oai": OAIRouter, "kobold": KoboldRouter} # Include the OAI api by default if api_servers: