Templates: Fix stop_string parsing

Template modules grab all set vars, including ones that use runtime
vars. If a template var is set to a runtime var and a module is created,
an UndefinedError fires.

Use make_module instead to pass runtime vars when creating a template
module.

Resolves #92

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-02 00:44:04 -04:00
parent 6ecce1604b
commit f9f8c97c6d
2 changed files with 22 additions and 28 deletions

View file

@ -9,7 +9,6 @@ from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
from packaging import version
from pydantic import BaseModel
from typing import Optional, Dict
class PromptTemplate(BaseModel):
@ -19,12 +18,7 @@ class PromptTemplate(BaseModel):
template: str
def get_prompt_from_template(
messages,
prompt_template: PromptTemplate,
add_generation_prompt: bool,
special_tokens: Optional[Dict[str, str]] = None,
):
def get_prompt_from_template(prompt_template: PromptTemplate, template_vars: dict):
"""Get a prompt from a template and a list of messages."""
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
@ -35,12 +29,8 @@ def get_prompt_from_template(
)
compiled_template = _compile_template(prompt_template.template)
rendered_template = compiled_template.render(
messages=messages,
add_generation_prompt=add_generation_prompt,
**special_tokens,
)
template_stop_strings = _get_template_stop_strings(compiled_template)
rendered_template = compiled_template.render(**template_vars)
template_stop_strings = _get_template_stop_strings(compiled_template, template_vars)
return rendered_template, template_stop_strings
@ -64,14 +54,15 @@ def _compile_template(template: str):
# TODO: Migrate to run during template load
def _get_template_stop_strings(prompt_template: Template):
def _get_template_stop_strings(prompt_template: Template, template_vars: dict):
"""Appends extra stop strings if present in a chat template."""
extra_stop_strings = []
template_module = prompt_template.make_module(template_vars)
if hasattr(prompt_template.module, "stop_strings"):
if isinstance(prompt_template.module.stop_strings, list):
extra_stop_strings += prompt_template.module.stop_strings
if hasattr(template_module, "stop_strings"):
if isinstance(template_module.stop_strings, list):
extra_stop_strings += template_module.stop_strings
else:
logger.warning(
"Skipping append of stopping strings from chat template "