Templates: Migrate to class
Having many utility functions for initialization doesn't make much sense. Instead, handle anything regarding template creation inside the class which reduces the amount of function imports. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
9f93505bc1
commit
cab789e685
4 changed files with 122 additions and 132 deletions
|
|
@ -34,8 +34,6 @@ from common.templating import (
|
|||
PromptTemplate,
|
||||
TemplateLoadError,
|
||||
find_template_from_model,
|
||||
get_template_from_model_json,
|
||||
get_template_from_file,
|
||||
)
|
||||
from common.transformers_utils import GenerationConfig
|
||||
from common.utils import coalesce, unwrap
|
||||
|
|
@ -276,18 +274,18 @@ class ExllamaV2Container:
|
|||
logger.info("Attempting to load a prompt template if present.")
|
||||
|
||||
find_template_functions = [
|
||||
lambda: get_template_from_model_json(
|
||||
lambda: PromptTemplate.from_model_json(
|
||||
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||
"chat_template",
|
||||
),
|
||||
lambda: get_template_from_file(find_template_from_model(model_directory)),
|
||||
lambda: PromptTemplate.from_file(find_template_from_model(model_directory)),
|
||||
]
|
||||
|
||||
# Add lookup from prompt template name if provided
|
||||
if prompt_template_name:
|
||||
find_template_functions[:0] = [
|
||||
lambda: get_template_from_file(prompt_template_name),
|
||||
lambda: get_template_from_model_json(
|
||||
lambda: PromptTemplate.from_file(prompt_template_name),
|
||||
lambda: PromptTemplate.from_model_json(
|
||||
pathlib.Path(self.config.model_dir) / "tokenizer_config.json",
|
||||
"chat_template",
|
||||
prompt_template_name,
|
||||
|
|
|
|||
|
|
@ -2,83 +2,143 @@
|
|||
|
||||
import json
|
||||
import pathlib
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import version as package_version
|
||||
from typing import Optional
|
||||
from jinja2 import Template, TemplateError
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from loguru import logger
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
|
||||
from common.utils import unwrap
|
||||
|
||||
|
||||
class PromptTemplate(BaseModel):
|
||||
"""A template for chat completion prompts."""
|
||||
|
||||
name: str
|
||||
template: str
|
||||
|
||||
|
||||
class TemplateLoadError(Exception):
|
||||
"""Raised on prompt template load"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_prompt_from_template(prompt_template: PromptTemplate, template_vars: dict):
|
||||
"""Get a prompt from a template and a list of messages."""
|
||||
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
||||
raise ImportError(
|
||||
"Parsing these chat completion messages requires jinja2 3.0.0 "
|
||||
f"or greater. Current version: {package_version('jinja2')}\n"
|
||||
"Please upgrade jinja by running the following command: "
|
||||
"pip install --upgrade jinja2"
|
||||
)
|
||||
class PromptTemplate:
|
||||
"""A template for chat completion prompts."""
|
||||
|
||||
compiled_template = _compile_template(prompt_template.template)
|
||||
rendered_template = compiled_template.render(**template_vars)
|
||||
template_stop_strings = _get_template_stop_strings(compiled_template, template_vars)
|
||||
name: str
|
||||
raw_template: str
|
||||
template: Template
|
||||
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True, lstrip_blocks=True
|
||||
)
|
||||
|
||||
return rendered_template, template_stop_strings
|
||||
def stop_strings(self, template_vars: dict):
|
||||
"""Appends extra stop strings if present in a chat template."""
|
||||
|
||||
extra_stop_strings = []
|
||||
template_module = self.template.make_module(template_vars)
|
||||
|
||||
# Inspired from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
|
||||
# TODO: Migrate to compile when template is loaded (removes the need for an lru_cache)
|
||||
@lru_cache
|
||||
def _compile_template(template: str):
|
||||
"""Compiles a Jinja2 template"""
|
||||
if hasattr(template_module, "stop_strings"):
|
||||
if isinstance(template_module.stop_strings, list):
|
||||
extra_stop_strings += template_module.stop_strings
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping append of stopping strings from chat template "
|
||||
"because stop_strings isn't a list."
|
||||
)
|
||||
|
||||
# Exception handler
|
||||
def raise_exception(message):
|
||||
raise TemplateError(message)
|
||||
return extra_stop_strings
|
||||
|
||||
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
||||
jinja_env.globals["raise_exception"] = raise_exception
|
||||
|
||||
jinja_template = jinja_env.from_string(template)
|
||||
return jinja_template
|
||||
|
||||
|
||||
# TODO: Migrate to run during template load
|
||||
def _get_template_stop_strings(prompt_template: Template, template_vars: dict):
|
||||
"""Appends extra stop strings if present in a chat template."""
|
||||
|
||||
extra_stop_strings = []
|
||||
template_module = prompt_template.make_module(template_vars)
|
||||
|
||||
if hasattr(template_module, "stop_strings"):
|
||||
if isinstance(template_module.stop_strings, list):
|
||||
extra_stop_strings += template_module.stop_strings
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping append of stopping strings from chat template "
|
||||
"because stop_strings isn't a list."
|
||||
def render(self, template_vars: dict):
|
||||
"""Get a prompt from a template and a list of messages."""
|
||||
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
||||
raise ImportError(
|
||||
"Parsing these chat completion messages requires jinja2 3.0.0 "
|
||||
f"or greater. Current version: {package_version('jinja2')}\n"
|
||||
"Please upgrade jinja by running the following command: "
|
||||
"pip install --upgrade jinja2"
|
||||
)
|
||||
|
||||
return extra_stop_strings
|
||||
rendered_template = self.template.render(**template_vars)
|
||||
template_stop_strings = self.stop_strings(template_vars)
|
||||
|
||||
return rendered_template, template_stop_strings
|
||||
|
||||
def compile(self, template_str: str):
|
||||
"""Compiles and stores a jinja2 template"""
|
||||
|
||||
# Exception handler
|
||||
def raise_exception(message):
|
||||
raise TemplateError(message)
|
||||
|
||||
self.environment.globals["raise_exception"] = raise_exception
|
||||
|
||||
return self.environment.from_string(template_str)
|
||||
|
||||
def __init__(self, name: str, raw_template: str):
|
||||
"""Initializer for the PromptTemplate class."""
|
||||
|
||||
self.name = name
|
||||
self.raw_template = raw_template
|
||||
self.template = self.compile(raw_template)
|
||||
|
||||
@classmethod
|
||||
def from_file(self, prompt_template_name: str):
|
||||
"""Get a template from a jinja file."""
|
||||
|
||||
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
||||
if template_path.exists():
|
||||
with open(template_path, "r", encoding="utf8") as raw_template_stream:
|
||||
return PromptTemplate(
|
||||
name=prompt_template_name,
|
||||
raw_template=raw_template_stream.read(),
|
||||
)
|
||||
else:
|
||||
# Let the user know if the template file isn't found
|
||||
raise TemplateLoadError(
|
||||
f'Chat template "{prompt_template_name}" not found in files.'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_model_json(
|
||||
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
|
||||
):
|
||||
"""Get a template from a JSON file. Requires a key and template name"""
|
||||
if not json_path.exists():
|
||||
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
|
||||
|
||||
with open(json_path, "r", encoding="utf8") as config_file:
|
||||
model_config = json.load(config_file)
|
||||
chat_template = model_config.get(key)
|
||||
|
||||
if not chat_template:
|
||||
raise TemplateLoadError(
|
||||
"Could not find a value from chat_template key in the passed JSON. "
|
||||
"Check the tokenizer config?"
|
||||
)
|
||||
|
||||
if isinstance(chat_template, list):
|
||||
# Handles the new list style of chat templates
|
||||
if name:
|
||||
wrapped_template = next(
|
||||
(x for x in chat_template if x.get("name") == name),
|
||||
{},
|
||||
)
|
||||
else:
|
||||
wrapped_template = chat_template[0]
|
||||
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")
|
||||
|
||||
selected_template = wrapped_template.get("template")
|
||||
|
||||
if selected_template:
|
||||
return PromptTemplate(name=name, raw_template=selected_template)
|
||||
else:
|
||||
raise TemplateLoadError(
|
||||
f'Chat template with name "{name}" not found '
|
||||
"in model templates list."
|
||||
)
|
||||
else:
|
||||
# Can safely assume the chat template is the old style
|
||||
return PromptTemplate(
|
||||
name="from_tokenizer_config",
|
||||
raw_template=chat_template,
|
||||
)
|
||||
|
||||
|
||||
def get_all_templates():
|
||||
|
|
@ -101,63 +161,3 @@ def find_template_from_model(model_path: pathlib.Path):
|
|||
return template_name
|
||||
else:
|
||||
raise TemplateLoadError("Could not find template from model name.")
|
||||
|
||||
|
||||
def get_template_from_file(prompt_template_name: str):
|
||||
"""Get a template from a jinja file."""
|
||||
|
||||
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
||||
if template_path.exists():
|
||||
with open(template_path, "r", encoding="utf8") as raw_template:
|
||||
return PromptTemplate(
|
||||
name=prompt_template_name, template=raw_template.read()
|
||||
)
|
||||
else:
|
||||
# Let the user know if the template file isn't found
|
||||
raise TemplateLoadError(
|
||||
f'Chat template "{prompt_template_name}" not found in files.'
|
||||
)
|
||||
|
||||
|
||||
# Get a template from a JSON file
|
||||
# Requires a key and template name
|
||||
def get_template_from_model_json(
|
||||
json_path: pathlib.Path, key: str, name: Optional[str] = None
|
||||
):
|
||||
"""Get a template from a JSON file. Requires a key and template name"""
|
||||
if not json_path.exists():
|
||||
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
|
||||
|
||||
with open(json_path, "r", encoding="utf8") as config_file:
|
||||
model_config = json.load(config_file)
|
||||
chat_template = model_config.get(key)
|
||||
|
||||
if not chat_template:
|
||||
raise TemplateLoadError(
|
||||
"Could not find a value from chat_template key in the passed JSON. "
|
||||
"Check the tokenizer config?"
|
||||
)
|
||||
|
||||
if isinstance(chat_template, list):
|
||||
# Handles the new list style of chat templates
|
||||
if name:
|
||||
wrapped_template = next(
|
||||
(x for x in chat_template if x.get("name") == name),
|
||||
{},
|
||||
)
|
||||
else:
|
||||
wrapped_template = chat_template[0]
|
||||
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")
|
||||
|
||||
selected_template = wrapped_template.get("template")
|
||||
|
||||
if selected_template:
|
||||
return PromptTemplate(name=name, template=selected_template)
|
||||
else:
|
||||
raise TemplateLoadError(
|
||||
f'Chat template with name "{name}" not found '
|
||||
"in model templates list."
|
||||
)
|
||||
else:
|
||||
# Can safely assume the chat template is the old style
|
||||
return PromptTemplate(name="from_tokenizer_config", template=chat_template)
|
||||
|
|
|
|||
|
|
@ -14,11 +14,7 @@ from common.concurrency import (
|
|||
generate_with_semaphore,
|
||||
)
|
||||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.templating import (
|
||||
get_all_templates,
|
||||
get_prompt_from_template,
|
||||
get_template_from_file,
|
||||
)
|
||||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import coalesce, unwrap
|
||||
from endpoints.OAI.types.auth import AuthPermissionResponse
|
||||
from endpoints.OAI.types.completion import CompletionRequest
|
||||
|
|
@ -224,8 +220,7 @@ async def switch_template(data: TemplateSwitchRequest):
|
|||
raise HTTPException(400, error_message)
|
||||
|
||||
try:
|
||||
template = get_template_from_file(data.name)
|
||||
model.container.prompt_template = template
|
||||
model.container.prompt_template = PromptTemplate.from_file(data.name)
|
||||
except FileNotFoundError as e:
|
||||
error_message = handle_request_error(
|
||||
f"The template name {data.name} doesn't exist. Check the spelling?",
|
||||
|
|
@ -402,9 +397,7 @@ async def encode_tokens(data: TokenEncodeRequest):
|
|||
**special_tokens_dict,
|
||||
}
|
||||
|
||||
text, _ = get_prompt_from_template(
|
||||
model.container.prompt_template, template_vars
|
||||
)
|
||||
text, _ = model.container.prompt_template.render(template_vars)
|
||||
|
||||
raw_tokens = model.container.encode_tokens(text, **data.get_params())
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from common.networking import (
|
|||
handle_request_disconnect,
|
||||
handle_request_error,
|
||||
)
|
||||
from common.templating import get_prompt_from_template
|
||||
from common.utils import unwrap
|
||||
from endpoints.OAI.types.chat_completion import (
|
||||
ChatCompletionLogprobs,
|
||||
|
|
@ -150,8 +149,8 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
|||
}
|
||||
)
|
||||
|
||||
prompt, template_stop_strings = get_prompt_from_template(
|
||||
model.container.prompt_template, data.template_vars
|
||||
prompt, template_stop_strings = model.container.prompt_template.render(
|
||||
data.template_vars
|
||||
)
|
||||
|
||||
# Append template stop strings
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue