* returning stop str if exists from gen * added chat template for firefunctionv2 * pulling tool vars from template * adding parsing for tool inputs/outputs * passing tool data from endpoint to chat template, adding tool_start to the stop list * loosened typing on the response tool call, leaning more on the user supplying a quality schema if they want a particular format * non streaming generation prototype * cleaning template * Continued work with type, ingestion into template, and chat template for fire func * Correction - streaming toolcall comes back as delta obj not inside chatcomprespchoice per chat_completion_chunk.py inside OAI lib. * Ruff Formating * Moved stop string and tool updates out of prompt creation func Updated tool pydantic to match OAI Support for streaming Updated generate tool calls to use flag within chat_template and insert tool reminder * Llama 3.1 chat templates Updated fire func template * renamed llama3.1 to chatml_with_headers.. * update name of template * Support for calling a tool start token rather than the string. Simplified tool_params Warning when gen_settings are being overidden becuase user set temp to 0 Corrected schema and tools to correct types for function args. Str for some reason * draft groq tool use model template * changed headers to vars for readablity (but mostly because some models are weird about newlines after headers, so this is an easier way to change globally) * Clean up comments and code in chat comp * Post processed tool call to meet OAI spec rather than forcing model to write json in a string in the middle of the call. * changes example back to args as json rather than string of json * Standardize chat templates to each other * cleaning/rewording * stop elements can also be ints (tokens) * Cleaning/formatting * added special tokens for tools and tool_response as specified in description * Cleaning * removing aux templates - going to live in llm-promp-templates repo instead * Tree: Format Signed-off-by: kingbri <bdashore3@proton.me> * Chat Completions: Don't include internal tool variables in OpenAPI Use SkipJsonSchema to supress inclusion with the OpenAPI JSON. The location of these variables may need to be changed in the future. Signed-off-by: kingbri <bdashore3@proton.me> * Templates: Deserialize metadata on template load Since we're only looking for specific template variables that are static in the template, it makes more sense to render when the template is initialized. Signed-off-by: kingbri <bdashore3@proton.me> * Tools: Fix comments Adhere to the format style of comments in the rest of the project. Signed-off-by: kingbri <bdashore3@proton.me> --------- Co-authored-by: Ben Gitter <gitterbd@gmail.com> Signed-off-by: kingbri <bdashore3@proton.me>
180 lines
6.3 KiB
Python
180 lines
6.3 KiB
Python
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
|
|
|
import json
|
|
import pathlib
|
|
from importlib.metadata import version as package_version
|
|
from typing import List, Optional
|
|
from jinja2 import Template, TemplateError
|
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
from loguru import logger
|
|
from packaging import version
|
|
|
|
from common.utils import unwrap
|
|
|
|
|
|
class TemplateLoadError(Exception):
|
|
"""Raised on prompt template load"""
|
|
|
|
pass
|
|
|
|
|
|
class TemplateMetadata:
|
|
"""Represents the parsed metadata from a template."""
|
|
|
|
stop_strings: List[str] = []
|
|
tool_starts: List[str] = []
|
|
|
|
|
|
class PromptTemplate:
|
|
"""A template for chat completion prompts."""
|
|
|
|
name: str
|
|
raw_template: str
|
|
template: Template
|
|
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
|
|
trim_blocks=True, lstrip_blocks=True
|
|
)
|
|
metadata: TemplateMetadata
|
|
|
|
def extract_metadata(self):
|
|
"""Returns deserialized template metadata from a chat template."""
|
|
|
|
template_metadata = TemplateMetadata()
|
|
|
|
template_module = self.template.make_module()
|
|
|
|
if hasattr(template_module, "stop_strings"):
|
|
if isinstance(template_module.stop_strings, list):
|
|
template_metadata.stop_strings += template_module.stop_strings
|
|
else:
|
|
logger.warning(
|
|
"Skipping append of stopping strings from chat template "
|
|
"because stop_strings isn't a list."
|
|
)
|
|
|
|
if hasattr(template_module, "tool_start"):
|
|
if isinstance(template_module.tool_start, str):
|
|
template_metadata.tool_starts.append(template_module.tool_start)
|
|
|
|
if hasattr(template_module, "tool_start_token"):
|
|
if isinstance(template_module.tool_start_token, int):
|
|
template_metadata.tool_starts.append(template_module.tool_start_token)
|
|
|
|
return template_metadata
|
|
|
|
def render(self, template_vars: dict):
|
|
"""Get a prompt from a template and a list of messages."""
|
|
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
|
raise ImportError(
|
|
"Parsing these chat completion messages requires jinja2 3.0.0 "
|
|
f"or greater. Current version: {package_version('jinja2')}\n"
|
|
"Please update jinja by running the following command: "
|
|
"pip install --upgrade jinja2"
|
|
)
|
|
|
|
rendered_template = self.template.render(**template_vars)
|
|
|
|
return rendered_template
|
|
|
|
def compile(self, template_str: str):
|
|
"""Compiles and stores a jinja2 template"""
|
|
|
|
# Exception handler
|
|
def raise_exception(message):
|
|
raise TemplateError(message)
|
|
|
|
self.environment.globals["raise_exception"] = raise_exception
|
|
|
|
return self.environment.from_string(template_str)
|
|
|
|
def __init__(self, name: str, raw_template: str):
|
|
"""Initializer for the PromptTemplate class."""
|
|
|
|
self.name = name
|
|
self.raw_template = raw_template
|
|
self.template = self.compile(raw_template)
|
|
self.metadata = self.extract_metadata()
|
|
|
|
@classmethod
|
|
def from_file(self, prompt_template_name: str):
|
|
"""Get a template from a jinja file."""
|
|
|
|
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
|
if template_path.exists():
|
|
with open(template_path, "r", encoding="utf8") as raw_template_stream:
|
|
return PromptTemplate(
|
|
name=prompt_template_name,
|
|
raw_template=raw_template_stream.read(),
|
|
)
|
|
else:
|
|
# Let the user know if the template file isn't found
|
|
raise TemplateLoadError(
|
|
f'Chat template "{prompt_template_name}" not found in files.'
|
|
)
|
|
|
|
@classmethod
|
|
def from_model_json(
|
|
self, json_path: pathlib.Path, key: str, name: Optional[str] = None
|
|
):
|
|
"""Get a template from a JSON file. Requires a key and template name"""
|
|
if not json_path.exists():
|
|
raise TemplateLoadError(f'Model JSON path "{json_path}" not found.')
|
|
|
|
with open(json_path, "r", encoding="utf8") as config_file:
|
|
model_config = json.load(config_file)
|
|
chat_template = model_config.get(key)
|
|
|
|
if not chat_template:
|
|
raise TemplateLoadError(
|
|
"Could not find a value from chat_template key in the passed JSON. "
|
|
"Check the tokenizer config?"
|
|
)
|
|
|
|
if isinstance(chat_template, list):
|
|
# Handles the new list style of chat templates
|
|
if name:
|
|
wrapped_template = next(
|
|
(x for x in chat_template if x.get("name") == name),
|
|
{},
|
|
)
|
|
else:
|
|
wrapped_template = chat_template[0]
|
|
name = unwrap(wrapped_template.get("name"), "from_tokenizer_config")
|
|
|
|
selected_template = wrapped_template.get("template")
|
|
|
|
if selected_template:
|
|
return PromptTemplate(name=name, raw_template=selected_template)
|
|
else:
|
|
raise TemplateLoadError(
|
|
f'Chat template with name "{name}" not found '
|
|
"in model templates list."
|
|
)
|
|
else:
|
|
# Can safely assume the chat template is the old style
|
|
return PromptTemplate(
|
|
name="from_tokenizer_config",
|
|
raw_template=chat_template,
|
|
)
|
|
|
|
|
|
def get_all_templates():
|
|
"""Fetches all templates from the templates directory"""
|
|
|
|
template_directory = pathlib.Path("templates")
|
|
return template_directory.glob("*.jinja")
|
|
|
|
|
|
def find_template_from_model(model_path: pathlib.Path):
|
|
"""Find a matching template name from a model path."""
|
|
model_name = model_path.name
|
|
template_files = get_all_templates()
|
|
|
|
for filepath in template_files:
|
|
template_name = filepath.stem.lower()
|
|
|
|
# Check if the template name is present in the model name
|
|
if template_name in model_name.lower():
|
|
return template_name
|
|
else:
|
|
raise TemplateLoadError("Could not find template from model name.")
|