Templates: Switch to async jinja engine

This prevents any possible blocking of the event loop due to template
rendering.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-08-17 12:01:59 -04:00
parent b4752c1e62
commit a51acb9db4
3 changed files with 17 additions and 14 deletions

View file

@ -178,10 +178,10 @@ def _create_stream_chunk(
return chunk
def _append_template_metadata(data: ChatCompletionRequest):
async def _append_template_metadata(data: ChatCompletionRequest):
"""Adding metadata is a one-time process."""
template_metadata = model.container.prompt_template.extract_metadata(
template_metadata = await model.container.prompt_template.extract_metadata(
data.template_vars
)
@ -200,7 +200,7 @@ def _append_template_metadata(data: ChatCompletionRequest):
data.stop.extend(template_metadata.tool_starts)
def format_prompt_with_template(
async def format_prompt_with_template(
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
):
"""
@ -242,7 +242,7 @@ def format_prompt_with_template(
}
)
prompt = model.container.prompt_template.render(data.template_vars)
prompt = await model.container.prompt_template.render(data.template_vars)
# Append response prefix if present
if data.response_prefix:
@ -261,7 +261,9 @@ def format_prompt_with_template(
prompt = prompt.removeprefix(bos_token)
# Add template metadata
_append_template_metadata(data)
await _append_template_metadata(data)
print(prompt)
print(model.container.prompt_template.metadata.tool_starts)
return prompt
@ -441,11 +443,13 @@ async def generate_tool_calls(
if gen["stop_str"] in tool_data.tool_call_start:
if "text" in gen:
# non streaming, all generations will have the text they generated
pre_tool_prompt = format_prompt_with_template(data, gen["text"])
pre_tool_prompt = await format_prompt_with_template(data, gen["text"])
elif current_generations is not None:
# streaming, we wont have text in the generation,
# we'll have to use the current_generations
pre_tool_prompt = format_prompt_with_template(data, current_generations)
pre_tool_prompt = await format_prompt_with_template(
data, current_generations
)
gen_tasks.append(
asyncio.create_task(