diff --git a/backends/base_model_container.py b/backends/base_model_container.py index 022a382..1eeae5a 100644 --- a/backends/base_model_container.py +++ b/backends/base_model_container.py @@ -14,6 +14,7 @@ from common.multimodal import MultimodalEmbeddingWrapper from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate from common.transformers_utils import GenerationConfig +from endpoints.core.types.model import ModelCard class BaseModelContainer(abc.ABC): @@ -189,7 +190,7 @@ class BaseModelContainer(abc.ABC): # TODO: Replace by yielding a model card @abc.abstractmethod - def get_model_parameters(self) -> Dict[str, Any]: + def model_info(self) -> ModelCard: """ Returns a dictionary of the current model's configuration parameters. diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 0cf1076..01d13d4 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -52,6 +52,7 @@ from common.sampling import BaseSamplerRequest from common.templating import PromptTemplate, find_prompt_template from common.transformers_utils import GenerationConfig from common.utils import calculate_rope_alpha, coalesce, unwrap +from endpoints.core.types.model import ModelCard, ModelCardParameters class ExllamaV2Container(BaseModelContainer): @@ -379,35 +380,43 @@ class ExllamaV2Container(BaseModelContainer): # Return the created instance return self - def get_model_parameters(self): - model_params = { - "name": self.model_dir.name, - "rope_scale": self.config.scale_pos_emb, - "rope_alpha": self.config.scale_alpha_value, - "max_seq_len": self.config.max_seq_len, - "max_batch_size": self.max_batch_size, - "cache_size": self.cache_size, - "cache_mode": self.cache_mode, - "chunk_size": self.config.max_input_len, - "use_vision": self.use_vision, - } + def model_info(self): + draft_model_card: ModelCard = None + if self.draft_config: + draft_model_params = ModelCardParameters( + max_seq_len=self.draft_config.max_seq_len, + rope_scale=self.draft_config.scale_pos_emb, + rope_alpha=self.draft_config.scale_alpha_value, + cache_mode=self.draft_cache_mode, + ) + + draft_model_card = ModelCard( + id=self.draft_model_dir.name, + parameters=draft_model_params, + ) + + model_params = ModelCardParameters( + max_seq_len=self.config.max_seq_len, + cache_size=self.cache_size, + rope_scale=self.config.scale_pos_emb, + rope_alpha=self.config.scale_alpha_value, + max_batch_size=self.max_batch_size, + cache_mode=self.cache_mode, + chunk_size=self.config.max_input_len, + use_vision=self.use_vision, + draft=draft_model_card, + ) if self.prompt_template: - model_params["prompt_template"] = self.prompt_template.name - model_params["prompt_template_content"] = self.prompt_template.raw_template + model_params.prompt_template = self.prompt_template.name + model_params.prompt_template_content = self.prompt_template.raw_template - if self.draft_config: - draft_model_params = { - "name": self.draft_model_dir.name, - "rope_scale": self.draft_config.scale_pos_emb, - "rope_alpha": self.draft_config.scale_alpha_value, - "max_seq_len": self.draft_config.max_seq_len, - "cache_mode": self.draft_cache_mode, - } + model_card = ModelCard( + id=self.model_dir.name, + parameters=model_params, + ) - model_params["draft"] = draft_model_params - - return model_params + return model_card async def wait_for_jobs(self, skip_wait: bool = False): """Polling mechanism to wait for pending generation jobs.""" diff --git a/endpoints/Kobold/router.py b/endpoints/Kobold/router.py index ea894ea..0e9c210 100644 --- a/endpoints/Kobold/router.py +++ b/endpoints/Kobold/router.py @@ -122,7 +122,7 @@ async def get_tokencount(data: TokenCountRequest) -> TokenCountResponse: async def get_max_length() -> MaxLengthResponse: """Fetches the max length of the model.""" - max_length = model.container.get_model_parameters().get("max_seq_len") + max_length = model.container.model_info().parameters.max_seq_len return {"value": max_length} diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index c2c209b..81805d7 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -64,30 +64,7 @@ async def get_current_model_list(model_type: str = "model"): def get_current_model(): """Gets the current model with all parameters.""" - model_params = model.container.get_model_parameters() - draft_model_params = model_params.pop("draft", {}) - - if draft_model_params: - model_params["draft"] = ModelCard( - id=unwrap(draft_model_params.get("name"), "unknown"), - parameters=ModelCardParameters.model_validate(draft_model_params), - ) - else: - draft_model_params = None - - model_card = ModelCard( - id=unwrap(model_params.pop("name", None), "unknown"), - parameters=ModelCardParameters.model_validate(model_params), - logging=config.logging, - ) - - if draft_model_params: - draft_card = ModelCard( - id=unwrap(draft_model_params.pop("name", None), "unknown"), - parameters=ModelCardParameters.model_validate(draft_model_params), - ) - - model_card.parameters.draft = draft_card + model_card = model.container.model_info() return model_card