Templates: Support list style chat_template keys

HuggingFace updated transformers to provide templates in a list for
tokenizers. Update to support this new format. Providing the name
of a template for the "prompt_template" value in config.yml will also
look inside the template list.

In addition, log if there's a template exception, but continue model
loading since it shouldn't shut down the application.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-07 11:17:55 -04:00
parent 5bb4995a7c
commit 46ac3beea9
3 changed files with 78 additions and 20 deletions

View file

@ -4,12 +4,15 @@ import json
import pathlib
from functools import lru_cache
from importlib.metadata import version as package_version
from typing import Optional
from jinja2 import Template, TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
from packaging import version
from pydantic import BaseModel
from common.utils import unwrap
class PromptTemplate(BaseModel):
"""A template for chat completion prompts."""
@ -18,6 +21,12 @@ class PromptTemplate(BaseModel):
template: str
class TemplateLoadError(Exception):
"""Raised on prompt template load"""
pass
def get_prompt_from_template(prompt_template: PromptTemplate, template_vars: dict):
"""Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
@ -91,7 +100,7 @@ def find_template_from_model(model_path: pathlib.Path):
if template_name in model_name.lower():
return template_name
else:
raise LookupError("Could not find template from model name.")
raise TemplateLoadError("Could not find template from model name.")
def get_template_from_file(prompt_template_name: str):
@ -105,18 +114,50 @@ def get_template_from_file(prompt_template_name: str):
)
else:
# Let the user know if the template file isn't found
raise FileNotFoundError(f'Template "{prompt_template_name}" not found.')
raise TemplateLoadError(
f'Chat template "{prompt_template_name}" not found in files.'
)
# Get a template from a JSON file
# Requires a key and template name
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
def get_template_from_model_json(
json_path: pathlib.Path, key: str, name: Optional[str] = None
):
"""Get a template from a JSON file. Requires a key and template name"""
if json_path.exists():
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
chat_template = model_config.get(key)
if chat_template:
return PromptTemplate(name=name, template=chat_template)
else:
raise FileNotFoundError(f'Model JSON path "{json_path}" not found.')
if not json_path.exists():
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
chat_template = model_config.get(key)
if not chat_template:
raise TemplateLoadError(
"Could not find a value from chat_template key in the passed JSON. "
"Check the tokenizer config?"
)
if isinstance(chat_template, list):
# Handles the new list style of chat templates
if name:
wrapped_template = next(
(x for x in chat_template if x.get("name") == name),
{},
)
else:
wrapped_template = chat_template[0]
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")
selected_template = wrapped_template.get("template")
if selected_template:
return PromptTemplate(name=name, template=selected_template)
else:
raise TemplateLoadError(
f'Chat template with name "{name}" not found '
"in model templates list."
)
else:
# Can safely assume the chat template is the old style
return PromptTemplate(name="from_tokenizer_config", template=chat_template)