* add github workflows for pylint and yapf * yapf * docstrings for auth * fix auth.py * fix generators.py * fix gen_logging.py * fix main.py * fix model.py * fix templating.py * fix utils.py * update formatting.sh to include subdirs for pylint * fix model_test.py * fix wheel_test.py * rename utils to utils_oai * fix OAI/utils_oai.py * fix completion.py * fix token.py * fix lora.py * fix common.py * add pylintrc and fix model.py * finish up pylint * fix attribute error * main.py formatting * add formatting batch script * Main: Remove unnecessary global Linter suggestion. Signed-off-by: kingbri <bdashore3@proton.me> * switch to ruff * Formatting + Linting: Add ruff.toml Signed-off-by: kingbri <bdashore3@proton.me> * Formatting + Linting: Switch scripts to use ruff Also remove the file and recent file change functions from both scripts. Signed-off-by: kingbri <bdashore3@proton.me> * Tree: Format and lint Signed-off-by: kingbri <bdashore3@proton.me> * Scripts + Workflows: Format Signed-off-by: kingbri <bdashore3@proton.me> * Tree: Remove pylint flags We use ruff now Signed-off-by: kingbri <bdashore3@proton.me> * Tree: Format Signed-off-by: kingbri <bdashore3@proton.me> * Formatting: Line length is 88 Use the same value as Black. Signed-off-by: kingbri <bdashore3@proton.me> * Tree: Format Update to new line length rules. Signed-off-by: kingbri <bdashore3@proton.me> --------- Authored-by: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Co-authored-by: kingbri <bdashore3@proton.me>
89 lines
3 KiB
Python
89 lines
3 KiB
Python
"""Small replication of AutoTokenizer's chat template system for efficiency"""
|
|
import json
|
|
import pathlib
|
|
from functools import lru_cache
|
|
from importlib.metadata import version as package_version
|
|
|
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
from packaging import version
|
|
from pydantic import BaseModel
|
|
from typing import Optional, Dict
|
|
|
|
|
|
class PromptTemplate(BaseModel):
|
|
"""A template for chat completion prompts."""
|
|
|
|
name: str
|
|
template: str
|
|
|
|
|
|
def get_prompt_from_template(
|
|
messages,
|
|
prompt_template: PromptTemplate,
|
|
add_generation_prompt: bool,
|
|
special_tokens: Optional[Dict[str, str]] = None,
|
|
):
|
|
"""Get a prompt from a template and a list of messages."""
|
|
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
|
raise ImportError(
|
|
"Parsing these chat completion messages requires jinja2 3.0.0 "
|
|
f"or greater. Current version: {package_version('jinja2')}\n"
|
|
"Please upgrade jinja by running the following command: "
|
|
"pip install --upgrade jinja2"
|
|
)
|
|
|
|
compiled_template = _compile_template(prompt_template.template)
|
|
return compiled_template.render(
|
|
messages=messages,
|
|
add_generation_prompt=add_generation_prompt,
|
|
**special_tokens,
|
|
)
|
|
|
|
|
|
# Inspired from
|
|
# https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
|
|
@lru_cache
|
|
def _compile_template(template: str):
|
|
jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
|
|
jinja_template = jinja_env.from_string(template)
|
|
return jinja_template
|
|
|
|
|
|
def find_template_from_model(model_path: pathlib.Path):
|
|
"""Find a matching template name from a model path."""
|
|
model_name = model_path.name
|
|
template_directory = pathlib.Path("templates")
|
|
for filepath in template_directory.glob("*.jinja"):
|
|
template_name = filepath.stem.lower()
|
|
|
|
# Check if the template name is present in the model name
|
|
if template_name in model_name.lower():
|
|
return template_name
|
|
|
|
return None
|
|
|
|
|
|
def get_template_from_file(prompt_template_name: str):
|
|
"""Get a template from a jinja file."""
|
|
template_path = pathlib.Path(f"templates/{prompt_template_name}.jinja")
|
|
if template_path.exists():
|
|
with open(template_path, "r", encoding="utf8") as raw_template:
|
|
return PromptTemplate(
|
|
name=prompt_template_name, template=raw_template.read()
|
|
)
|
|
|
|
return None
|
|
|
|
|
|
# Get a template from a JSON file
|
|
# Requires a key and template name
|
|
def get_template_from_model_json(json_path: pathlib.Path, key: str, name: str):
|
|
"""Get a template from a JSON file. Requires a key and template name"""
|
|
if json_path.exists():
|
|
with open(json_path, "r", encoding="utf8") as config_file:
|
|
model_config = json.load(config_file)
|
|
chat_template = model_config.get(key)
|
|
if chat_template:
|
|
return PromptTemplate(name=name, template=chat_template)
|
|
|
|
return None
|