Templates: Switch to Jinja2
Jinja2 is a lightweight template parser that's used in Transformers for parsing chat completions. It's much more efficient than Fastchat and can be imported as part of requirements. Also allows for unblocking Pydantic's version. Users now have to provide their own template if needed. A separate repo may be usable for common prompt template storage. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
95fd0f075e
commit
f631dd6ff7
14 changed files with 115 additions and 74 deletions
|
|
@ -1,7 +1,7 @@
|
|||
from uuid import uuid4
|
||||
from time import time
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Union, List, Optional
|
||||
from typing import Union, List, Optional, Dict
|
||||
from OAI.types.common import UsageStats, CommonCompletionRequest
|
||||
|
||||
class ChatCompletionMessage(BaseModel):
|
||||
|
|
@ -24,7 +24,7 @@ class ChatCompletionStreamChoice(BaseModel):
|
|||
class ChatCompletionRequest(CommonCompletionRequest):
|
||||
# Messages
|
||||
# Take in a string as well even though it's not part of the OAI spec
|
||||
messages: Union[str, List[ChatCompletionMessage]]
|
||||
messages: Union[str, List[Dict[str, str]]]
|
||||
prompt_template: Optional[str] = None
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
|
|
|
|||
55
OAI/utils.py
55
OAI/utils.py
|
|
@ -10,18 +10,9 @@ from OAI.types.chat_completion import (
|
|||
from OAI.types.common import UsageStats
|
||||
from OAI.types.lora import LoraList, LoraCard
|
||||
from OAI.types.model import ModelList, ModelCard
|
||||
from packaging import version
|
||||
from typing import Optional, List
|
||||
from utils import unwrap
|
||||
from typing import Optional
|
||||
|
||||
# Check fastchat
|
||||
try:
|
||||
import fastchat
|
||||
from fastchat.model.model_adapter import get_conversation_template, get_conv_template
|
||||
from fastchat.conversation import SeparatorStyle
|
||||
_fastchat_available = True
|
||||
except ImportError:
|
||||
_fastchat_available = False
|
||||
from utils import unwrap
|
||||
|
||||
def create_completion_response(text: str, prompt_tokens: int, completion_tokens: int, model_name: Optional[str]):
|
||||
choice = CompletionRespChoice(
|
||||
|
|
@ -110,45 +101,3 @@ def get_lora_list(lora_path: pathlib.Path):
|
|||
lora_list.data.append(lora_card)
|
||||
|
||||
return lora_list
|
||||
|
||||
def get_chat_completion_prompt(model_path: str, messages: List[ChatCompletionMessage], prompt_template: Optional[str] = None):
|
||||
|
||||
# TODO: Replace fastchat with in-house jinja templates
|
||||
# Check if fastchat is available
|
||||
if not _fastchat_available:
|
||||
raise ModuleNotFoundError(
|
||||
"Fastchat must be installed to parse these chat completion messages.\n"
|
||||
"Please run the following command: pip install fschat[model_worker]"
|
||||
)
|
||||
if version.parse(fastchat.__version__) < version.parse("0.2.23"):
|
||||
raise ImportError(
|
||||
"Parsing these chat completion messages requires fastchat 0.2.23 or greater. "
|
||||
f"Current version: {fastchat.__version__}\n"
|
||||
"Please upgrade fastchat by running the following command: "
|
||||
"pip install -U fschat[model_worker]"
|
||||
)
|
||||
|
||||
if prompt_template:
|
||||
conv = get_conv_template(prompt_template)
|
||||
else:
|
||||
conv = get_conversation_template(model_path)
|
||||
|
||||
if conv.sep_style is None:
|
||||
conv.sep_style = SeparatorStyle.LLAMA2
|
||||
|
||||
for message in messages:
|
||||
msg_role = message.role
|
||||
if msg_role == "system":
|
||||
conv.set_system_message(message.content)
|
||||
elif msg_role == "user":
|
||||
conv.append_message(conv.roles[0], message.content)
|
||||
elif msg_role == "assistant":
|
||||
conv.append_message(conv.roles[1], message.content)
|
||||
else:
|
||||
raise ValueError(f"Unknown role: {msg_role}")
|
||||
|
||||
conv.append_message(conv.roles[1], None)
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
print(prompt)
|
||||
return prompt
|
||||
|
|
|
|||
|
|
@ -54,8 +54,6 @@ NOTE: For Flash Attention 2 to work on Windows, CUDA 12.x **must** be installed!
|
|||
|
||||
3. ROCm 5.6: `pip install -r requirements-amd.txt`
|
||||
|
||||
5. If you want the `/v1/chat/completions` endpoint to work with a list of messages, install fastchat by running `pip install fschat[model_worker]`
|
||||
|
||||
## Configuration
|
||||
|
||||
A config.yml file is required for overriding project defaults. If you are okay with the defaults, you don't need a config file!
|
||||
|
|
@ -126,6 +124,12 @@ All routes require an API key except for the following which require an **admin*
|
|||
|
||||
- `/v1/model/unload`
|
||||
|
||||
## Chat Completions
|
||||
|
||||
`/v1/chat/completions` now uses Jinja2 for templating. Please read [Huggingface's documentation](https://huggingface.co/docs/transformers/main/chat_templating) for more information of how chat templates work.
|
||||
|
||||
Also make sure to set the template name in `config.yml` to the template's filename.
|
||||
|
||||
## Common Issues
|
||||
|
||||
- AMD cards will error out with flash attention installed, even if the config option is set to False. Run `pip uninstall flash_attn` to remove the wheel from your system.
|
||||
|
|
|
|||
|
|
@ -56,9 +56,9 @@ model:
|
|||
# Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16)
|
||||
cache_mode: FP16
|
||||
|
||||
# Set the prompt template for this model. If empty, fastchat will automatically find the best template to use (default: None)
|
||||
# Set the prompt template for this model. If empty, chat completions will be disabled. (default: alpaca)
|
||||
# NOTE: Only works with chat completion message lists!
|
||||
prompt_template:
|
||||
prompt_template: alpaca
|
||||
|
||||
# Number of experts to use per token. Loads from the model's config.json if not specified (default: None)
|
||||
# WARNING: Don't set this unless you know what you're doing!
|
||||
|
|
|
|||
19
main.py
19
main.py
|
|
@ -27,10 +27,10 @@ from OAI.utils import (
|
|||
create_completion_response,
|
||||
get_model_list,
|
||||
get_lora_list,
|
||||
get_chat_completion_prompt,
|
||||
create_chat_completion_response,
|
||||
create_chat_completion_stream_chunk
|
||||
)
|
||||
from templating import get_prompt_from_template
|
||||
from utils import get_generator_error, get_sse_packet, load_progress, unwrap
|
||||
|
||||
app = FastAPI()
|
||||
|
|
@ -76,6 +76,7 @@ async def list_models():
|
|||
@app.get("/v1/internal/model/info", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def get_current_model():
|
||||
model_name = model_container.get_model_path().name
|
||||
prompt_template = model_container.prompt_template
|
||||
model_card = ModelCard(
|
||||
id = model_name,
|
||||
parameters = ModelCardParameters(
|
||||
|
|
@ -83,7 +84,7 @@ async def get_current_model():
|
|||
rope_alpha = model_container.config.scale_alpha_value,
|
||||
max_seq_len = model_container.config.max_seq_len,
|
||||
cache_mode = "FP8" if model_container.cache_fp8 else "FP16",
|
||||
prompt_template = unwrap(model_container.prompt_template, "auto")
|
||||
prompt_template = prompt_template.name if prompt_template else None
|
||||
),
|
||||
logging = gen_logging.config
|
||||
)
|
||||
|
|
@ -302,19 +303,21 @@ async def generate_completion(request: Request, data: CompletionRequest):
|
|||
# Chat completions endpoint
|
||||
@app.post("/v1/chat/completions", dependencies=[Depends(check_api_key), Depends(_check_model_container)])
|
||||
async def generate_chat_completion(request: Request, data: ChatCompletionRequest):
|
||||
if model_container.prompt_template is None:
|
||||
return HTTPException(422, "This endpoint is disabled because a prompt template is not set.")
|
||||
|
||||
model_path = model_container.get_model_path()
|
||||
|
||||
if isinstance(data.messages, str):
|
||||
prompt = data.messages
|
||||
else:
|
||||
# If the request specified prompt template isn't found, use the one from model container
|
||||
# Otherwise, let fastchat figure it out
|
||||
prompt_template = unwrap(data.prompt_template, model_container.prompt_template)
|
||||
|
||||
try:
|
||||
prompt = get_chat_completion_prompt(model_path.name, data.messages, prompt_template)
|
||||
prompt = get_prompt_from_template(data.messages, model_container.prompt_template)
|
||||
except KeyError:
|
||||
return HTTPException(400, f"Could not find a Conversation from prompt template '{prompt_template}'. Check your spelling?")
|
||||
return HTTPException(
|
||||
400,
|
||||
f"Could not find a Conversation from prompt template '{model_container.prompt_template.name}'. Check your spelling?"
|
||||
)
|
||||
|
||||
if data.stream:
|
||||
const_id = f"chatcmpl-{uuid4().hex}"
|
||||
|
|
|
|||
18
model.py
18
model.py
|
|
@ -17,6 +17,7 @@ from exllamav2.generator import(
|
|||
|
||||
from gen_logging import log_generation_params, log_prompt, log_response
|
||||
from typing import List, Optional, Union
|
||||
from templating import PromptTemplate
|
||||
from utils import coalesce, unwrap
|
||||
|
||||
# Bytes to reserve on first device when loading with auto split
|
||||
|
|
@ -31,7 +32,7 @@ class ModelContainer:
|
|||
draft_cache: Optional[ExLlamaV2Cache] = None
|
||||
tokenizer: Optional[ExLlamaV2Tokenizer] = None
|
||||
generator: Optional[ExLlamaV2StreamingGenerator] = None
|
||||
prompt_template: Optional[str] = None
|
||||
prompt_template: Optional[PromptTemplate] = None
|
||||
|
||||
cache_fp8: bool = False
|
||||
gpu_split_auto: bool = True
|
||||
|
|
@ -103,7 +104,20 @@ class ModelContainer:
|
|||
"""
|
||||
|
||||
# Set prompt template override if provided
|
||||
self.prompt_template = kwargs.get("prompt_template")
|
||||
prompt_template_name = kwargs.get("prompt_template")
|
||||
if prompt_template_name:
|
||||
try:
|
||||
with open(pathlib.Path(f"templates/{prompt_template_name}.jinja"), "r") as raw_template:
|
||||
self.prompt_template = PromptTemplate(
|
||||
name = prompt_template_name,
|
||||
template = raw_template.read()
|
||||
)
|
||||
except OSError:
|
||||
print("Chat completions are disabled because the provided prompt template couldn't be found.")
|
||||
self.prompt_template = None
|
||||
else:
|
||||
print("Chat completions are disabled because a provided prompt template couldn't be found.")
|
||||
self.prompt_template = None
|
||||
|
||||
# Set num of experts per token if provided
|
||||
num_experts_override = kwargs.get("num_experts_per_token")
|
||||
|
|
|
|||
|
|
@ -12,3 +12,4 @@ pydantic < 2,>= 1
|
|||
PyYAML
|
||||
progress
|
||||
uvicorn
|
||||
jinja2
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ pydantic < 2,>= 1
|
|||
PyYAML
|
||||
progress
|
||||
uvicorn
|
||||
jinja2
|
||||
|
||||
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
|
||||
https://github.com/Dao-AILab/flash-attention/releases/download/v2.3.6/flash_attn-2.3.6+cu118torch2.1cxx11abiFALSE-cp310-cp310-linux_x86_64.whl; platform_system == "Linux" and platform_machine == "x86_64" and python_version == "3.10"
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ pydantic < 2,>= 1
|
|||
PyYAML
|
||||
progress
|
||||
uvicorn
|
||||
jinja2
|
||||
|
||||
# Flash attention v2
|
||||
|
||||
|
|
|
|||
7
templates/README.md
Normal file
7
templates/README.md
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Templates
|
||||
|
||||
NOTE: This folder will be replaced by a submodule or something similar in the future
|
||||
|
||||
These templates are examples from [Aphrodite Engine](https://github.com/PygmalionAI/aphrodite-engine/tree/main/examples)
|
||||
|
||||
Please look at [Huggingface's documentation](https://huggingface.co/docs/transformers/main/chat_templating) for making Jinja2 templates.
|
||||
29
templates/alpaca.jinja
Normal file
29
templates/alpaca.jinja
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
{{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }}
|
||||
|
||||
{% for message in messages %}
|
||||
{% if message['role'] == 'user' %}
|
||||
### Instruction:
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'assistant' %}
|
||||
### Response:
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% elif message['role'] == 'user_context' %}
|
||||
### Input:
|
||||
{{ message['content']|trim -}}
|
||||
{% if not loop.last %}
|
||||
|
||||
|
||||
{% endif %}
|
||||
{% endif %}
|
||||
{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}
|
||||
### Response:
|
||||
{% endif %}
|
||||
2
templates/chatml.jinja
Normal file
2
templates/chatml.jinja
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if (loop.last and add_generation_prompt) or not loop.last %}{{ '<|im_end|>' + '\n'}}{% endif %}{% endfor %}
|
||||
{% if add_generation_prompt and messages[-1]['role'] != 'assistant' %}{{ '<|im_start|>assistant\n' }}{% endif %}
|
||||
30
templating.py
Normal file
30
templating.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from functools import lru_cache
|
||||
from importlib.metadata import version as package_version
|
||||
from packaging import version
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Small replication of AutoTokenizer's chat template system for efficiency
|
||||
|
||||
class PromptTemplate(BaseModel):
|
||||
name: str
|
||||
template: str
|
||||
|
||||
def get_prompt_from_template(messages, prompt_template: PromptTemplate):
|
||||
if version.parse(package_version("jinja2")) < version.parse("3.0.0"):
|
||||
raise ImportError(
|
||||
"Parsing these chat completion messages requires fastchat 0.2.23 or greater. "
|
||||
f"Current version: {version('jinja2')}\n"
|
||||
"Please upgrade fastchat by running the following command: "
|
||||
"pip install -U fschat[model_worker]"
|
||||
)
|
||||
|
||||
compiled_template = _compile_template(prompt_template.template)
|
||||
return compiled_template.render(messages = messages)
|
||||
|
||||
# Inspired from https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L1761
|
||||
@lru_cache
|
||||
def _compile_template(template: str):
|
||||
jinja_env = ImmutableSandboxedEnvironment(trim_blocks = True, lstrip_blocks = True)
|
||||
jinja_template = jinja_env.from_string(template)
|
||||
return jinja_template
|
||||
|
|
@ -25,12 +25,12 @@ else:
|
|||
print("Torch is not found in your environment.")
|
||||
errored_packages.append("torch")
|
||||
|
||||
if find_spec("fastchat") is not None:
|
||||
print(f"Fastchat on version {version('fschat')} successfully imported")
|
||||
successful_packages.append("fastchat")
|
||||
if find_spec("jinja2") is not None:
|
||||
print(f"Jinja2 on version {version('jinja2')} successfully imported")
|
||||
successful_packages.append("jinja2")
|
||||
else:
|
||||
print("Fastchat is not found in your environment. It isn't needed unless you're using chat completions with message arrays.")
|
||||
errored_packages.append("fastchat")
|
||||
print("Jinja2 is not found in your environment.")
|
||||
errored_packages.append("jinja2")
|
||||
|
||||
print(
|
||||
f"\nSuccessful imports: {', '.join(successful_packages)}",
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue