Templates: Support bos_token and eos_token fields
These are commonly seen in huggingface provided chat templates and aren't that difficult to add in. For feature parity, honor the add_bos_token and ban_eos_token parameters when constructing the prompt. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
2bf8087de3
commit
a14abfe21c
3 changed files with 21 additions and 3 deletions
7
main.py
7
main.py
|
|
@ -324,10 +324,15 @@ async def generate_chat_completion(request: Request, data: ChatCompletionRequest
|
||||||
prompt = data.messages
|
prompt = data.messages
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
|
special_tokens_dict = model_container.get_special_tokens(
|
||||||
|
unwrap(data.add_bos_token, True),
|
||||||
|
unwrap(data.ban_eos_token, False)
|
||||||
|
)
|
||||||
prompt = get_prompt_from_template(
|
prompt = get_prompt_from_template(
|
||||||
data.messages,
|
data.messages,
|
||||||
model_container.prompt_template,
|
model_container.prompt_template,
|
||||||
data.add_generation_prompt
|
data.add_generation_prompt,
|
||||||
|
special_tokens_dict,
|
||||||
)
|
)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return HTTPException(
|
return HTTPException(
|
||||||
|
|
|
||||||
8
model.py
8
model.py
|
|
@ -341,6 +341,14 @@ class ModelContainer:
|
||||||
ids = torch.tensor([ids])
|
ids = torch.tensor([ids])
|
||||||
return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0]
|
return self.tokenizer.decode(ids, decode_special_tokens = unwrap(kwargs.get("decode_special_tokens"), True))[0]
|
||||||
|
|
||||||
|
def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
|
||||||
|
return {
|
||||||
|
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
|
||||||
|
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
|
||||||
|
"pad_token": self.tokenizer.pad_token,
|
||||||
|
"unk_token": self.tokenizer.unk_token,
|
||||||
|
}
|
||||||
|
|
||||||
def generate(self, prompt: str, **kwargs):
|
def generate(self, prompt: str, **kwargs):
|
||||||
generation = list(self.generate_gen(prompt, **kwargs))
|
generation = list(self.generate_gen(prompt, **kwargs))
|
||||||
if generation:
|
if generation:
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ from importlib.metadata import version as package_version
|
||||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional, Dict
|
||||||
|
|
||||||
# Small replication of AutoTokenizer's chat template system for efficiency
|
# Small replication of AutoTokenizer's chat template system for efficiency
|
||||||
|
|
||||||
|
|
@ -12,7 +13,10 @@ class PromptTemplate(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
template: str
|
template: str
|
||||||
|
|
||||||
def get_prompt_from_template(messages, prompt_template: PromptTemplate, add_generation_prompt: bool):
|
def get_prompt_from_template(messages,
|
||||||
|
prompt_template: PromptTemplate,
|
||||||
|
add_generation_prompt: bool,
|
||||||
|
special_tokens: Optional[Dict[str, str]] = None):
|
||||||
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
"Parsing these chat completion messages requires jinja2 3.0.0 or greater. "
|
"Parsing these chat completion messages requires jinja2 3.0.0 or greater. "
|
||||||
|
|
@ -24,7 +28,8 @@ def get_prompt_from_template(messages, prompt_template: PromptTemplate, add_gene
|
||||||
compiled_template = _compile_template(prompt_template.template)
|
compiled_template = _compile_template(prompt_template.template)
|
||||||
return compiled_template.render(
|
return compiled_template.render(
|
||||||
messages = messages,
|
messages = messages,
|
||||||
add_generation_prompt = add_generation_prompt
|
add_generation_prompt = add_generation_prompt,
|
||||||
|
**special_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
|
# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue