Model: Make model params return a model card
The model card is a unified structure for sharing model params. Rather than kwargs, use this instead. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
9834c7f99b
commit
3f09fcd8c9
4 changed files with 38 additions and 51 deletions
|
|
@ -14,6 +14,7 @@ from common.multimodal import MultimodalEmbeddingWrapper
|
|||
from common.sampling import BaseSamplerRequest
|
||||
from common.templating import PromptTemplate
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from endpoints.core.types.model import ModelCard
|
||||
|
||||
|
||||
class BaseModelContainer(abc.ABC):
|
||||
|
|
@ -189,7 +190,7 @@ class BaseModelContainer(abc.ABC):
|
|||
|
||||
# TODO: Replace by yielding a model card
|
||||
@abc.abstractmethod
|
||||
def get_model_parameters(self) -> Dict[str, Any]:
|
||||
def model_info(self) -> ModelCard:
|
||||
"""
|
||||
Returns a dictionary of the current model's configuration parameters.
|
||||
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ from common.sampling import BaseSamplerRequest
|
|||
from common.templating import PromptTemplate, find_prompt_template
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from common.utils import calculate_rope_alpha, coalesce, unwrap
|
||||
from endpoints.core.types.model import ModelCard, ModelCardParameters
|
||||
|
||||
|
||||
class ExllamaV2Container(BaseModelContainer):
|
||||
|
|
@ -379,35 +380,43 @@ class ExllamaV2Container(BaseModelContainer):
|
|||
# Return the created instance
|
||||
return self
|
||||
|
||||
def get_model_parameters(self):
|
||||
model_params = {
|
||||
"name": self.model_dir.name,
|
||||
"rope_scale": self.config.scale_pos_emb,
|
||||
"rope_alpha": self.config.scale_alpha_value,
|
||||
"max_seq_len": self.config.max_seq_len,
|
||||
"max_batch_size": self.max_batch_size,
|
||||
"cache_size": self.cache_size,
|
||||
"cache_mode": self.cache_mode,
|
||||
"chunk_size": self.config.max_input_len,
|
||||
"use_vision": self.use_vision,
|
||||
}
|
||||
def model_info(self):
|
||||
draft_model_card: ModelCard = None
|
||||
if self.draft_config:
|
||||
draft_model_params = ModelCardParameters(
|
||||
max_seq_len=self.draft_config.max_seq_len,
|
||||
rope_scale=self.draft_config.scale_pos_emb,
|
||||
rope_alpha=self.draft_config.scale_alpha_value,
|
||||
cache_mode=self.draft_cache_mode,
|
||||
)
|
||||
|
||||
draft_model_card = ModelCard(
|
||||
id=self.draft_model_dir.name,
|
||||
parameters=draft_model_params,
|
||||
)
|
||||
|
||||
model_params = ModelCardParameters(
|
||||
max_seq_len=self.config.max_seq_len,
|
||||
cache_size=self.cache_size,
|
||||
rope_scale=self.config.scale_pos_emb,
|
||||
rope_alpha=self.config.scale_alpha_value,
|
||||
max_batch_size=self.max_batch_size,
|
||||
cache_mode=self.cache_mode,
|
||||
chunk_size=self.config.max_input_len,
|
||||
use_vision=self.use_vision,
|
||||
draft=draft_model_card,
|
||||
)
|
||||
|
||||
if self.prompt_template:
|
||||
model_params["prompt_template"] = self.prompt_template.name
|
||||
model_params["prompt_template_content"] = self.prompt_template.raw_template
|
||||
model_params.prompt_template = self.prompt_template.name
|
||||
model_params.prompt_template_content = self.prompt_template.raw_template
|
||||
|
||||
if self.draft_config:
|
||||
draft_model_params = {
|
||||
"name": self.draft_model_dir.name,
|
||||
"rope_scale": self.draft_config.scale_pos_emb,
|
||||
"rope_alpha": self.draft_config.scale_alpha_value,
|
||||
"max_seq_len": self.draft_config.max_seq_len,
|
||||
"cache_mode": self.draft_cache_mode,
|
||||
}
|
||||
model_card = ModelCard(
|
||||
id=self.model_dir.name,
|
||||
parameters=model_params,
|
||||
)
|
||||
|
||||
model_params["draft"] = draft_model_params
|
||||
|
||||
return model_params
|
||||
return model_card
|
||||
|
||||
async def wait_for_jobs(self, skip_wait: bool = False):
|
||||
"""Polling mechanism to wait for pending generation jobs."""
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse:
|
|||
async def get_max_length() -> MaxLengthResponse:
|
||||
"""Fetches the max length of the model."""
|
||||
|
||||
max_length = model.container.get_model_parameters().get("max_seq_len")
|
||||
max_length = model.container.model_info().parameters.max_seq_len
|
||||
return {"value": max_length}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -64,30 +64,7 @@ async def get_current_model_list(model_type: str = "model"):
|
|||
def get_current_model():
|
||||
"""Gets the current model with all parameters."""
|
||||
|
||||
model_params = model.container.get_model_parameters()
|
||||
draft_model_params = model_params.pop("draft", {})
|
||||
|
||||
if draft_model_params:
|
||||
model_params["draft"] = ModelCard(
|
||||
id=unwrap(draft_model_params.get("name"), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
else:
|
||||
draft_model_params = None
|
||||
|
||||
model_card = ModelCard(
|
||||
id=unwrap(model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(model_params),
|
||||
logging=config.logging,
|
||||
)
|
||||
|
||||
if draft_model_params:
|
||||
draft_card = ModelCard(
|
||||
id=unwrap(draft_model_params.pop("name", None), "unknown"),
|
||||
parameters=ModelCardParameters.model_validate(draft_model_params),
|
||||
)
|
||||
|
||||
model_card.parameters.draft = draft_card
|
||||
model_card = model.container.model_info()
|
||||
|
||||
return model_card
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue