From a14abfe21cf7586f0dbee9555984f794dad38fcd Mon Sep 17 00:00:00 2001 From: kingbri Date: Fri, 22 Dec 2023 10:31:50 -0500 Subject: [PATCH] Templates: Support bos_token and eos_token fields These are commonly seen in huggingface provided chat templates and aren't that difficult to add in. For feature parity, honor the add_bos_token and ban_eos_token parameters when constructing the prompt. Signed-off-by: kingbri --- main.py | 7 ++++++- model.py | 8 ++++++++ templating.py | 9 +++++++-- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 1228404..e39ed1e 100644 --- a/main.py +++ b/main.py @@ -324,10 +324,15 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest prompt = data.messages else: try: + special_tokens_dict = model_container.get_special_tokens( + unwrap(data.add_bos_token, True), + unwrap(data.ban_eos_token, False) + ) prompt = get_prompt_from_template( data.messages, model_container.prompt_template, - data.add_generation_prompt + data.add_generation_prompt, + special_tokens_dict, ) except KeyError: return HTTPException( diff --git a/model.py b/model.py index 41d9889..767e9ad 100644 --- a/model.py +++ b/model.py @@ -341,6 +341,14 @@ class ModelContainer: ids = torch.tensor([ids]) return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0] + def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool): + return { + "bos_token": self.tokenizer.bos_token if add_bos_token else "", + "eos_token": self.tokenizer.eos_token if not ban_eos_token else "", + "pad_token": self.tokenizer.pad_token, + "unk_token": self.tokenizer.unk_token, + } + def generate(self, prompt: str, **kwargs): generation = list(self.generate_gen(prompt, **kwargs)) if generation: diff --git a/templating.py b/templating.py index 86bcee8..c7fde76 100644 --- a/templating.py +++ b/templating.py @@ -5,6 +5,7 @@ from importlib.metadata import version as package_version from jinja2.sandbox import ImmutableSandboxedEnvironment from packaging import version from pydantic import BaseModel +from typing import Optional, Dict # Small replication of AutoTokenizer's chat template system for efficiency @@ -12,7 +13,10 @@ class PromptTemplate(BaseModel): name: str template: str -def get_prompt_from_template(messages, prompt_template: PromptTemplate, add_generation_prompt: bool): +def get_prompt_from_template(messages, + prompt_template: PromptTemplate, + add_generation_prompt: bool, + special_tokens: Optional[Dict[str, str]] = None): if version.parse(package_version("jinja2")) < version.parse("3.0.0"): raise ImportError( "Parsing these chat completion messages requires jinja2 3.0.0 or greater. " @@ -24,7 +28,8 @@ def get_prompt_from_template(messages, prompt_template: PromptTemplate, add_gene compiled_template = _compile_template(prompt_template.template) return compiled_template.render( messages = messages, - add_generation_prompt = add_generation_prompt + add_generation_prompt = add_generation_prompt, + **special_tokens, ) # Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761