Templates: Add stop_strings meta param

Adding the stop_strings var to chat templates will allow for the
template creator to specify stopping strings to add onto chat completions.

Thes get appended with existing stopping strings that are passed
in the API request. However, a sampler override with force: true will
override all stopping strings.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-27 22:07:43 -04:00
parent 277c540c98
commit dc456f4cc2
2 changed files with 29 additions and 3 deletions

View file

@ -4,8 +4,9 @@ import json
import pathlib
from functools import lru_cache
from importlib.metadata import version as package_version
from jinja2 import TemplateError
from jinja2 import Template, TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
from packaging import version
from pydantic import BaseModel
from typing import Optional, Dict
@ -34,15 +35,19 @@ def get_prompt_from_template(
)
compiled_template = _compile_template(prompt_template.template)
return compiled_template.render(
rendered_template = compiled_template.render(
messages=messages,
add_generation_prompt=add_generation_prompt,
**special_tokens,
)
template_stop_strings = _get_template_stop_strings(compiled_template)
return rendered_template, template_stop_strings
# Inspired from
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
# TODO: Migrate to compile when template is loaded (removes the need for an lru_cache)
@lru_cache
def _compile_template(template: str):
"""Compiles a Jinja2 template"""
@ -58,6 +63,24 @@ def _compile_template(template: str):
return jinja_template
# TODO: Migrate to run during template load
def _get_template_stop_strings(prompt_template: Template):
"""Appends extra stop strings if present in a chat template."""
extra_stop_strings = []
if hasattr(prompt_template.module, "stop_strings"):
if isinstance(prompt_template.module.stop_strings, list):
extra_stop_strings += prompt_template.module.stop_strings
else:
logger.warning(
"Skipping append of stopping strings from chat template "
"because stop_strings isn't a list."
)
return extra_stop_strings
def get_all_templates():
"""Fetches all templates from the templates directory"""

View file

@ -508,7 +508,10 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
if isinstance(data.messages, str):
prompt = data.messages
else:
prompt = format_prompt_with_template(data)
# Compile the prompt and get any additional stop strings from the template
# Template stop strings can be overriden by sampler overrides if force is true
prompt, template_stop_strings = format_prompt_with_template(data)
data.stop += template_stop_strings
disable_request_streaming = unwrap(
config.developer_config().get("disable_request_streaming"), False