Endpoints: Add props endpoint and add more values to model params

The props endpoint is a standard used by llamacpp APIs which returns
various properties of a model to a server. It's still recommended to
use /v1/model to get all the parameters a TabbyAPI model has.

Also include the contents of a prompt template when fetching the current
model.

Signed-off-by: kingbri <8082010+bdashore3@users.noreply.github.com>
This commit is contained in:
kingbri 2024-12-26 17:32:19 -05:00
parent fa8035ef72
commit 7878d351a7
3 changed files with 51 additions and 3 deletions

View file

@ -486,16 +486,18 @@ class ExllamaV2Container:
"rope_scale": self.config.scale_pos_emb,
"rope_alpha": self.config.scale_alpha_value,
"max_seq_len": self.config.max_seq_len,
"max_batch_size": self.max_batch_size,
"cache_size": self.cache_size,
"cache_mode": self.cache_mode,
"chunk_size": self.config.max_input_len,
"num_experts_per_token": self.config.num_experts_per_token,
"prompt_template": self.prompt_template.name
if self.prompt_template
else None,
"use_vision": self.use_vision,
}
if self.prompt_template:
model_params["prompt_template"] = self.prompt_template.name
model_params["prompt_template_content"] = self.prompt_template.raw_template
if self.draft_config:
draft_model_params = {
"name": self.draft_model_dir.name,
@ -759,6 +761,10 @@ class ExllamaV2Container:
max_batch_size=self.max_batch_size,
paged=self.paged,
)
# Update the state of the container var
if self.max_batch_size is None:
self.max_batch_size = self.generator.generator.max_batch_size
finally:
# This means the generator is being recreated
# The load lock is already released in the load function

View file

@ -23,9 +23,11 @@ from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadRespons
from endpoints.core.types.model import (
EmbeddingModelLoadRequest,
ModelCard,
ModelDefaultGenerationSettings,
ModelList,
ModelLoadRequest,
ModelLoadResponse,
ModelPropsResponse,
)
from endpoints.core.types.health import HealthCheckResponse
from endpoints.core.types.sampler_overrides import (
@ -131,6 +133,30 @@ async def current_model() -> ModelCard:
return get_current_model()
@router.get(
"/props", dependencies=[Depends(check_api_key), Depends(check_model_container)]
)
async def model_props() -> ModelPropsResponse:
"""
Returns specific properties of a model for clients.
To get all properties, use /v1/model instead.
"""
current_model_card = get_current_model()
resp = ModelPropsResponse(
total_slots=current_model_card.parameters.max_batch_size,
default_generation_settings=ModelDefaultGenerationSettings(
n_ctx=current_model_card.parameters.max_seq_len,
),
)
if current_model_card.parameters.prompt_template_content:
resp.chat_template = current_model_card.parameters.prompt_template_content
return resp
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models(request: Request) -> ModelList:
"""

View file

@ -16,10 +16,12 @@ class ModelCardParameters(BaseModel):
max_seq_len: Optional[int] = None
rope_scale: Optional[float] = 1.0
rope_alpha: Optional[float] = 1.0
max_batch_size: Optional[int] = 1
cache_size: Optional[int] = None
cache_mode: Optional[str] = "FP16"
chunk_size: Optional[int] = 2048
prompt_template: Optional[str] = None
prompt_template_content: Optional[str] = None
num_experts_per_token: Optional[int] = None
use_vision: Optional[bool] = False
@ -139,3 +141,17 @@ class ModelLoadResponse(BaseModel):
module: int
modules: int
status: str
class ModelDefaultGenerationSettings(BaseModel):
"""Contains default generation settings for model props."""
n_ctx: int
class ModelPropsResponse(BaseModel):
"""Represents a model props response."""
total_slots: int = 1
chat_template: str = ""
default_generation_settings: ModelDefaultGenerationSettings