The model card is a unified structure for sharing model params. Rather than kwargs, use this instead. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
161 lines
4.3 KiB
Python
161 lines
4.3 KiB
Python
from sys import maxsize
|
|
from fastapi import APIRouter, Depends, Request
|
|
from sse_starlette import EventSourceResponse
|
|
|
|
from common import model
|
|
from common.auth import check_api_key
|
|
from common.model import check_model_container
|
|
from common.utils import unwrap
|
|
from endpoints.core.utils.model import get_current_model
|
|
from endpoints.Kobold.types.generation import (
|
|
AbortRequest,
|
|
AbortResponse,
|
|
CheckGenerateRequest,
|
|
GenerateRequest,
|
|
GenerateResponse,
|
|
)
|
|
from endpoints.Kobold.types.model import CurrentModelResponse, MaxLengthResponse
|
|
from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse
|
|
from endpoints.Kobold.utils.generation import (
|
|
abort_generation,
|
|
generation_status,
|
|
get_generation,
|
|
stream_generation,
|
|
)
|
|
|
|
|
|
api_name = "KoboldAI"
|
|
router = APIRouter(prefix="/api")
|
|
urls = {
|
|
"Generation": "http://{host}:{port}/api/v1/generate",
|
|
"Streaming": "http://{host}:{port}/api/extra/generate/stream",
|
|
}
|
|
|
|
kai_router = APIRouter()
|
|
extra_kai_router = APIRouter()
|
|
|
|
|
|
def setup():
|
|
router.include_router(kai_router, prefix="/v1")
|
|
router.include_router(kai_router, prefix="/latest", include_in_schema=False)
|
|
router.include_router(extra_kai_router, prefix="/extra")
|
|
|
|
return router
|
|
|
|
|
|
@kai_router.post(
|
|
"/generate",
|
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
|
)
|
|
async def generate(request: Request, data: GenerateRequest) -> GenerateResponse:
|
|
response = await get_generation(data, request)
|
|
|
|
return response
|
|
|
|
|
|
@extra_kai_router.post(
|
|
"/generate/stream",
|
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
|
)
|
|
async def generate_stream(request: Request, data: GenerateRequest) -> GenerateResponse:
|
|
response = EventSourceResponse(stream_generation(data, request), ping=maxsize)
|
|
|
|
return response
|
|
|
|
|
|
@extra_kai_router.post(
|
|
"/abort",
|
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
|
)
|
|
async def abort_generate(data: AbortRequest) -> AbortResponse:
|
|
response = await abort_generation(data.genkey)
|
|
|
|
return response
|
|
|
|
|
|
@extra_kai_router.get(
|
|
"/generate/check",
|
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
|
)
|
|
@extra_kai_router.post(
|
|
"/generate/check",
|
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
|
)
|
|
async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
|
|
response = await generation_status(data.genkey)
|
|
|
|
return response
|
|
|
|
|
|
@kai_router.get(
|
|
"/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
|
|
)
|
|
async def current_model() -> CurrentModelResponse:
|
|
"""Fetches the current model and who owns it."""
|
|
|
|
current_model_card = get_current_model()
|
|
return {"result": f"{current_model_card.owned_by}/{current_model_card.id}"}
|
|
|
|
|
|
@extra_kai_router.post(
|
|
"/tokencount",
|
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
|
)
|
|
async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse:
|
|
raw_tokens = model.container.encode_tokens(data.prompt)
|
|
tokens = unwrap(raw_tokens, [])
|
|
return TokenCountResponse(value=len(tokens), ids=tokens)
|
|
|
|
|
|
@kai_router.get(
|
|
"/config/max_length",
|
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
|
)
|
|
@kai_router.get(
|
|
"/config/max_context_length",
|
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
|
)
|
|
@extra_kai_router.get(
|
|
"/true_max_context_length",
|
|
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
|
)
|
|
async def get_max_length() -> MaxLengthResponse:
|
|
"""Fetches the max length of the model."""
|
|
|
|
max_length = model.container.model_info().parameters.max_seq_len
|
|
return {"value": max_length}
|
|
|
|
|
|
@kai_router.get("/info/version")
|
|
async def get_version():
|
|
"""Impersonate KAI United."""
|
|
|
|
return {"result": "1.2.5"}
|
|
|
|
|
|
@extra_kai_router.get("/version")
|
|
async def get_extra_version():
|
|
"""Impersonate Koboldcpp."""
|
|
|
|
return {"result": "KoboldCpp", "version": "1.74"}
|
|
|
|
|
|
@kai_router.get("/config/soft_prompts_list")
|
|
async def get_available_softprompts():
|
|
"""Used for KAI compliance."""
|
|
|
|
return {"values": []}
|
|
|
|
|
|
@kai_router.get("/config/soft_prompt")
|
|
async def get_current_softprompt():
|
|
"""Used for KAI compliance."""
|
|
|
|
return {"value": ""}
|
|
|
|
|
|
@kai_router.put("/config/soft_prompt")
|
|
async def set_current_softprompt():
|
|
"""Used for KAI compliance."""
|
|
|
|
return {}
|