From a51acb9db4f2cf6562cbc8ce1f78a2eed314d00d Mon Sep 17 00:00:00 2001 From: kingbri Date: Sat, 17 Aug 2024 12:01:59 -0400 Subject: [PATCH] Templates: Switch to async jinja engine This prevents any possible blocking of the event loop due to template rendering. Signed-off-by: kingbri --- common/templating.py | 11 +++++------ endpoints/OAI/router.py | 2 +- endpoints/OAI/utils/chat_completion.py | 18 +++++++++++------- 3 files changed, 17 insertions(+), 14 deletions(-) diff --git a/common/templating.py b/common/templating.py index 021d1d4..47299ff 100644 --- a/common/templating.py +++ b/common/templating.py @@ -1,6 +1,5 @@ """Small replication of AutoTokenizer's chat template system for efficiency""" -from functools import lru_cache import json import pathlib from importlib.metadata import version as package_version @@ -33,11 +32,11 @@ class PromptTemplate: raw_template: str template: Template environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment( - trim_blocks=True, lstrip_blocks=True + trim_blocks=True, lstrip_blocks=True, enable_async=True ) metadata: Optional[TemplateMetadata] = None - def extract_metadata(self, template_vars: dict): + async def extract_metadata(self, template_vars: dict): """ Returns deserialized template metadata from a chat template. @@ -52,7 +51,7 @@ class PromptTemplate: template_metadata = TemplateMetadata() - template_module = self.template.make_module(template_vars) + template_module = await self.template.make_module_async(template_vars) if hasattr(template_module, "stop_strings"): if isinstance(template_module.stop_strings, list): @@ -74,7 +73,7 @@ class PromptTemplate: self.metadata = template_metadata return template_metadata - def render(self, template_vars: dict): + async def render(self, template_vars: dict): """Get a prompt from a template and a list of messages.""" if version.parse(package_version("jinja2")) < version.parse("3.0.0"): raise ImportError( @@ -84,7 +83,7 @@ class PromptTemplate: "pip install --upgrade jinja2" ) - rendered_template = self.template.render(**template_vars) + rendered_template = await self.template.render_async(**template_vars) return rendered_template diff --git a/endpoints/OAI/router.py b/endpoints/OAI/router.py index c1ee343..eb2445a 100644 --- a/endpoints/OAI/router.py +++ b/endpoints/OAI/router.py @@ -110,7 +110,7 @@ async def chat_completion_request( if isinstance(data.messages, str): prompt = data.messages else: - prompt = format_prompt_with_template(data) + prompt = await format_prompt_with_template(data) # Set an empty JSON schema if the request wants a JSON response if data.response_format.type == "json": diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index d924b5e..cb907f5 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -178,10 +178,10 @@ def _create_stream_chunk( return chunk -def _append_template_metadata(data: ChatCompletionRequest): +async def _append_template_metadata(data: ChatCompletionRequest): """Adding metadata is a one-time process.""" - template_metadata = model.container.prompt_template.extract_metadata( + template_metadata = await model.container.prompt_template.extract_metadata( data.template_vars ) @@ -200,7 +200,7 @@ def _append_template_metadata(data: ChatCompletionRequest): data.stop.extend(template_metadata.tool_starts) -def format_prompt_with_template( +async def format_prompt_with_template( data: ChatCompletionRequest, tool_precursor: Optional[str] = None ): """ @@ -242,7 +242,7 @@ def format_prompt_with_template( } ) - prompt = model.container.prompt_template.render(data.template_vars) + prompt = await model.container.prompt_template.render(data.template_vars) # Append response prefix if present if data.response_prefix: @@ -261,7 +261,9 @@ def format_prompt_with_template( prompt = prompt.removeprefix(bos_token) # Add template metadata - _append_template_metadata(data) + await _append_template_metadata(data) + print(prompt) + print(model.container.prompt_template.metadata.tool_starts) return prompt @@ -441,11 +443,13 @@ async def generate_tool_calls( if gen["stop_str"] in tool_data.tool_call_start: if "text" in gen: # non streaming, all generations will have the text they generated - pre_tool_prompt = format_prompt_with_template(data, gen["text"]) + pre_tool_prompt = await format_prompt_with_template(data, gen["text"]) elif current_generations is not None: # streaming, we wont have text in the generation, # we'll have to use the current_generations - pre_tool_prompt = format_prompt_with_template(data, current_generations) + pre_tool_prompt = await format_prompt_with_template( + data, current_generations + ) gen_tasks.append( asyncio.create_task(