Templates: Switch to async jinja engine

This prevents any possible blocking of the event loop due to template
rendering.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-08-17 12:01:59 -04:00
parent b4752c1e62
commit a51acb9db4
3 changed files with 17 additions and 14 deletions

View file

@ -1,6 +1,5 @@
"""Small replication of AutoTokenizer's chat template system for efficiency"""
from functools import lru_cache
import json
import pathlib
from importlib.metadata import version as package_version
@ -33,11 +32,11 @@ class PromptTemplate:
raw_template: str
template: Template
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True
trim_blocks=True, lstrip_blocks=True, enable_async=True
)
metadata: Optional[TemplateMetadata] = None
def extract_metadata(self, template_vars: dict):
async def extract_metadata(self, template_vars: dict):
"""
Returns deserialized template metadata from a chat template.
@ -52,7 +51,7 @@ class PromptTemplate:
template_metadata = TemplateMetadata()
template_module = self.template.make_module(template_vars)
template_module = await self.template.make_module_async(template_vars)
if hasattr(template_module, "stop_strings"):
if isinstance(template_module.stop_strings, list):
@ -74,7 +73,7 @@ class PromptTemplate:
self.metadata = template_metadata
return template_metadata
def render(self, template_vars: dict):
async def render(self, template_vars: dict):
"""Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
@ -84,7 +83,7 @@ class PromptTemplate:
"pip install --upgrade jinja2"
)
rendered_template = self.template.render(**template_vars)
rendered_template = await self.template.render_async(**template_vars)
return rendered_template