Templates: Revert to load metadata on runtime
Metadata is generated via a template's module. This requires a single iteration through the template. If a template tries to access a passed variable that doesn't exist, it will error. Therefore, generate the metadata at runtime to prevent these errors from happening. To optimize further, cache the metadata after the first generation to prevent the expensive call of making a template module. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
617ac12150
commit
b4752c1e62
2 changed files with 19 additions and 6 deletions
|
|
@ -1,5 +1,6 @@
|
|||
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
||||
|
||||
from functools import lru_cache
|
||||
import json
|
||||
import pathlib
|
||||
from importlib.metadata import version as package_version
|
||||
|
|
@ -34,14 +35,24 @@ class PromptTemplate:
|
|||
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True, lstrip_blocks=True
|
||||
)
|
||||
metadata: TemplateMetadata
|
||||
metadata: Optional[TemplateMetadata] = None
|
||||
|
||||
def extract_metadata(self):
|
||||
"""Returns deserialized template metadata from a chat template."""
|
||||
def extract_metadata(self, template_vars: dict):
|
||||
"""
|
||||
Returns deserialized template metadata from a chat template.
|
||||
|
||||
NOTE: Requires all template vars to be passed in since the template
|
||||
is run once to make a module and errors can result.
|
||||
"""
|
||||
|
||||
# No need to extract new metadata if it already exists
|
||||
# This might be removed if stored metadata becomes arbitrary
|
||||
if self.metadata:
|
||||
return self.metadata
|
||||
|
||||
template_metadata = TemplateMetadata()
|
||||
|
||||
template_module = self.template.make_module()
|
||||
template_module = self.template.make_module(template_vars)
|
||||
|
||||
if hasattr(template_module, "stop_strings"):
|
||||
if isinstance(template_module.stop_strings, list):
|
||||
|
|
@ -60,6 +71,7 @@ class PromptTemplate:
|
|||
if isinstance(template_module.tool_start_token, int):
|
||||
template_metadata.tool_starts.append(template_module.tool_start_token)
|
||||
|
||||
self.metadata = template_metadata
|
||||
return template_metadata
|
||||
|
||||
def render(self, template_vars: dict):
|
||||
|
|
@ -93,7 +105,6 @@ class PromptTemplate:
|
|||
self.name = name
|
||||
self.raw_template = raw_template
|
||||
self.template = self.compile(raw_template)
|
||||
self.metadata = self.extract_metadata()
|
||||
|
||||
@classmethod
|
||||
def from_file(self, prompt_template_name: str):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue