Templates: Revert to load metadata on runtime

Metadata is generated via a template's module. This requires a single
iteration through the template. If a template tries to access a passed
variable that doesn't exist, it will error.

Therefore, generate the metadata at runtime to prevent these errors
from happening. To optimize further, cache the metadata after the
first generation to prevent the expensive call of making a template
module.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-08-17 11:30:50 -04:00
parent 617ac12150
commit b4752c1e62
2 changed files with 19 additions and 6 deletions

View file

@ -1,5 +1,6 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""
from functools import lru_cache
import json
import pathlib
from importlib.metadata import version as package_version
@ -34,14 +35,24 @@ class PromptTemplate:
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True
)
metadata: TemplateMetadata
metadata: Optional[TemplateMetadata] = None
def extract_metadata(self):
"""Returns deserialized template metadata from a chat template."""
def extract_metadata(self, template_vars: dict):
"""
Returns deserialized template metadata from a chat template.
NOTE: Requires all template vars to be passed in since the template
is run once to make a module and errors can result.
"""
# No need to extract new metadata if it already exists
# This might be removed if stored metadata becomes arbitrary
if self.metadata:
return self.metadata
template_metadata = TemplateMetadata()
template_module = self.template.make_module()
template_module = self.template.make_module(template_vars)
if hasattr(template_module, "stop_strings"):
if isinstance(template_module.stop_strings, list):
@ -60,6 +71,7 @@ class PromptTemplate:
if isinstance(template_module.tool_start_token, int):
template_metadata.tool_starts.append(template_module.tool_start_token)
self.metadata = template_metadata
return template_metadata
def render(self, template_vars: dict):
@ -93,7 +105,6 @@ class PromptTemplate:
self.name = name
self.raw_template = raw_template
self.template = self.compile(raw_template)
self.metadata = self.extract_metadata()
@classmethod
def from_file(self, prompt_template_name: str):