From 7878d351a71d51bb09f3829d1f40f7748995ad55 Mon Sep 17 00:00:00 2001 From: kingbri <8082010+bdashore3@users.noreply.github.com> Date: Thu, 26 Dec 2024 17:32:19 -0500 Subject: [PATCH] Endpoints: Add props endpoint and add more values to model params The props endpoint is a standard used by llamacpp APIs which returns various properties of a model to a server. It's still recommended to use /v1/model to get all the parameters a TabbyAPI model has. Also include the contents of a prompt template when fetching the current model. Signed-off-by: kingbri <8082010+bdashore3@users.noreply.github.com> --- backends/exllamav2/model.py | 12 +++++++++--- endpoints/core/router.py | 26 ++++++++++++++++++++++++++ endpoints/core/types/model.py | 16 ++++++++++++++++ 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 50cef42..2beca23 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -486,16 +486,18 @@ class ExllamaV2Container: "rope_scale": self.config.scale_pos_emb, "rope_alpha": self.config.scale_alpha_value, "max_seq_len": self.config.max_seq_len, + "max_batch_size": self.max_batch_size, "cache_size": self.cache_size, "cache_mode": self.cache_mode, "chunk_size": self.config.max_input_len, "num_experts_per_token": self.config.num_experts_per_token, - "prompt_template": self.prompt_template.name - if self.prompt_template - else None, "use_vision": self.use_vision, } + if self.prompt_template: + model_params["prompt_template"] = self.prompt_template.name + model_params["prompt_template_content"] = self.prompt_template.raw_template + if self.draft_config: draft_model_params = { "name": self.draft_model_dir.name, @@ -759,6 +761,10 @@ class ExllamaV2Container: max_batch_size=self.max_batch_size, paged=self.paged, ) + + # Update the state of the container var + if self.max_batch_size is None: + self.max_batch_size = self.generator.generator.max_batch_size finally: # This means the generator is being recreated # The load lock is already released in the load function diff --git a/endpoints/core/router.py b/endpoints/core/router.py index 6b48182..6217475 100644 --- a/endpoints/core/router.py +++ b/endpoints/core/router.py @@ -23,9 +23,11 @@ from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadRespons from endpoints.core.types.model import ( EmbeddingModelLoadRequest, ModelCard, + ModelDefaultGenerationSettings, ModelList, ModelLoadRequest, ModelLoadResponse, + ModelPropsResponse, ) from endpoints.core.types.health import HealthCheckResponse from endpoints.core.types.sampler_overrides import ( @@ -131,6 +133,30 @@ async def current_model() -> ModelCard: return get_current_model() +@router.get( + "/props", dependencies=[Depends(check_api_key), Depends(check_model_container)] +) +async def model_props() -> ModelPropsResponse: + """ + Returns specific properties of a model for clients. + + To get all properties, use /v1/model instead. + """ + + current_model_card = get_current_model() + resp = ModelPropsResponse( + total_slots=current_model_card.parameters.max_batch_size, + default_generation_settings=ModelDefaultGenerationSettings( + n_ctx=current_model_card.parameters.max_seq_len, + ), + ) + + if current_model_card.parameters.prompt_template_content: + resp.chat_template = current_model_card.parameters.prompt_template_content + + return resp + + @router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)]) async def list_draft_models(request: Request) -> ModelList: """ diff --git a/endpoints/core/types/model.py b/endpoints/core/types/model.py index ddf1cc2..8a2e55e 100644 --- a/endpoints/core/types/model.py +++ b/endpoints/core/types/model.py @@ -16,10 +16,12 @@ class ModelCardParameters(BaseModel): max_seq_len: Optional[int] = None rope_scale: Optional[float] = 1.0 rope_alpha: Optional[float] = 1.0 + max_batch_size: Optional[int] = 1 cache_size: Optional[int] = None cache_mode: Optional[str] = "FP16" chunk_size: Optional[int] = 2048 prompt_template: Optional[str] = None + prompt_template_content: Optional[str] = None num_experts_per_token: Optional[int] = None use_vision: Optional[bool] = False @@ -139,3 +141,17 @@ class ModelLoadResponse(BaseModel): module: int modules: int status: str + + +class ModelDefaultGenerationSettings(BaseModel): + """Contains default generation settings for model props.""" + + n_ctx: int + + +class ModelPropsResponse(BaseModel): + """Represents a model props response.""" + + total_slots: int = 1 + chat_template: str = "" + default_generation_settings: ModelDefaultGenerationSettings