Templates: Switch to async jinja engine
This prevents any possible blocking of the event loop due to template rendering. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
b4752c1e62
commit
a51acb9db4
3 changed files with 17 additions and 14 deletions
|
|
@ -1,6 +1,5 @@
|
||||||
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
||||||
|
|
||||||
from functools import lru_cache
|
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
from importlib.metadata import version as package_version
|
from importlib.metadata import version as package_version
|
||||||
|
|
@ -33,11 +32,11 @@ class PromptTemplate:
|
||||||
raw_template: str
|
raw_template: str
|
||||||
template: Template
|
template: Template
|
||||||
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
|
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
|
||||||
trim_blocks=True, lstrip_blocks=True
|
trim_blocks=True, lstrip_blocks=True, enable_async=True
|
||||||
)
|
)
|
||||||
metadata: Optional[TemplateMetadata] = None
|
metadata: Optional[TemplateMetadata] = None
|
||||||
|
|
||||||
def extract_metadata(self, template_vars: dict):
|
async def extract_metadata(self, template_vars: dict):
|
||||||
"""
|
"""
|
||||||
Returns deserialized template metadata from a chat template.
|
Returns deserialized template metadata from a chat template.
|
||||||
|
|
||||||
|
|
@ -52,7 +51,7 @@ class PromptTemplate:
|
||||||
|
|
||||||
template_metadata = TemplateMetadata()
|
template_metadata = TemplateMetadata()
|
||||||
|
|
||||||
template_module = self.template.make_module(template_vars)
|
template_module = await self.template.make_module_async(template_vars)
|
||||||
|
|
||||||
if hasattr(template_module, "stop_strings"):
|
if hasattr(template_module, "stop_strings"):
|
||||||
if isinstance(template_module.stop_strings, list):
|
if isinstance(template_module.stop_strings, list):
|
||||||
|
|
@ -74,7 +73,7 @@ class PromptTemplate:
|
||||||
self.metadata = template_metadata
|
self.metadata = template_metadata
|
||||||
return template_metadata
|
return template_metadata
|
||||||
|
|
||||||
def render(self, template_vars: dict):
|
async def render(self, template_vars: dict):
|
||||||
"""Get a prompt from a template and a list of messages."""
|
"""Get a prompt from a template and a list of messages."""
|
||||||
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
|
|
@ -84,7 +83,7 @@ class PromptTemplate:
|
||||||
"pip install --upgrade jinja2"
|
"pip install --upgrade jinja2"
|
||||||
)
|
)
|
||||||
|
|
||||||
rendered_template = self.template.render(**template_vars)
|
rendered_template = await self.template.render_async(**template_vars)
|
||||||
|
|
||||||
return rendered_template
|
return rendered_template
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -110,7 +110,7 @@ async def chat_completion_request(
|
||||||
if isinstance(data.messages, str):
|
if isinstance(data.messages, str):
|
||||||
prompt = data.messages
|
prompt = data.messages
|
||||||
else:
|
else:
|
||||||
prompt = format_prompt_with_template(data)
|
prompt = await format_prompt_with_template(data)
|
||||||
|
|
||||||
# Set an empty JSON schema if the request wants a JSON response
|
# Set an empty JSON schema if the request wants a JSON response
|
||||||
if data.response_format.type == "json":
|
if data.response_format.type == "json":
|
||||||
|
|
|
||||||
|
|
@ -178,10 +178,10 @@ def _create_stream_chunk(
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
def _append_template_metadata(data: ChatCompletionRequest):
|
async def _append_template_metadata(data: ChatCompletionRequest):
|
||||||
"""Adding metadata is a one-time process."""
|
"""Adding metadata is a one-time process."""
|
||||||
|
|
||||||
template_metadata = model.container.prompt_template.extract_metadata(
|
template_metadata = await model.container.prompt_template.extract_metadata(
|
||||||
data.template_vars
|
data.template_vars
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -200,7 +200,7 @@ def _append_template_metadata(data: ChatCompletionRequest):
|
||||||
data.stop.extend(template_metadata.tool_starts)
|
data.stop.extend(template_metadata.tool_starts)
|
||||||
|
|
||||||
|
|
||||||
def format_prompt_with_template(
|
async def format_prompt_with_template(
|
||||||
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
|
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -242,7 +242,7 @@ def format_prompt_with_template(
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt = model.container.prompt_template.render(data.template_vars)
|
prompt = await model.container.prompt_template.render(data.template_vars)
|
||||||
|
|
||||||
# Append response prefix if present
|
# Append response prefix if present
|
||||||
if data.response_prefix:
|
if data.response_prefix:
|
||||||
|
|
@ -261,7 +261,9 @@ def format_prompt_with_template(
|
||||||
prompt = prompt.removeprefix(bos_token)
|
prompt = prompt.removeprefix(bos_token)
|
||||||
|
|
||||||
# Add template metadata
|
# Add template metadata
|
||||||
_append_template_metadata(data)
|
await _append_template_metadata(data)
|
||||||
|
print(prompt)
|
||||||
|
print(model.container.prompt_template.metadata.tool_starts)
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
@ -441,11 +443,13 @@ async def generate_tool_calls(
|
||||||
if gen["stop_str"] in tool_data.tool_call_start:
|
if gen["stop_str"] in tool_data.tool_call_start:
|
||||||
if "text" in gen:
|
if "text" in gen:
|
||||||
# non streaming, all generations will have the text they generated
|
# non streaming, all generations will have the text they generated
|
||||||
pre_tool_prompt = format_prompt_with_template(data, gen["text"])
|
pre_tool_prompt = await format_prompt_with_template(data, gen["text"])
|
||||||
elif current_generations is not None:
|
elif current_generations is not None:
|
||||||
# streaming, we wont have text in the generation,
|
# streaming, we wont have text in the generation,
|
||||||
# we'll have to use the current_generations
|
# we'll have to use the current_generations
|
||||||
pre_tool_prompt = format_prompt_with_template(data, current_generations)
|
pre_tool_prompt = await format_prompt_with_template(
|
||||||
|
data, current_generations
|
||||||
|
)
|
||||||
|
|
||||||
gen_tasks.append(
|
gen_tasks.append(
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue