Templates: Switch to Jinja2

Jinja2 is a lightweight template parser that's used in Transformers
for parsing chat completions. It's much more efficient than Fastchat
and can be imported as part of requirements.

Also allows for unblocking Pydantic's version.

Users now have to provide their own template if needed. A separate
repo may be usable for common prompt template storage.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-12-17 00:41:42 -05:00 committed by Brian Dashore
parent 95fd0f075e
commit f631dd6ff7
14 changed files with 115 additions and 74 deletions

View file

@ -1,7 +1,7 @@
from uuid import uuid4
from time import time
from pydantic import BaseModel, Field
from typing import Union, List, Optional
from typing import Union, List, Optional, Dict
from OAI.types.common import UsageStats, CommonCompletionRequest
class ChatCompletionMessage(BaseModel):
@ -24,7 +24,7 @@ class ChatCompletionStreamChoice(BaseModel):
class ChatCompletionRequest(CommonCompletionRequest):
# Messages
# Take in a string as well even though it's not part of the OAI spec
messages: Union[str, List[ChatCompletionMessage]]
messages: Union[str, List[Dict[str, str]]]
prompt_template: Optional[str] = None
class ChatCompletionResponse(BaseModel):

View file

@ -10,18 +10,9 @@ from OAI.types.chat_completion import (
from OAI.types.common import UsageStats
from OAI.types.lora import LoraList, LoraCard
from OAI.types.model import ModelList, ModelCard
from packaging import version
from typing import Optional, List
from utils import unwrap
from typing import Optional
# Check fastchat
try:
import fastchat
from fastchat.model.model_adapter import get_conversation_template, get_conv_template
from fastchat.conversation import SeparatorStyle
_fastchat_available = True
except ImportError:
_fastchat_available = False
from utils import unwrap
def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
choice = CompletionRespChoice(
@ -110,45 +101,3 @@ def get_lora_list(lora_path: pathlib.Path):
lora_list.data.append(lora_card)
return lora_list
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage], prompt_template: Optional[str] = None):
# TODO: Replace fastchat with in-house jinja templates
# Check if fastchat is available
if not _fastchat_available:
raise ModuleNotFoundError(
"Fastchat must be installed to parse these chat completion messages.\n"
"Please run the following command: pip install fschat[model_worker]"
)
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
raise ImportError(
"Parsing these chat completion messages requires fastchat 0.2.23 or greater. "
f"Current version: {fastchat.__version__}\n"
"Please upgrade fastchat by running the following command: "
"pip install -U fschat[model_worker]"
)
if prompt_template:
conv = get_conv_template(prompt_template)
else:
conv = get_conversation_template(model_path)
if conv.sep_style is None:
conv.sep_style = SeparatorStyle.LLAMA2
for message in messages:
msg_role = message.role
if msg_role == "system":
conv.set_system_message(message.content)
elif msg_role == "user":
conv.append_message(conv.roles[0], message.content)
elif msg_role == "assistant":
conv.append_message(conv.roles[1], message.content)
else:
raise ValueError(f"Unknown role: {msg_role}")
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
print(prompt)
return prompt

View file

@ -54,8 +54,6 @@ NOTE: For Flash Attention 2 to work on Windows, CUDA 12.x **must** be installed!
3. ROCm 5.6: `pip install -r requirements-amd.txt`
5. If you want the `/v1/chat/completions` endpoint to work with a list of messages, install fastchat by running `pip install fschat[model_worker]`
## Configuration
A config.yml file is required for overriding project defaults. If you are okay with the defaults, you don't need a config file!
@ -126,6 +124,12 @@ All routes require an API key except for the following which require an **admin*
- `/v1/model/unload`
## Chat Completions
`/v1/chat/completions` now uses Jinja2 for templating. Please read [Huggingface's documentation](https://huggingface.co/docs/transformers/main/chat_templating) for more information of how chat templates work.
Also make sure to set the template name in `config.yml` to the template's filename.
## Common Issues
- AMD cards will error out with flash attention installed, even if the config option is set to False. Run `pip uninstall flash_attn` to remove the wheel from your system.

View file

@ -56,9 +56,9 @@ model:
# Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16)
cache_mode: FP16
# Set the prompt template for this model. If empty, fastchat will automatically find the best template to use (default: None)
# Set the prompt template for this model. If empty, chat completions will be disabled. (default: alpaca)
# NOTE: Only works with chat completion message lists!
prompt_template:
prompt_template: alpaca
# Number of experts to use per token. Loads from the model's config.json if not specified (default: None)
# WARNING: Don't set this unless you know what you're doing!

19
main.py
View file

@ -27,10 +27,10 @@ from OAI.utils import (
create_completion_response,
get_model_list,
get_lora_list,
get_chat_completion_prompt,
create_chat_completion_response,
create_chat_completion_stream_chunk
)
from templating import get_prompt_from_template
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
app = FastAPI()
@ -76,6 +76,7 @@ async def list_models():
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def get_current_model():
model_name = model_container.get_model_path().name
prompt_template = model_container.prompt_template
model_card = ModelCard(
id = model_name,
parameters = ModelCardParameters(
@ -83,7 +84,7 @@ async def get_current_model():
rope_alpha = model_container.config.scale_alpha_value,
max_seq_len = model_container.config.max_seq_len,
cache_mode = "FP8" if model_container.cache_fp8 else "FP16",
prompt_template = unwrap(model_container.prompt_template, "auto")
prompt_template = prompt_template.name if prompt_template else None
),
logging = gen_logging.config
)
@ -302,19 +303,21 @@ async def generate_completion(request: Request, data: CompletionRequest):
# Chat completions endpoint
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
if model_container.prompt_template is None:
return HTTPException(422, "This endpoint is disabled because a prompt template is not set.")
model_path = model_container.get_model_path()
if isinstance(data.messages, str):
prompt = data.messages
else:
# If the request specified prompt template isn't found, use the one from model container
# Otherwise, let fastchat figure it out
prompt_template = unwrap(data.prompt_template, model_container.prompt_template)
try:
prompt = get_chat_completion_prompt(model_path.name, data.messages, prompt_template)
prompt = get_prompt_from_template(data.messages, model_container.prompt_template)
except KeyError:
return HTTPException(400, f"Could not find a Conversation from prompt template '{prompt_template}'. Check your spelling?")
return HTTPException(
400,
f"Could not find a Conversation from prompt template '{model_container.prompt_template.name}'. Check your spelling?"
)
if data.stream:
const_id = f"chatcmpl-{uuid4().hex}"

View file

@ -17,6 +17,7 @@ from exllamav2.generator import(
from gen_logging import log_generation_params, log_prompt, log_response
from typing import List, Optional, Union
from templating import PromptTemplate
from utils import coalesce, unwrap
# Bytes to reserve on first device when loading with auto split
@ -31,7 +32,7 @@ class ModelContainer:
draft_cache: Optional[ExLlamaV2Cache] = None
tokenizer: Optional[ExLlamaV2Tokenizer] = None
generator: Optional[ExLlamaV2StreamingGenerator] = None
prompt_template: Optional[str] = None
prompt_template: Optional[PromptTemplate] = None
cache_fp8: bool = False
gpu_split_auto: bool = True
@ -103,7 +104,20 @@ class ModelContainer:
"""
# Set prompt template override if provided
self.prompt_template = kwargs.get("prompt_template")
prompt_template_name = kwargs.get("prompt_template")
if prompt_template_name:
try:
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template:
self.prompt_template = PromptTemplate(
name = prompt_template_name,
template = raw_template.read()
)
except OSError:
print("Chat completions are disabled because the provided prompt template couldn't be found.")
self.prompt_template = None
else:
print("Chat completions are disabled because a provided prompt template couldn't be found.")
self.prompt_template = None
# Set num of experts per token if provided
num_experts_override = kwargs.get("num_experts_per_token")

View file

@ -12,3 +12,4 @@ pydantic < 2,>= 1
PyYAML
progress
uvicorn
jinja2

View file

@ -18,6 +18,7 @@ pydantic < 2,>= 1
PyYAML
progress
uvicorn
jinja2
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.6/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"

View file

@ -18,6 +18,7 @@ pydantic < 2,>= 1
PyYAML
progress
uvicorn
jinja2
# Flash attention v2

7
templates/README.md Normal file
View file

@ -0,0 +1,7 @@
# Templates
NOTE: This folder will be replaced by a submodule or something similar in the future
These templates are examples from [Aphrodite Engine](https://github.com/PygmalionAI/aphrodite-engine/tree/main/examples)
Please look at [Huggingface's documentation](https://huggingface.co/docs/transformers/main/chat_templating) for making Jinja2 templates.

29
templates/alpaca.jinja Normal file
View file

@ -0,0 +1,29 @@
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
{% for message in messages %}
{% if message['role'] == 'user' %}
### Instruction:
{{ message['content']|trim -}}
{% if not loop.last %}
{% endif %}
{% elif message['role'] == 'assistant' %}
### Response:
{{ message['content']|trim -}}
{% if not loop.last %}
{% endif %}
{% elif message['role'] == 'user_context' %}
### Input:
{{ message['content']|trim -}}
{% if not loop.last %}
{% endif %}
{% endif %}
{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
### Response:
{% endif %}

2
templates/chatml.jinja Normal file
View file

@ -0,0 +1,2 @@
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}

30
templating.py Normal file
View file

@ -0,0 +1,30 @@
from functools import lru_cache
from importlib.metadata import version as package_version
from packaging import version
from jinja2.sandbox import ImmutableSandboxedEnvironment
from pydantic import BaseModel
# Small replication of AutoTokenizer's chat template system for efficiency
class PromptTemplate(BaseModel):
name: str
template: str
def get_prompt_from_template(messages, prompt_template: PromptTemplate):
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
raise ImportError(
"Parsing these chat completion messages requires fastchat 0.2.23 or greater. "
f"Current version: {version('jinja2')}\n"
"Please upgrade fastchat by running the following command: "
"pip install -U fschat[model_worker]"
)
compiled_template = _compile_template(prompt_template.template)
return compiled_template.render(messages = messages)
# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
@lru_cache
def _compile_template(template: str):
jinja_env = ImmutableSandboxedEnvironment(trim_blocks = True, lstrip_blocks = True)
jinja_template = jinja_env.from_string(template)
return jinja_template

View file

@ -25,12 +25,12 @@ else:
print("Torch is not found in your environment.")
errored_packages.append("torch")
if find_spec("fastchat") is not None:
print(f"Fastchat on version {version('fschat')} successfully imported")
successful_packages.append("fastchat")
if find_spec("jinja2") is not None:
print(f"Jinja2 on version {version('jinja2')} successfully imported")
successful_packages.append("jinja2")
else:
print("Fastchat is not found in your environment. It isn't needed unless you're using chat completions with message arrays.")
errored_packages.append("fastchat")
print("Jinja2 is not found in your environment.")
errored_packages.append("jinja2")
print(
f"\nSuccessful imports: {', '.join(successful_packages)}",