Kobold: Add extra routes for horde compatability
Needed to connect to horde. Also do some reordering to clean the router file up. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
2773517a16
commit
3038f668e8
1 changed files with 46 additions and 4 deletions
|
|
@ -6,12 +6,15 @@ from common import model
|
|||
from common.auth import check_api_key
|
||||
from common.model import check_model_container
|
||||
from common.utils import unwrap
|
||||
from endpoints.core.utils.model import get_current_model
|
||||
from endpoints.Kobold.types.generation import (
|
||||
AbortRequest,
|
||||
AbortResponse,
|
||||
CheckGenerateRequest,
|
||||
GenerateRequest,
|
||||
GenerateResponse,
|
||||
)
|
||||
from endpoints.Kobold.types.model import CurrentModelResponse, MaxLengthResponse
|
||||
from endpoints.Kobold.types.token import TokenCountRequest, TokenCountResponse
|
||||
from endpoints.Kobold.utils.generation import (
|
||||
abort_generation,
|
||||
|
|
@ -19,7 +22,6 @@ from endpoints.Kobold.utils.generation import (
|
|||
get_generation,
|
||||
stream_generation,
|
||||
)
|
||||
from endpoints.core.utils.model import get_current_model
|
||||
|
||||
|
||||
api_name = "KoboldAI"
|
||||
|
|
@ -65,7 +67,7 @@ async def generate_stream(request: Request, data: GenerateRequest) -> GenerateRe
|
|||
"/abort",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def abort_generate(data: AbortRequest):
|
||||
async def abort_generate(data: AbortRequest) -> AbortResponse:
|
||||
response = await abort_generation(data.genkey)
|
||||
|
||||
return response
|
||||
|
|
@ -88,7 +90,7 @@ async def check_generate(data: CheckGenerateRequest) -> GenerateResponse:
|
|||
@kai_router.get(
|
||||
"/model", dependencies=[Depends(check_api_key), Depends(check_model_container)]
|
||||
)
|
||||
async def current_model():
|
||||
async def current_model() -> CurrentModelResponse:
|
||||
"""Fetches the current model and who owns it."""
|
||||
|
||||
current_model_card = get_current_model()
|
||||
|
|
@ -99,12 +101,31 @@ async def current_model():
|
|||
"/tokencount",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def get_tokencount(data: TokenCountRequest):
|
||||
async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse:
|
||||
raw_tokens = model.container.encode_tokens(data.prompt)
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
return TokenCountResponse(value=len(tokens), ids=tokens)
|
||||
|
||||
|
||||
@kai_router.get(
|
||||
"/config/max_length",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
@kai_router.get(
|
||||
"/config/max_context_length",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
@extra_kai_router.get(
|
||||
"/true_max_context_length",
|
||||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def get_max_length() -> MaxLengthResponse:
|
||||
"""Fetches the max length of the model."""
|
||||
|
||||
max_length = model.container.get_model_parameters().get("max_seq_len")
|
||||
return {"value": max_length}
|
||||
|
||||
|
||||
@kai_router.get("/info/version")
|
||||
async def get_version():
|
||||
"""Impersonate KAI United."""
|
||||
|
|
@ -117,3 +138,24 @@ async def get_extra_version():
|
|||
"""Impersonate Koboldcpp."""
|
||||
|
||||
return {"result": "KoboldCpp", "version": "1.61"}
|
||||
|
||||
|
||||
@kai_router.get("/config/soft_prompts_list")
|
||||
async def get_available_softprompts():
|
||||
"""Used for KAI compliance."""
|
||||
|
||||
return {"values": []}
|
||||
|
||||
|
||||
@kai_router.get("/config/soft_prompt")
|
||||
async def get_current_softprompt():
|
||||
"""Used for KAI compliance."""
|
||||
|
||||
return {"value": ""}
|
||||
|
||||
|
||||
@kai_router.put("/config/soft_prompt")
|
||||
async def set_current_softprompt():
|
||||
"""Used for KAI compliance."""
|
||||
|
||||
return {}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue