Model: Make model params return a model card

The model card is a unified structure for sharing model params.
Rather than kwargs, use this instead.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-04-21 23:15:46 -04:00
parent 9834c7f99b
commit 3f09fcd8c9
4 changed files with 38 additions and 51 deletions

View file

@ -14,6 +14,7 @@ from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate
from common.transformers_utils import GenerationConfig
from endpoints.core.types.model import ModelCard
class BaseModelContainer(abc.ABC):
@ -189,7 +190,7 @@ class BaseModelContainer(abc.ABC):
# TODO: Replace by yielding a model card
@abc.abstractmethod
def get_model_parameters(self) -> Dict[str, Any]:
def model_info(self) -> ModelCard:
"""
Returns a dictionary of the current model's configuration parameters.

View file

@ -52,6 +52,7 @@ from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate, find_prompt_template
from common.transformers_utils import GenerationConfig
from common.utils import calculate_rope_alpha, coalesce, unwrap
from endpoints.core.types.model import ModelCard, ModelCardParameters
class ExllamaV2Container(BaseModelContainer):
@ -379,35 +380,43 @@ class ExllamaV2Container(BaseModelContainer):
# Return the created instance
return self
def get_model_parameters(self):
model_params = {
"name": self.model_dir.name,
"rope_scale": self.config.scale_pos_emb,
"rope_alpha": self.config.scale_alpha_value,
"max_seq_len": self.config.max_seq_len,
"max_batch_size": self.max_batch_size,
"cache_size": self.cache_size,
"cache_mode": self.cache_mode,
"chunk_size": self.config.max_input_len,
"use_vision": self.use_vision,
}
def model_info(self):
draft_model_card: ModelCard = None
if self.draft_config:
draft_model_params = ModelCardParameters(
max_seq_len=self.draft_config.max_seq_len,
rope_scale=self.draft_config.scale_pos_emb,
rope_alpha=self.draft_config.scale_alpha_value,
cache_mode=self.draft_cache_mode,
)
draft_model_card = ModelCard(
id=self.draft_model_dir.name,
parameters=draft_model_params,
)
model_params = ModelCardParameters(
max_seq_len=self.config.max_seq_len,
cache_size=self.cache_size,
rope_scale=self.config.scale_pos_emb,
rope_alpha=self.config.scale_alpha_value,
max_batch_size=self.max_batch_size,
cache_mode=self.cache_mode,
chunk_size=self.config.max_input_len,
use_vision=self.use_vision,
draft=draft_model_card,
)
if self.prompt_template:
model_params["prompt_template"] = self.prompt_template.name
model_params["prompt_template_content"] = self.prompt_template.raw_template
model_params.prompt_template = self.prompt_template.name
model_params.prompt_template_content = self.prompt_template.raw_template
if self.draft_config:
draft_model_params = {
"name": self.draft_model_dir.name,
"rope_scale": self.draft_config.scale_pos_emb,
"rope_alpha": self.draft_config.scale_alpha_value,
"max_seq_len": self.draft_config.max_seq_len,
"cache_mode": self.draft_cache_mode,
}
model_card = ModelCard(
id=self.model_dir.name,
parameters=model_params,
)
model_params["draft"] = draft_model_params
return model_params
return model_card
async def wait_for_jobs(self, skip_wait: bool = False):
"""Polling mechanism to wait for pending generation jobs."""

View file

@ -122,7 +122,7 @@ async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse:
async def get_max_length() -> MaxLengthResponse:
"""Fetches the max length of the model."""
max_length = model.container.get_model_parameters().get("max_seq_len")
max_length = model.container.model_info().parameters.max_seq_len
return {"value": max_length}

View file

@ -64,30 +64,7 @@ async def get_current_model_list(model_type: str = "model"):
def get_current_model():
"""Gets the current model with all parameters."""
model_params = model.container.get_model_parameters()
draft_model_params = model_params.pop("draft", {})
if draft_model_params:
model_params["draft"] = ModelCard(
id=unwrap(draft_model_params.get("name"), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
else:
draft_model_params = None
model_card = ModelCard(
id=unwrap(model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(model_params),
logging=config.logging,
)
if draft_model_params:
draft_card = ModelCard(
id=unwrap(draft_model_params.pop("name", None), "unknown"),
parameters=ModelCardParameters.model_validate(draft_model_params),
)
model_card.parameters.draft = draft_card
model_card = model.container.model_info()
return model_card