From 90fb41a77a704c710e2ed528519e6d5aac8537ed Mon Sep 17 00:00:00 2001 From: kingbri Date: Wed, 24 Jan 2024 23:36:35 -0500 Subject: [PATCH] Model: Fix prompt template initialization The previous commit iterated through multiple try conditions which made it so the user has to provide a dummy prompt template. Now, template loading is fallback based. Run through a loop of functions and return if one of them succeeds. Signed-off-by: kingbri --- backends/exllamav2/model.py | 58 ++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index d0c9f44..ac939d4 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -158,32 +158,10 @@ class ExllamaV2Container: self.config.set_low_mem() """ - # Set prompt template override if provided - prompt_template_name = kwargs.get("prompt_template") - if prompt_template_name: - logger.info("Loading prompt template with name " f"{prompt_template_name}") - # Read the template - try: - self.prompt_template = get_template_from_file(prompt_template_name) - except FileNotFoundError: - self.prompt_template = None - - # Then try finding the template from the tokenizer_config.json - try: - self.prompt_template = get_template_from_model_json( - pathlib.Path(self.config.model_dir) / "tokenizer_config.json", - "chat_template", - "from_tokenizer_config", - ) - except FileNotFoundError: - self.prompt_template = None - - # If that fails, attempt fetching from model name - try: - template_match = find_template_from_model(model_directory) - self.prompt_template = get_template_from_file(template_match) - except (LookupError, FileNotFoundError): - self.prompt_template = None + # Try to set prompt template + self.prompt_template = self.find_prompt_template( + kwargs.get("prompt_template"), model_directory + ) # Catch all for template lookup errors if self.prompt_template: @@ -250,6 +228,34 @@ class ExllamaV2Container: self.draft_config.max_input_len = kwargs["chunk_size"] self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2 + def find_prompt_template(self, prompt_template_name, model_directory): + """Tries to find a prompt template using various methods""" + + logger.info("Loading prompt template with name " f"{prompt_template_name}") + + find_template_functions = [ + lambda: get_template_from_model_json( + pathlib.Path(self.config.model_dir) / "tokenizer_config.json", + "chat_template", + "from_tokenizer_config", + ), + lambda: get_template_from_file(find_template_from_model(model_directory)), + ] + + # Add lookup from prompt template name if provided + if prompt_template_name: + find_template_functions.insert( + 0, lambda: get_template_from_file(prompt_template_name) + ) + + for func in find_template_functions: + try: + prompt_template = func() + if prompt_template is not None: + return prompt_template + except (FileNotFoundError, LookupError): + continue + def calculate_rope_alpha(self, base_seq_len): """Calculate the rope alpha value for a given sequence length.""" ratio = self.config.max_seq_len / base_seq_len