Templates: Switch to async jinja engine

This prevents any possible blocking of the event loop due to template
rendering.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-08-17 12:01:59 -04:00
parent b4752c1e62
commit a51acb9db4
3 changed files with 17 additions and 14 deletions

View file

@ -1,6 +1,5 @@
"""Small replication of AutoTokenizer's chat template system for efficiency""" """Small replication of AutoTokenizer's chat template system for efficiency"""
from functools import lru_cache
import json import json
import pathlib import pathlib
from importlib.metadata import version as package_version from importlib.metadata import version as package_version
@ -33,11 +32,11 @@ class PromptTemplate:
raw_template: str raw_template: str
template: Template template: Template
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment( environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True trim_blocks=True, lstrip_blocks=True, enable_async=True
) )
metadata: Optional[TemplateMetadata] = None metadata: Optional[TemplateMetadata] = None
def extract_metadata(self, template_vars: dict): async def extract_metadata(self, template_vars: dict):
""" """
Returns deserialized template metadata from a chat template. Returns deserialized template metadata from a chat template.
@ -52,7 +51,7 @@ class PromptTemplate:
template_metadata = TemplateMetadata() template_metadata = TemplateMetadata()
template_module = self.template.make_module(template_vars) template_module = await self.template.make_module_async(template_vars)
if hasattr(template_module, "stop_strings"): if hasattr(template_module, "stop_strings"):
if isinstance(template_module.stop_strings, list): if isinstance(template_module.stop_strings, list):
@ -74,7 +73,7 @@ class PromptTemplate:
self.metadata = template_metadata self.metadata = template_metadata
return template_metadata return template_metadata
def render(self, template_vars: dict): async def render(self, template_vars: dict):
"""Get a prompt from a template and a list of messages.""" """Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"): if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError( raise ImportError(
@ -84,7 +83,7 @@ class PromptTemplate:
"pip install --upgrade jinja2" "pip install --upgrade jinja2"
) )
rendered_template = self.template.render(**template_vars) rendered_template = await self.template.render_async(**template_vars)
return rendered_template return rendered_template

View file

@ -110,7 +110,7 @@ async def chat_completion_request(
if isinstance(data.messages, str): if isinstance(data.messages, str):
prompt = data.messages prompt = data.messages
else: else:
prompt = format_prompt_with_template(data) prompt = await format_prompt_with_template(data)
# Set an empty JSON schema if the request wants a JSON response # Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json": if data.response_format.type == "json":

View file

@ -178,10 +178,10 @@ def _create_stream_chunk(
return chunk return chunk
def _append_template_metadata(data: ChatCompletionRequest): async def _append_template_metadata(data: ChatCompletionRequest):
"""Adding metadata is a one-time process.""" """Adding metadata is a one-time process."""
template_metadata = model.container.prompt_template.extract_metadata( template_metadata = await model.container.prompt_template.extract_metadata(
data.template_vars data.template_vars
) )
@ -200,7 +200,7 @@ def _append_template_metadata(data: ChatCompletionRequest):
data.stop.extend(template_metadata.tool_starts) data.stop.extend(template_metadata.tool_starts)
def format_prompt_with_template( async def format_prompt_with_template(
data: ChatCompletionRequest, tool_precursor: Optional[str] = None data: ChatCompletionRequest, tool_precursor: Optional[str] = None
): ):
""" """
@ -242,7 +242,7 @@ def format_prompt_with_template(
} }
) )
prompt = model.container.prompt_template.render(data.template_vars) prompt = await model.container.prompt_template.render(data.template_vars)
# Append response prefix if present # Append response prefix if present
if data.response_prefix: if data.response_prefix:
@ -261,7 +261,9 @@ def format_prompt_with_template(
prompt = prompt.removeprefix(bos_token) prompt = prompt.removeprefix(bos_token)
# Add template metadata # Add template metadata
_append_template_metadata(data) await _append_template_metadata(data)
print(prompt)
print(model.container.prompt_template.metadata.tool_starts)
return prompt return prompt
@ -441,11 +443,13 @@ async def generate_tool_calls(
if gen["stop_str"] in tool_data.tool_call_start: if gen["stop_str"] in tool_data.tool_call_start:
if "text" in gen: if "text" in gen:
# non streaming, all generations will have the text they generated # non streaming, all generations will have the text they generated
pre_tool_prompt = format_prompt_with_template(data, gen["text"]) pre_tool_prompt = await format_prompt_with_template(data, gen["text"])
elif current_generations is not None: elif current_generations is not None:
# streaming, we wont have text in the generation, # streaming, we wont have text in the generation,
# we'll have to use the current_generations # we'll have to use the current_generations
pre_tool_prompt = format_prompt_with_template(data, current_generations) pre_tool_prompt = await format_prompt_with_template(
data, current_generations
)
gen_tasks.append( gen_tasks.append(
asyncio.create_task( asyncio.create_task(