From dd55b99af5e687cd38082bbebd68ded900c838cb Mon Sep 17 00:00:00 2001 From: kingbri Date: Thu, 29 Aug 2024 22:49:20 -0400 Subject: [PATCH] Model: Store directory paths Storing a pathlib type makes it easier to manipulate the model directory path in the long run without constantly fetching it from the config. Signed-off-by: kingbri --- backends/exllamav2/model.py | 21 ++++++++------------- common/model.py | 2 +- endpoints/OAI/router.py | 4 ++-- endpoints/core/utils/model.py | 11 ++++++----- 4 files changed, 17 insertions(+), 21 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 501b3ab..d7e3bf9 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -63,6 +63,10 @@ except ImportError: class ExllamaV2Container: """The model container class for ExLlamaV2 models.""" + # Model directories + model_dir: pathlib.Path = pathlib.Path("models") + draft_model_dir: pathlib.Path = pathlib.Path("models") + # Exl2 vars config: Optional[ExLlamaV2Config] = None draft_config: Optional[ExLlamaV2Config] = None @@ -110,6 +114,7 @@ class ExllamaV2Container: # Initialize config self.config = ExLlamaV2Config() + self.model_dir = model_directory self.config.model_dir = str(model_directory.resolve()) # Make the max seq len 4096 before preparing the config @@ -142,6 +147,7 @@ class ExllamaV2Container: ) draft_model_path = draft_model_path / draft_model_name + self.draft_model_dir = draft_model_path self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() @@ -403,20 +409,9 @@ class ExllamaV2Container: alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio**2 return alpha - def get_model_path(self, is_draft: bool = False): - """Get the path for this model.""" - - if is_draft and not self.draft_config: - return None - - model_path = pathlib.Path( - self.draft_config.model_dir if is_draft else self.config.model_dir - ) - return model_path - def get_model_parameters(self): model_params = { - "name": self.get_model_path().name, + "name": self.model_dir.name, "rope_scale": self.config.scale_pos_emb, "rope_alpha": self.config.scale_alpha_value, "max_seq_len": self.config.max_seq_len, @@ -431,7 +426,7 @@ class ExllamaV2Container: if self.draft_config: draft_model_params = { - "name": self.get_model_path(is_draft=True).name, + "name": self.draft_model_dir.name, "rope_scale": self.draft_config.scale_pos_emb, "rope_alpha": self.draft_config.scale_alpha_value, "max_seq_len": self.draft_config.max_seq_len, diff --git a/common/model.py b/common/model.py index 0bfbab2..d1b49af 100644 --- a/common/model.py +++ b/common/model.py @@ -57,7 +57,7 @@ async def load_model_gen(model_path: pathlib.Path, **kwargs): # Check if the model is already loaded if container and container.model: - loaded_model_name = container.get_model_path().name + loaded_model_name = container.model_dir.name if loaded_model_name == model_path.name and container.model_loaded: raise ValueError( diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index eb2445a..66bc759 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -52,7 +52,7 @@ async def completion_request( If stream = true, this returns an SSE stream. """ - model_path = model.container.get_model_path() + model_path = model.container.model_dir if isinstance(data.prompt, list): data.prompt = "\n".join(data.prompt) @@ -105,7 +105,7 @@ async def chat_completion_request( raise HTTPException(422, error_message) - model_path = model.container.get_model_path() + model_path = model.container.model_dir if isinstance(data.messages, str): prompt = data.messages diff --git a/endpoints/core/utils/model.py b/endpoints/core/utils/model.py index fc61337..15f85b8 100644 --- a/endpoints/core/utils/model.py +++ b/endpoints/core/utils/model.py @@ -43,11 +43,12 @@ async def get_current_model_list(model_type: str = "model"): model_path = None # Make sure the model container exists - if model_type == "model" or model_type == "draft": - if model.container: - model_path = model.container.get_model_path(model_type == "draft") - elif model_type == "embedding": - if model.embeddings_container: + match model_type: + case "model": + model_path = model.container.model_dir + case "draft": + model_path = model.container.draft_model_dir + case "embedding": model_path = model.embeddings_container.model_dir if model_path: