Templates: Support list style chat_template keys
HuggingFace updated transformers to provide templates in a list for tokenizers. Update to support this new format. Providing the name of a template for the "prompt_template" value in config.yml will also look inside the template list. In addition, log if there's a template exception, but continue model loading since it shouldn't shut down the application. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
5bb4995a7c
commit
46ac3beea9
3 changed files with 78 additions and 20 deletions
|
|
@ -4,6 +4,7 @@ import gc
|
|||
import pathlib
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import torch
|
||||
from exllamav2 import (
|
||||
|
|
@ -30,6 +31,7 @@ from common.gen_logging import (
|
|||
)
|
||||
from common.templating import (
|
||||
PromptTemplate,
|
||||
TemplateLoadError,
|
||||
find_template_from_model,
|
||||
get_template_from_model_json,
|
||||
get_template_from_file,
|
||||
|
|
@ -194,7 +196,7 @@ class ExllamaV2Container:
|
|||
# Catch all for template lookup errors
|
||||
if self.prompt_template:
|
||||
logger.info(
|
||||
f"Using template {self.prompt_template.name} " "for chat completions."
|
||||
f'Using template "{self.prompt_template.name}" for chat completions.'
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
|
|
@ -259,23 +261,36 @@ class ExllamaV2Container:
|
|||
lambda: get_template_from_model_json(
|
||||
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||
"chat_template",
|
||||
"from_tokenizer_config",
|
||||
),
|
||||
lambda: get_template_from_file(find_template_from_model(model_directory)),
|
||||
]
|
||||
|
||||
# Add lookup from prompt template name if provided
|
||||
if prompt_template_name:
|
||||
find_template_functions.insert(
|
||||
0, lambda: get_template_from_file(prompt_template_name)
|
||||
)
|
||||
find_template_functions[:0] = [
|
||||
lambda: get_template_from_file(prompt_template_name),
|
||||
lambda: get_template_from_model_json(
|
||||
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||
"chat_template",
|
||||
prompt_template_name,
|
||||
),
|
||||
]
|
||||
|
||||
for func in find_template_functions:
|
||||
# Continue on exception since functions are tried as they fail
|
||||
for template_func in find_template_functions:
|
||||
try:
|
||||
prompt_template = func()
|
||||
prompt_template = template_func()
|
||||
if prompt_template is not None:
|
||||
return prompt_template
|
||||
except (FileNotFoundError, LookupError):
|
||||
except TemplateLoadError as e:
|
||||
logger.warning(f"TemplateLoadError: {str(e)}")
|
||||
continue
|
||||
except Exception:
|
||||
logger.error(traceback.format_exc())
|
||||
logger.warning(
|
||||
"An unexpected error happened when trying to load the template. "
|
||||
"Trying other methods."
|
||||
)
|
||||
continue
|
||||
|
||||
def calculate_rope_alpha(self, base_seq_len):
|
||||
|
|
|
|||
|
|
@ -4,12 +4,15 @@ import json
|
|||
import pathlib
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import version as package_version
|
||||
from typing import Optional
|
||||
from jinja2 import Template, TemplateError
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from loguru import logger
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
class PromptTemplate(BaseModel):
|
||||
"""A template for chat completion prompts."""
|
||||
|
|
@ -18,6 +21,12 @@ class PromptTemplate(BaseModel):
|
|||
template: str
|
||||
|
||||
|
||||
class TemplateLoadError(Exception):
|
||||
"""Raised on prompt template load"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_prompt_from_template(prompt_template: PromptTemplate, template_vars: dict):
|
||||
"""Get a prompt from a template and a list of messages."""
|
||||
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
||||
|
|
@ -91,7 +100,7 @@ def find_template_from_model(model_path: pathlib.Path):
|
|||
if template_name in model_name.lower():
|
||||
return template_name
|
||||
else:
|
||||
raise LookupError("Could not find template from model name.")
|
||||
raise TemplateLoadError("Could not find template from model name.")
|
||||
|
||||
|
||||
def get_template_from_file(prompt_template_name: str):
|
||||
|
|
@ -105,18 +114,50 @@ def get_template_from_file(prompt_template_name: str):
|
|||
)
|
||||
else:
|
||||
# Let the user know if the template file isn't found
|
||||
raise FileNotFoundError(f'Template "{prompt_template_name}" not found.')
|
||||
raise TemplateLoadError(
|
||||
f'Chat template "{prompt_template_name}" not found in files.'
|
||||
)
|
||||
|
||||
|
||||
# Get a template from a JSON file
|
||||
# Requires a key and template name
|
||||
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
|
||||
def get_template_from_model_json(
|
||||
json_path: pathlib.Path, key: str, name: Optional[str] = None
|
||||
):
|
||||
"""Get a template from a JSON file. Requires a key and template name"""
|
||||
if json_path.exists():
|
||||
with open(json_path, "r", encoding="utf8") as config_file:
|
||||
model_config = json.load(config_file)
|
||||
chat_template = model_config.get(key)
|
||||
if chat_template:
|
||||
return PromptTemplate(name=name, template=chat_template)
|
||||
else:
|
||||
raise FileNotFoundError(f'Model JSON path "{json_path}" not found.')
|
||||
if not json_path.exists():
|
||||
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
|
||||
|
||||
with open(json_path, "r", encoding="utf8") as config_file:
|
||||
model_config = json.load(config_file)
|
||||
chat_template = model_config.get(key)
|
||||
|
||||
if not chat_template:
|
||||
raise TemplateLoadError(
|
||||
"Could not find a value from chat_template key in the passed JSON. "
|
||||
"Check the tokenizer config?"
|
||||
)
|
||||
|
||||
if isinstance(chat_template, list):
|
||||
# Handles the new list style of chat templates
|
||||
if name:
|
||||
wrapped_template = next(
|
||||
(x for x in chat_template if x.get("name") == name),
|
||||
{},
|
||||
)
|
||||
else:
|
||||
wrapped_template = chat_template[0]
|
||||
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")
|
||||
|
||||
selected_template = wrapped_template.get("template")
|
||||
|
||||
if selected_template:
|
||||
return PromptTemplate(name=name, template=selected_template)
|
||||
else:
|
||||
raise TemplateLoadError(
|
||||
f'Chat template with name "{name}" not found '
|
||||
"in model templates list."
|
||||
)
|
||||
else:
|
||||
# Can safely assume the chat template is the old style
|
||||
return PromptTemplate(name="from_tokenizer_config", template=chat_template)
|
||||
|
|
|
|||
|
|
@ -107,7 +107,9 @@ model:
|
|||
# Possible values FP16, FP8, Q4. (default: FP16)
|
||||
#cache_mode: FP16
|
||||
|
||||
# Set the prompt template for this model. If empty, chat completions will be disabled. (default: Empty)
|
||||
# Set the prompt template for this model. If empty, attempts to look for the model's chat template. (default: None)
|
||||
# If a model contains multiple templates in its tokenizer_config.json, set prompt_template to the name
|
||||
# of the template you want to use.
|
||||
# NOTE: Only works with chat completion message lists!
|
||||
#prompt_template:
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue