Templates: Add stop_strings meta param
Adding the stop_strings var to chat templates will allow for the template creator to specify stopping strings to add onto chat completions. Thes get appended with existing stopping strings that are passed in the API request. However, a sampler override with force: true will override all stopping strings. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
277c540c98
commit
dc456f4cc2
2 changed files with 29 additions and 3 deletions
|
|
@ -4,8 +4,9 @@ import json
|
|||
import pathlib
|
||||
from functools import lru_cache
|
||||
from importlib.metadata import version as package_version
|
||||
from jinja2 import TemplateError
|
||||
from jinja2 import Template, TemplateError
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from loguru import logger
|
||||
from packaging import version
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, Dict
|
||||
|
|
@ -34,15 +35,19 @@ def get_prompt_from_template(
|
|||
)
|
||||
|
||||
compiled_template = _compile_template(prompt_template.template)
|
||||
return compiled_template.render(
|
||||
rendered_template = compiled_template.render(
|
||||
messages=messages,
|
||||
add_generation_prompt=add_generation_prompt,
|
||||
**special_tokens,
|
||||
)
|
||||
template_stop_strings = _get_template_stop_strings(compiled_template)
|
||||
|
||||
return rendered_template, template_stop_strings
|
||||
|
||||
|
||||
# Inspired from
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
|
||||
# TODO: Migrate to compile when template is loaded (removes the need for an lru_cache)
|
||||
@lru_cache
|
||||
def _compile_template(template: str):
|
||||
"""Compiles a Jinja2 template"""
|
||||
|
|
@ -58,6 +63,24 @@ def _compile_template(template: str):
|
|||
return jinja_template
|
||||
|
||||
|
||||
# TODO: Migrate to run during template load
|
||||
def _get_template_stop_strings(prompt_template: Template):
|
||||
"""Appends extra stop strings if present in a chat template."""
|
||||
|
||||
extra_stop_strings = []
|
||||
|
||||
if hasattr(prompt_template.module, "stop_strings"):
|
||||
if isinstance(prompt_template.module.stop_strings, list):
|
||||
extra_stop_strings += prompt_template.module.stop_strings
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping append of stopping strings from chat template "
|
||||
"because stop_strings isn't a list."
|
||||
)
|
||||
|
||||
return extra_stop_strings
|
||||
|
||||
|
||||
def get_all_templates():
|
||||
"""Fetches all templates from the templates directory"""
|
||||
|
||||
|
|
|
|||
|
|
@ -508,7 +508,10 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
|
|||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
prompt = format_prompt_with_template(data)
|
||||
# Compile the prompt and get any additional stop strings from the template
|
||||
# Template stop strings can be overriden by sampler overrides if force is true
|
||||
prompt, template_stop_strings = format_prompt_with_template(data)
|
||||
data.stop += template_stop_strings
|
||||
|
||||
disable_request_streaming = unwrap(
|
||||
config.developer_config().get("disable_request_streaming"), False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue