import json import pathlib from functools import lru_cache from importlib.metadata import version as package_version from jinja2.sandbox import ImmutableSandboxedEnvironment from packaging import version from pydantic import BaseModel from typing import Optional, Dict # Small replication of AutoTokenizer's chat template system for efficiency class PromptTemplate(BaseModel): name: str template: str def get_prompt_from_template(messages, prompt_template: PromptTemplate, add_generation_prompt: bool, special_tokens: Optional[Dict[str, str]] = None): if version.parse(package_version("jinja2")) < version.parse("3.0.0"): raise ImportError( "Parsing these chat completion messages requires jinja2 3.0.0 or greater. " f"Current version: {version('jinja2')}\n" "Please upgrade jinja by running the following command: " "pip install --upgrade jinja2" ) compiled_template = _compile_template(prompt_template.template) return compiled_template.render( messages = messages, add_generation_prompt = add_generation_prompt, **special_tokens, ) # Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761 @lru_cache def _compile_template(template: str): jinja_env = ImmutableSandboxedEnvironment(trim_blocks = True, lstrip_blocks = True) jinja_template = jinja_env.from_string(template) return jinja_template # Find a matching template name from a model path def find_template_from_model(model_path: pathlib.Path): model_name = model_path.name template_directory = pathlib.Path("templates") for filepath in template_directory.glob("*.jinja"): template_name = filepath.stem.lower() # Check if the template name is present in the model name if template_name in model_name.lower(): return template_name # Get a template from a jinja file def get_template_from_file(prompt_template_name: str): template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja") if template_path.exists(): with open(template_path, "r", encoding = "utf8") as raw_template: return PromptTemplate( name = prompt_template_name, template = raw_template.read() ) return None # Get a template from a JSON file # Requires a key and template name def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str): if json_path.exists(): with open(json_path, "r", encoding = "utf8") as config_file: model_config = json.load(config_file) chat_template = model_config.get(key) if chat_template: return PromptTemplate( name = name, template = chat_template ) return None