Templates: Support list style chat_template keys
HuggingFace updated transformers to provide templates in a list for tokenizers. Update to support this new format. Providing the name of a template for the "prompt_template" value in config.yml will also look inside the template list. In addition, log if there's a template exception, but continue model loading since it shouldn't shut down the application. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
5bb4995a7c
commit
46ac3beea9
3 changed files with 78 additions and 20 deletions
|
|
@ -4,12 +4,15 @@ import json
|
|||
import pathlib
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import version as package_version
|
||||
from typing import Optional
|
||||
from jinja2 import Template, TemplateError
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from loguru import logger
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
class PromptTemplate(BaseModel):
|
||||
"""A template for chat completion prompts."""
|
||||
|
|
@ -18,6 +21,12 @@ class PromptTemplate(BaseModel):
|
|||
template: str
|
||||
|
||||
|
||||
class TemplateLoadError(Exception):
|
||||
"""Raised on prompt template load"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_prompt_from_template(prompt_template: PromptTemplate, template_vars: dict):
|
||||
"""Get a prompt from a template and a list of messages."""
|
||||
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
||||
|
|
@ -91,7 +100,7 @@ def find_template_from_model(model_path: pathlib.Path):
|
|||
if template_name in model_name.lower():
|
||||
return template_name
|
||||
else:
|
||||
raise LookupError("Could not find template from model name.")
|
||||
raise TemplateLoadError("Could not find template from model name.")
|
||||
|
||||
|
||||
def get_template_from_file(prompt_template_name: str):
|
||||
|
|
@ -105,18 +114,50 @@ def get_template_from_file(prompt_template_name: str):
|
|||
)
|
||||
else:
|
||||
# Let the user know if the template file isn't found
|
||||
raise FileNotFoundError(f'Template "{prompt_template_name}" not found.')
|
||||
raise TemplateLoadError(
|
||||
f'Chat template "{prompt_template_name}" not found in files.'
|
||||
)
|
||||
|
||||
|
||||
# Get a template from a JSON file
|
||||
# Requires a key and template name
|
||||
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
|
||||
def get_template_from_model_json(
|
||||
json_path: pathlib.Path, key: str, name: Optional[str] = None
|
||||
):
|
||||
"""Get a template from a JSON file. Requires a key and template name"""
|
||||
if json_path.exists():
|
||||
with open(json_path, "r", encoding="utf8") as config_file:
|
||||
model_config = json.load(config_file)
|
||||
chat_template = model_config.get(key)
|
||||
if chat_template:
|
||||
return PromptTemplate(name=name, template=chat_template)
|
||||
else:
|
||||
raise FileNotFoundError(f'Model JSON path "{json_path}" not found.')
|
||||
if not json_path.exists():
|
||||
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
|
||||
|
||||
with open(json_path, "r", encoding="utf8") as config_file:
|
||||
model_config = json.load(config_file)
|
||||
chat_template = model_config.get(key)
|
||||
|
||||
if not chat_template:
|
||||
raise TemplateLoadError(
|
||||
"Could not find a value from chat_template key in the passed JSON. "
|
||||
"Check the tokenizer config?"
|
||||
)
|
||||
|
||||
if isinstance(chat_template, list):
|
||||
# Handles the new list style of chat templates
|
||||
if name:
|
||||
wrapped_template = next(
|
||||
(x for x in chat_template if x.get("name") == name),
|
||||
{},
|
||||
)
|
||||
else:
|
||||
wrapped_template = chat_template[0]
|
||||
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")
|
||||
|
||||
selected_template = wrapped_template.get("template")
|
||||
|
||||
if selected_template:
|
||||
return PromptTemplate(name=name, template=selected_template)
|
||||
else:
|
||||
raise TemplateLoadError(
|
||||
f'Chat template with name "{name}" not found '
|
||||
"in model templates list."
|
||||
)
|
||||
else:
|
||||
# Can safely assume the chat template is the old style
|
||||
return PromptTemplate(name="from_tokenizer_config", template=chat_template)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue