From 46ac3beea9c9e52343d39349245d9f353c9628c4 Mon Sep 17 00:00:00 2001 From: kingbri Date: Sun, 7 Apr 2024 11:17:55 -0400 Subject: [PATCH] Templates: Support list style chat_template keys HuggingFace updated transformers to provide templates in a list for tokenizers. Update to support this new format. Providing the name of a template for the "prompt_template" value in config.yml will also look inside the template list. In addition, log if there's a template exception, but continue model loading since it shouldn't shut down the application. Signed-off-by: kingbri --- backends/exllamav2/model.py | 31 +++++++++++++----- common/templating.py | 63 ++++++++++++++++++++++++++++++------- config_sample.yml | 4 ++- 3 files changed, 78 insertions(+), 20 deletions(-) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 10fcb01..5e10845 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -4,6 +4,7 @@ import gc import pathlib import threading import time +import traceback import torch from exllamav2 import ( @@ -30,6 +31,7 @@ from common.gen_logging import ( ) from common.templating import ( PromptTemplate, + TemplateLoadError, find_template_from_model, get_template_from_model_json, get_template_from_file, @@ -194,7 +196,7 @@ class ExllamaV2Container: # Catch all for template lookup errors if self.prompt_template: logger.info( - f"Using template {self.prompt_template.name} " "for chat completions." + f'Using template "{self.prompt_template.name}" for chat completions.' ) else: logger.warning( @@ -259,23 +261,36 @@ class ExllamaV2Container: lambda: get_template_from_model_json( pathlib.Path(self.config.model_dir) / "tokenizer_config.json", "chat_template", - "from_tokenizer_config", ), lambda: get_template_from_file(find_template_from_model(model_directory)), ] # Add lookup from prompt template name if provided if prompt_template_name: - find_template_functions.insert( - 0, lambda: get_template_from_file(prompt_template_name) - ) + find_template_functions[:0] = [ + lambda: get_template_from_file(prompt_template_name), + lambda: get_template_from_model_json( + pathlib.Path(self.config.model_dir) / "tokenizer_config.json", + "chat_template", + prompt_template_name, + ), + ] - for func in find_template_functions: + # Continue on exception since functions are tried as they fail + for template_func in find_template_functions: try: - prompt_template = func() + prompt_template = template_func() if prompt_template is not None: return prompt_template - except (FileNotFoundError, LookupError): + except TemplateLoadError as e: + logger.warning(f"TemplateLoadError: {str(e)}") + continue + except Exception: + logger.error(traceback.format_exc()) + logger.warning( + "An unexpected error happened when trying to load the template. " + "Trying other methods." + ) continue def calculate_rope_alpha(self, base_seq_len): diff --git a/common/templating.py b/common/templating.py index d193845..6bd2a88 100644 --- a/common/templating.py +++ b/common/templating.py @@ -4,12 +4,15 @@ import json import pathlib from functools import lru_cache from importlib.metadata import version as package_version +from typing import Optional from jinja2 import Template, TemplateError from jinja2.sandbox import ImmutableSandboxedEnvironment from loguru import logger from packaging import version from pydantic import BaseModel +from common.utils import unwrap + class PromptTemplate(BaseModel): """A template for chat completion prompts.""" @@ -18,6 +21,12 @@ class PromptTemplate(BaseModel): template: str +class TemplateLoadError(Exception): + """Raised on prompt template load""" + + pass + + def get_prompt_from_template(prompt_template: PromptTemplate, template_vars: dict): """Get a prompt from a template and a list of messages.""" if version.parse(package_version("jinja2")) < version.parse("3.0.0"): @@ -91,7 +100,7 @@ def find_template_from_model(model_path: pathlib.Path): if template_name in model_name.lower(): return template_name else: - raise LookupError("Could not find template from model name.") + raise TemplateLoadError("Could not find template from model name.") def get_template_from_file(prompt_template_name: str): @@ -105,18 +114,50 @@ def get_template_from_file(prompt_template_name: str): ) else: # Let the user know if the template file isn't found - raise FileNotFoundError(f'Template "{prompt_template_name}" not found.') + raise TemplateLoadError( + f'Chat template "{prompt_template_name}" not found in files.' + ) # Get a template from a JSON file # Requires a key and template name -def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str): +def get_template_from_model_json( + json_path: pathlib.Path, key: str, name: Optional[str] = None +): """Get a template from a JSON file. Requires a key and template name""" - if json_path.exists(): - with open(json_path, "r", encoding="utf8") as config_file: - model_config = json.load(config_file) - chat_template = model_config.get(key) - if chat_template: - return PromptTemplate(name=name, template=chat_template) - else: - raise FileNotFoundError(f'Model JSON path "{json_path}" not found.') + if not json_path.exists(): + raise TemplateLoadError(f'Model JSON path "{json_path}" not found.') + + with open(json_path, "r", encoding="utf8") as config_file: + model_config = json.load(config_file) + chat_template = model_config.get(key) + + if not chat_template: + raise TemplateLoadError( + "Could not find a value from chat_template key in the passed JSON. " + "Check the tokenizer config?" + ) + + if isinstance(chat_template, list): + # Handles the new list style of chat templates + if name: + wrapped_template = next( + (x for x in chat_template if x.get("name") == name), + {}, + ) + else: + wrapped_template = chat_template[0] + name = unwrap(wrapped_template.get("name"), "from_tokenizer_config") + + selected_template = wrapped_template.get("template") + + if selected_template: + return PromptTemplate(name=name, template=selected_template) + else: + raise TemplateLoadError( + f'Chat template with name "{name}" not found ' + "in model templates list." + ) + else: + # Can safely assume the chat template is the old style + return PromptTemplate(name="from_tokenizer_config", template=chat_template) diff --git a/config_sample.yml b/config_sample.yml index 94f7ad5..11c6555 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -107,7 +107,9 @@ model: # Possible values FP16, FP8, Q4. (default: FP16) #cache_mode: FP16 - # Set the prompt template for this model. If empty, chat completions will be disabled. (default: Empty) + # Set the prompt template for this model. If empty, attempts to look for the model's chat template. (default: None) + # If a model contains multiple templates in its tokenizer_config.json, set prompt_template to the name + # of the template you want to use. # NOTE: Only works with chat completion message lists! #prompt_template: