Templates: Migrate to class

Having many utility functions for initialization doesn't make much sense.
Instead, handle anything regarding template creation inside the
class which reduces the amount of function imports.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-21 23:28:14 -04:00
parent 9f93505bc1
commit cab789e685
4 changed files with 122 additions and 132 deletions

View file

@ -34,8 +34,6 @@ from common.templating import (
PromptTemplate,
TemplateLoadError,
find_template_from_model,
get_template_from_model_json,
get_template_from_file,
)
from common.transformers_utils import GenerationConfig
from common.utils import coalesce, unwrap
@ -276,18 +274,18 @@ class ExllamaV2Container:
logger.info("Attempting to load a prompt template if present.")
find_template_functions = [
lambda: get_template_from_model_json(
lambda: PromptTemplate.from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
),
lambda: get_template_from_file(find_template_from_model(model_directory)),
lambda: PromptTemplate.from_file(find_template_from_model(model_directory)),
]
# Add lookup from prompt template name if provided
if prompt_template_name:
find_template_functions[:0] = [
lambda: get_template_from_file(prompt_template_name),
lambda: get_template_from_model_json(
lambda: PromptTemplate.from_file(prompt_template_name),
lambda: PromptTemplate.from_model_json(
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
"chat_template",
prompt_template_name,

View file

@ -2,83 +2,143 @@
import json
import pathlib
from functools import lru_cache
from importlib.metadata import version as package_version
from typing import Optional
from jinja2 import Template, TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
from packaging import version
from pydantic import BaseModel
from common.utils import unwrap
class PromptTemplate(BaseModel):
"""A template for chat completion prompts."""
name: str
template: str
class TemplateLoadError(Exception):
"""Raised on prompt template load"""
pass
def get_prompt_from_template(prompt_template: PromptTemplate, template_vars: dict):
"""Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
"Parsing these chat completion messages requires jinja2 3.0.0 "
f"or greater. Current version: {package_version('jinja2')}\n"
"Please upgrade jinja by running the following command: "
"pip install --upgrade jinja2"
)
class PromptTemplate:
"""A template for chat completion prompts."""
compiled_template = _compile_template(prompt_template.template)
rendered_template = compiled_template.render(**template_vars)
template_stop_strings = _get_template_stop_strings(compiled_template, template_vars)
name: str
raw_template: str
template: Template
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True
)
return rendered_template, template_stop_strings
def stop_strings(self, template_vars: dict):
"""Appends extra stop strings if present in a chat template."""
extra_stop_strings = []
template_module = self.template.make_module(template_vars)
# Inspired from
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
# TODO: Migrate to compile when template is loaded (removes the need for an lru_cache)
@lru_cache
def _compile_template(template: str):
"""Compiles a Jinja2 template"""
if hasattr(template_module, "stop_strings"):
if isinstance(template_module.stop_strings, list):
extra_stop_strings += template_module.stop_strings
else:
logger.warning(
"Skipping append of stopping strings from chat template "
"because stop_strings isn't a list."
)
# Exception handler
def raise_exception(message):
raise TemplateError(message)
return extra_stop_strings
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
jinja_env.globals["raise_exception"] = raise_exception
jinja_template = jinja_env.from_string(template)
return jinja_template
# TODO: Migrate to run during template load
def _get_template_stop_strings(prompt_template: Template, template_vars: dict):
"""Appends extra stop strings if present in a chat template."""
extra_stop_strings = []
template_module = prompt_template.make_module(template_vars)
if hasattr(template_module, "stop_strings"):
if isinstance(template_module.stop_strings, list):
extra_stop_strings += template_module.stop_strings
else:
logger.warning(
"Skipping append of stopping strings from chat template "
"because stop_strings isn't a list."
def render(self, template_vars: dict):
"""Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
"Parsing these chat completion messages requires jinja2 3.0.0 "
f"or greater. Current version: {package_version('jinja2')}\n"
"Please upgrade jinja by running the following command: "
"pip install --upgrade jinja2"
)
return extra_stop_strings
rendered_template = self.template.render(**template_vars)
template_stop_strings = self.stop_strings(template_vars)
return rendered_template, template_stop_strings
def compile(self, template_str: str):
"""Compiles and stores a jinja2 template"""
# Exception handler
def raise_exception(message):
raise TemplateError(message)
self.environment.globals["raise_exception"] = raise_exception
return self.environment.from_string(template_str)
def __init__(self, name: str, raw_template: str):
"""Initializer for the PromptTemplate class."""
self.name = name
self.raw_template = raw_template
self.template = self.compile(raw_template)
@classmethod
def from_file(self, prompt_template_name: str):
"""Get a template from a jinja file."""
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
if template_path.exists():
with open(template_path, "r", encoding="utf8") as raw_template_stream:
return PromptTemplate(
name=prompt_template_name,
raw_template=raw_template_stream.read(),
)
else:
# Let the user know if the template file isn't found
raise TemplateLoadError(
f'Chat template "{prompt_template_name}" not found in files.'
)
@classmethod
def from_model_json(
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
):
"""Get a template from a JSON file. Requires a key and template name"""
if not json_path.exists():
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
chat_template = model_config.get(key)
if not chat_template:
raise TemplateLoadError(
"Could not find a value from chat_template key in the passed JSON. "
"Check the tokenizer config?"
)
if isinstance(chat_template, list):
# Handles the new list style of chat templates
if name:
wrapped_template = next(
(x for x in chat_template if x.get("name") == name),
{},
)
else:
wrapped_template = chat_template[0]
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")
selected_template = wrapped_template.get("template")
if selected_template:
return PromptTemplate(name=name, raw_template=selected_template)
else:
raise TemplateLoadError(
f'Chat template with name "{name}" not found '
"in model templates list."
)
else:
# Can safely assume the chat template is the old style
return PromptTemplate(
name="from_tokenizer_config",
raw_template=chat_template,
)
def get_all_templates():
@ -101,63 +161,3 @@ def find_template_from_model(model_path: pathlib.Path):
return template_name
else:
raise TemplateLoadError("Could not find template from model name.")
def get_template_from_file(prompt_template_name: str):
"""Get a template from a jinja file."""
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
if template_path.exists():
with open(template_path, "r", encoding="utf8") as raw_template:
return PromptTemplate(
name=prompt_template_name, template=raw_template.read()
)
else:
# Let the user know if the template file isn't found
raise TemplateLoadError(
f'Chat template "{prompt_template_name}" not found in files.'
)
# Get a template from a JSON file
# Requires a key and template name
def get_template_from_model_json(
json_path: pathlib.Path, key: str, name: Optional[str] = None
):
"""Get a template from a JSON file. Requires a key and template name"""
if not json_path.exists():
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
with open(json_path, "r", encoding="utf8") as config_file:
model_config = json.load(config_file)
chat_template = model_config.get(key)
if not chat_template:
raise TemplateLoadError(
"Could not find a value from chat_template key in the passed JSON. "
"Check the tokenizer config?"
)
if isinstance(chat_template, list):
# Handles the new list style of chat templates
if name:
wrapped_template = next(
(x for x in chat_template if x.get("name") == name),
{},
)
else:
wrapped_template = chat_template[0]
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")
selected_template = wrapped_template.get("template")
if selected_template:
return PromptTemplate(name=name, template=selected_template)
else:
raise TemplateLoadError(
f'Chat template with name "{name}" not found '
"in model templates list."
)
else:
# Can safely assume the chat template is the old style
return PromptTemplate(name="from_tokenizer_config", template=chat_template)

View file

@ -14,11 +14,7 @@ from common.concurrency import (
generate_with_semaphore,
)
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import (
get_all_templates,
get_prompt_from_template,
get_template_from_file,
)
from common.templating import PromptTemplate, get_all_templates
from common.utils import coalesce, unwrap
from endpoints.OAI.types.auth import AuthPermissionResponse
from endpoints.OAI.types.completion import CompletionRequest
@ -224,8 +220,7 @@ async def switch_template(data: TemplateSwitchRequest):
raise HTTPException(400, error_message)
try:
template = get_template_from_file(data.name)
model.container.prompt_template = template
model.container.prompt_template = PromptTemplate.from_file(data.name)
except FileNotFoundError as e:
error_message = handle_request_error(
f"The template name {data.name} doesn't exist. Check the spelling?",
@ -402,9 +397,7 @@ async def encode_tokens(data: TokenEncodeRequest):
**special_tokens_dict,
}
text, _ = get_prompt_from_template(
model.container.prompt_template, template_vars
)
text, _ = model.container.prompt_template.render(template_vars)
raw_tokens = model.container.encode_tokens(text, **data.get_params())
tokens = unwrap(raw_tokens, [])

View file

@ -15,7 +15,6 @@ from common.networking import (
handle_request_disconnect,
handle_request_error,
)
from common.templating import get_prompt_from_template
from common.utils import unwrap
from endpoints.OAI.types.chat_completion import (
ChatCompletionLogprobs,
@ -150,8 +149,8 @@ def format_prompt_with_template(data: ChatCompletionRequest):
}
)
prompt, template_stop_strings = get_prompt_from_template(
model.container.prompt_template, data.template_vars
prompt, template_stop_strings = model.container.prompt_template.render(
data.template_vars
)
# Append template stop strings