Kobold: Add extra routes for horde compatability

Needed to connect to horde. Also do some reordering to clean the
router file up.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-07-26 22:50:01 -04:00
parent 2773517a16
commit 3038f668e8

View file

@ -6,12 +6,15 @@ from common import model
from common.auth import check_api_key
from common.model import check_model_container
from common.utils import unwrap
from endpoints.core.utils.model import get_current_model
from endpoints.Kobold.types.generation import (
AbortRequest,
AbortResponse,
CheckGenerateRequest,
GenerateRequest,
GenerateResponse,
)
from endpoints.Kobold.types.model import CurrentModelResponse, MaxLengthResponse
from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse
from endpoints.Kobold.utils.generation import (
abort_generation,
@ -19,7 +22,6 @@ from endpoints.Kobold.utils.generation import (
get_generation,
stream_generation,
)
from endpoints.core.utils.model import get_current_model
api_name = "KoboldAI"
@ -65,7 +67,7 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe
"/abort",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def abort_generate(data: AbortRequest):
async def abort_generate(data: AbortRequest) -> AbortResponse:
response = await abort_generation(data.genkey)
return response
@ -88,7 +90,7 @@ async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
@kai_router.get(
"/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
)
async def current_model():
async def current_model() -> CurrentModelResponse:
"""Fetches the current model and who owns it."""
current_model_card = get_current_model()
@ -99,12 +101,31 @@ async def current_model():
"/tokencount",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_tokencount(data: TokenCountRequest):
async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse:
raw_tokens = model.container.encode_tokens(data.prompt)
tokens = unwrap(raw_tokens, [])
return TokenCountResponse(value=len(tokens), ids=tokens)
@kai_router.get(
"/config/max_length",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@kai_router.get(
"/config/max_context_length",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
@extra_kai_router.get(
"/true_max_context_length",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def get_max_length() -> MaxLengthResponse:
"""Fetches the max length of the model."""
max_length = model.container.get_model_parameters().get("max_seq_len")
return {"value": max_length}
@kai_router.get("/info/version")
async def get_version():
"""Impersonate KAI United."""
@ -117,3 +138,24 @@ async def get_extra_version():
"""Impersonate Koboldcpp."""
return {"result": "KoboldCpp", "version": "1.61"}
@kai_router.get("/config/soft_prompts_list")
async def get_available_softprompts():
"""Used for KAI compliance."""
return {"values": []}
@kai_router.get("/config/soft_prompt")
async def get_current_softprompt():
"""Used for KAI compliance."""
return {"value": ""}
@kai_router.put("/config/soft_prompt")
async def set_current_softprompt():
"""Used for KAI compliance."""
return {}