[WIP] OpenAI Tools Support/Function calling (#154)
* returning stop str if exists from gen * added chat template for firefunctionv2 * pulling tool vars from template * adding parsing for tool inputs/outputs * passing tool data from endpoint to chat template, adding tool_start to the stop list * loosened typing on the response tool call, leaning more on the user supplying a quality schema if they want a particular format * non streaming generation prototype * cleaning template * Continued work with type, ingestion into template, and chat template for fire func * Correction - streaming toolcall comes back as delta obj not inside chatcomprespchoice per chat_completion_chunk.py inside OAI lib. * Ruff Formating * Moved stop string and tool updates out of prompt creation func Updated tool pydantic to match OAI Support for streaming Updated generate tool calls to use flag within chat_template and insert tool reminder * Llama 3.1 chat templates Updated fire func template * renamed llama3.1 to chatml_with_headers.. * update name of template * Support for calling a tool start token rather than the string. Simplified tool_params Warning when gen_settings are being overidden becuase user set temp to 0 Corrected schema and tools to correct types for function args. Str for some reason * draft groq tool use model template * changed headers to vars for readablity (but mostly because some models are weird about newlines after headers, so this is an easier way to change globally) * Clean up comments and code in chat comp * Post processed tool call to meet OAI spec rather than forcing model to write json in a string in the middle of the call. * changes example back to args as json rather than string of json * Standardize chat templates to each other * cleaning/rewording * stop elements can also be ints (tokens) * Cleaning/formatting * added special tokens for tools and tool_response as specified in description * Cleaning * removing aux templates - going to live in llm-promp-templates repo instead * Tree: Format Signed-off-by: kingbri <bdashore3@proton.me> * Chat Completions: Don't include internal tool variables in OpenAPI Use SkipJsonSchema to supress inclusion with the OpenAPI JSON. The location of these variables may need to be changed in the future. Signed-off-by: kingbri <bdashore3@proton.me> * Templates: Deserialize metadata on template load Since we're only looking for specific template variables that are static in the template, it makes more sense to render when the template is initialized. Signed-off-by: kingbri <bdashore3@proton.me> * Tools: Fix comments Adhere to the format style of comments in the rest of the project. Signed-off-by: kingbri <bdashore3@proton.me> --------- Co-authored-by: Ben Gitter <gitterbd@gmail.com> Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
9cc0e70098
commit
70b9fc95de
8 changed files with 336 additions and 23 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -192,6 +192,7 @@ templates/*
|
|||
!templates/place_your_templates_here.txt
|
||||
!templates/alpaca.jinja
|
||||
!templates/chatml.jinja
|
||||
!templates/chatml_with_headers_tool_calling.jinja
|
||||
|
||||
# Sampler overrides folder
|
||||
sampler_overrides/*
|
||||
|
|
|
|||
|
|
@ -852,6 +852,7 @@ class ExllamaV2Container:
|
|||
"text": "",
|
||||
"prompt_tokens": 0,
|
||||
"generation_tokens": 0,
|
||||
"tool_calls": None,
|
||||
"offset": [],
|
||||
"token_probs": {},
|
||||
"logprobs": [],
|
||||
|
|
@ -864,6 +865,7 @@ class ExllamaV2Container:
|
|||
joined_generation["finish_reason"] = finish_reason_gen.get(
|
||||
"finish_reason"
|
||||
)
|
||||
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
|
||||
else:
|
||||
joined_generation["finish_reason"] = "stop"
|
||||
|
||||
|
|
@ -1068,6 +1070,15 @@ class ExllamaV2Container:
|
|||
gen_settings.top_p = 0
|
||||
gen_settings.typical = 0
|
||||
|
||||
logger.warning(
|
||||
"".join(
|
||||
[
|
||||
"Temperature is set to 0. Overriding temp, ",
|
||||
"top_k, top_p, and typical to 1.0, 1, 0, and 0.",
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Store the gen settings for logging purposes
|
||||
gen_settings_log_dict = vars(gen_settings)
|
||||
|
||||
|
|
@ -1227,9 +1238,17 @@ class ExllamaV2Container:
|
|||
log_response(request_id, full_response)
|
||||
|
||||
eos_reason = result.get("eos_reason")
|
||||
finish_reason = (
|
||||
"length" if eos_reason == "max_new_tokens" else "stop"
|
||||
)
|
||||
|
||||
stop_str = None
|
||||
if eos_reason == "max_new_tokens":
|
||||
finish_reason = "length"
|
||||
else:
|
||||
finish_reason = "stop"
|
||||
# Grab stop string if stop was the reason
|
||||
if eos_reason == "stop_token":
|
||||
stop_str = result.get("eos_triggering_token_str")
|
||||
elif eos_reason == "stop_string":
|
||||
stop_str = result.get("eos_triggering_string")
|
||||
|
||||
# Save the final result for metrics logging
|
||||
metrics_result = result
|
||||
|
|
@ -1239,6 +1258,7 @@ class ExllamaV2Container:
|
|||
"prompt_tokens": generation.get("prompt_tokens"),
|
||||
"generated_tokens": generation.get("generated_tokens"),
|
||||
"finish_reason": finish_reason,
|
||||
"stop_str": stop_str,
|
||||
}
|
||||
|
||||
yield generation
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
examples=[512],
|
||||
)
|
||||
|
||||
stop: Optional[Union[str, List[str]]] = Field(
|
||||
stop: Optional[Union[str, List[Union[str, int]]]] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("stop", []),
|
||||
validation_alias=AliasChoices("stop", "stop_sequence"),
|
||||
description="Aliases: stop_sequence",
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
import json
|
||||
import pathlib
|
||||
from importlib.metadata import version as package_version
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
from jinja2 import Template, TemplateError
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from loguru import logger
|
||||
|
|
@ -18,6 +18,13 @@ class TemplateLoadError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class TemplateMetadata:
|
||||
"""Represents the parsed metadata from a template."""
|
||||
|
||||
stop_strings: List[str] = []
|
||||
tool_starts: List[str] = []
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
"""A template for chat completion prompts."""
|
||||
|
||||
|
|
@ -27,23 +34,33 @@ class PromptTemplate:
|
|||
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
|
||||
trim_blocks=True, lstrip_blocks=True
|
||||
)
|
||||
metadata: TemplateMetadata
|
||||
|
||||
def stop_strings(self, template_vars: dict):
|
||||
"""Appends extra stop strings if present in a chat template."""
|
||||
def extract_metadata(self):
|
||||
"""Returns deserialized template metadata from a chat template."""
|
||||
|
||||
extra_stop_strings = []
|
||||
template_module = self.template.make_module(template_vars)
|
||||
template_metadata = TemplateMetadata()
|
||||
|
||||
template_module = self.template.make_module()
|
||||
|
||||
if hasattr(template_module, "stop_strings"):
|
||||
if isinstance(template_module.stop_strings, list):
|
||||
extra_stop_strings += template_module.stop_strings
|
||||
template_metadata.stop_strings += template_module.stop_strings
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping append of stopping strings from chat template "
|
||||
"because stop_strings isn't a list."
|
||||
)
|
||||
|
||||
return extra_stop_strings
|
||||
if hasattr(template_module, "tool_start"):
|
||||
if isinstance(template_module.tool_start, str):
|
||||
template_metadata.tool_starts.append(template_module.tool_start)
|
||||
|
||||
if hasattr(template_module, "tool_start_token"):
|
||||
if isinstance(template_module.tool_start_token, int):
|
||||
template_metadata.tool_starts.append(template_module.tool_start_token)
|
||||
|
||||
return template_metadata
|
||||
|
||||
def render(self, template_vars: dict):
|
||||
"""Get a prompt from a template and a list of messages."""
|
||||
|
|
@ -56,9 +73,8 @@ class PromptTemplate:
|
|||
)
|
||||
|
||||
rendered_template = self.template.render(**template_vars)
|
||||
template_stop_strings = self.stop_strings(template_vars)
|
||||
|
||||
return rendered_template, template_stop_strings
|
||||
return rendered_template
|
||||
|
||||
def compile(self, template_str: str):
|
||||
"""Compiles and stores a jinja2 template"""
|
||||
|
|
@ -77,6 +93,7 @@ class PromptTemplate:
|
|||
self.name = name
|
||||
self.raw_template = raw_template
|
||||
self.template = self.compile(raw_template)
|
||||
self.metadata = self.extract_metadata()
|
||||
|
||||
@classmethod
|
||||
def from_file(self, prompt_template_name: str):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,11 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from pydantic.json_schema import SkipJsonSchema
|
||||
from time import time
|
||||
from typing import Union, List, Optional, Dict
|
||||
from uuid import uuid4
|
||||
|
||||
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
|
||||
from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema
|
||||
|
||||
|
||||
class ChatCompletionLogprob(BaseModel):
|
||||
|
|
@ -19,12 +21,16 @@ class ChatCompletionLogprobs(BaseModel):
|
|||
class ChatCompletionMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = None
|
||||
|
||||
|
||||
class ChatCompletionRespChoice(BaseModel):
|
||||
# Index is 0 since we aren't using multiple choices
|
||||
index: int = 0
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
# let's us understand why it stopped and if we need to generate a tool_call
|
||||
stop_str: Optional[str] = None
|
||||
message: ChatCompletionMessage
|
||||
logprobs: Optional[ChatCompletionLogprobs] = None
|
||||
|
||||
|
|
@ -42,12 +48,28 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
|||
# Messages
|
||||
# Take in a string as well even though it's not part of the OAI spec
|
||||
# support messages.content as a list of dict
|
||||
messages: Union[str, List[Dict[str, Union[str, List[Dict[str, str]]]]]]
|
||||
|
||||
# WIP this can probably be tightened, or maybe match the OAI lib type
|
||||
# in openai\types\chat\chat_completion_message_param.py
|
||||
messages: Union[str, List[Dict]]
|
||||
prompt_template: Optional[str] = None
|
||||
add_generation_prompt: Optional[bool] = True
|
||||
template_vars: Optional[dict] = {}
|
||||
response_prefix: Optional[str] = None
|
||||
|
||||
# tools is follows the format OAI schema, functions is more flexible
|
||||
# both are available in the chat template.
|
||||
|
||||
tools: Optional[List[ToolSpec]] = None
|
||||
functions: Optional[List[Dict]] = None
|
||||
|
||||
# Typically collected from Chat Template.
|
||||
# Don't include this in the OpenAPI docs
|
||||
# TODO: Use these custom parameters
|
||||
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = None
|
||||
tool_call_end: SkipJsonSchema[Optional[str]] = None
|
||||
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")
|
||||
|
|
|
|||
58
endpoints/OAI/types/tools.py
Normal file
58
endpoints/OAI/types/tools.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Dict, Literal
|
||||
|
||||
tool_call_schema = {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"arguments": {
|
||||
# Converted to OAI's string in post process
|
||||
"type": "object"
|
||||
},
|
||||
},
|
||||
"required": ["name", "arguments"],
|
||||
},
|
||||
"type": {"type": "string", "enum": ["function"]},
|
||||
},
|
||||
"required": ["id", "function", "type"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""Represents a description of a tool function."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: Dict[str, object]
|
||||
|
||||
|
||||
class ToolSpec(BaseModel):
|
||||
"""Wrapper for an inner tool function."""
|
||||
|
||||
function: Function
|
||||
type: Literal["function"]
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
"""Represents an OAI tool description."""
|
||||
|
||||
name: str
|
||||
|
||||
# Makes more sense to be a dict, but OAI knows best
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Represents an OAI tool description."""
|
||||
|
||||
id: str
|
||||
function: Tool
|
||||
type: Literal["function"]
|
||||
|
|
@ -5,6 +5,7 @@ import pathlib
|
|||
from asyncio import CancelledError
|
||||
from copy import deepcopy
|
||||
from typing import List, Optional
|
||||
import json
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from jinja2 import TemplateError
|
||||
|
|
@ -30,6 +31,7 @@ from endpoints.OAI.types.chat_completion import (
|
|||
)
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
from endpoints.OAI.utils.completion import _stream_collector
|
||||
from endpoints.OAI.types.tools import ToolCall
|
||||
|
||||
|
||||
def _create_response(
|
||||
|
|
@ -46,6 +48,10 @@ def _create_response(
|
|||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
)
|
||||
|
||||
tool_calls = generation["tool_calls"]
|
||||
if tool_calls:
|
||||
message.tool_calls = postprocess_tool_call(tool_calls)
|
||||
|
||||
logprob_response = None
|
||||
|
||||
token_probs = unwrap(generation.get("token_probs"), {})
|
||||
|
|
@ -72,6 +78,7 @@ def _create_response(
|
|||
choice = ChatCompletionRespChoice(
|
||||
index=index,
|
||||
finish_reason=generation.get("finish_reason"),
|
||||
stop_str=generation.get("stop_str"),
|
||||
message=message,
|
||||
logprobs=logprob_response,
|
||||
)
|
||||
|
|
@ -119,7 +126,16 @@ def _create_stream_chunk(
|
|||
finish_reason=generation.get("finish_reason"),
|
||||
)
|
||||
|
||||
# lets check if we have tool calls since we are at the end of the generation
|
||||
if "tool_calls" in generation:
|
||||
tool_calls = generation["tool_calls"]
|
||||
message = ChatCompletionMessage(
|
||||
tool_calls=postprocess_tool_call(tool_calls)
|
||||
)
|
||||
choice.delta = message
|
||||
|
||||
choices.append(choice)
|
||||
|
||||
else:
|
||||
message = ChatCompletionMessage(
|
||||
role="assistant", content=unwrap(generation.get("text"), "")
|
||||
|
|
@ -162,7 +178,29 @@ def _create_stream_chunk(
|
|||
return chunk
|
||||
|
||||
|
||||
def format_prompt_with_template(data: ChatCompletionRequest):
|
||||
def _append_template_metadata(data: ChatCompletionRequest):
|
||||
"""Adding metadata is a one-time process."""
|
||||
|
||||
template_metadata = model.container.prompt_template.metadata
|
||||
|
||||
# Stop strings
|
||||
if isinstance(data.stop, str):
|
||||
data.stop = [data.stop] + template_metadata.stop_strings
|
||||
else:
|
||||
data.stop += template_metadata.stop_strings
|
||||
|
||||
# Tool call start strings
|
||||
if template_metadata.tool_starts:
|
||||
if data.tool_call_start is None:
|
||||
data.tool_call_start = template_metadata.tool_starts
|
||||
|
||||
# Append to stop strings to halt for a tool call generation
|
||||
data.stop.extend(template_metadata.tool_starts)
|
||||
|
||||
|
||||
def format_prompt_with_template(
|
||||
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
Compile the prompt and get any additional stop strings from the template.
|
||||
Template stop strings can be overriden by sampler overrides if force is true.
|
||||
|
|
@ -187,18 +225,22 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
|||
"",
|
||||
)
|
||||
|
||||
if "tool_calls" in message:
|
||||
message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2)
|
||||
|
||||
# Overwrite any protected vars with their values
|
||||
data.template_vars.update(
|
||||
{
|
||||
"messages": data.messages,
|
||||
"add_generation_prompt": data.add_generation_prompt,
|
||||
"tools_json": json.dumps(data.model_dump()["tools"], indent=2),
|
||||
"functions_json": json.dumps(data.functions, indent=2),
|
||||
"tool_precursor": tool_precursor,
|
||||
**special_tokens_dict,
|
||||
}
|
||||
)
|
||||
|
||||
prompt, template_stop_strings = model.container.prompt_template.render(
|
||||
data.template_vars
|
||||
)
|
||||
prompt = model.container.prompt_template.render(data.template_vars)
|
||||
|
||||
# Append response prefix if present
|
||||
if data.response_prefix:
|
||||
|
|
@ -216,11 +258,8 @@ def format_prompt_with_template(data: ChatCompletionRequest):
|
|||
if bos_token and prompt.startswith(bos_token):
|
||||
prompt = prompt.removeprefix(bos_token)
|
||||
|
||||
# Append template stop strings
|
||||
if isinstance(data.stop, str):
|
||||
data.stop = [data.stop] + template_stop_strings
|
||||
else:
|
||||
data.stop += template_stop_strings
|
||||
# Add template metadata
|
||||
_append_template_metadata(data)
|
||||
|
||||
return prompt
|
||||
|
||||
|
|
@ -271,6 +310,9 @@ async def stream_generate_chat_completion(
|
|||
|
||||
gen_tasks.append(gen_task)
|
||||
|
||||
# We need to keep track of the text generated so we can resume the tool calls
|
||||
current_generation_text = ""
|
||||
|
||||
# Consumer loop
|
||||
while True:
|
||||
if disconnect_task.done():
|
||||
|
|
@ -280,6 +322,19 @@ async def stream_generate_chat_completion(
|
|||
)
|
||||
|
||||
generation = await gen_queue.get()
|
||||
# lets only append the text if we need it for tool calls later
|
||||
if data.tool_call_start and "text" in generation:
|
||||
current_generation_text += generation["text"]
|
||||
|
||||
# check if we are running a tool model, and that we are at stop
|
||||
if data.tool_call_start and "stop_str" in generation:
|
||||
generations = await generate_tool_calls(
|
||||
data,
|
||||
[generation],
|
||||
request,
|
||||
current_generations=current_generation_text,
|
||||
)
|
||||
generation = generations[0] # We only have one generation in this case
|
||||
|
||||
# Stream collector will push an exception to the queue if it fails
|
||||
if isinstance(generation, Exception):
|
||||
|
|
@ -344,6 +399,11 @@ async def generate_chat_completion(
|
|||
)
|
||||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
|
||||
# Let's not waste our time if we arn't running a tool model
|
||||
if data.tool_call_start:
|
||||
generations = await generate_tool_calls(data, generations, request)
|
||||
|
||||
response = _create_response(request.state.id, generations, model_path.name)
|
||||
|
||||
logger.info(f"Finished chat completion request {request.state.id}")
|
||||
|
|
@ -358,3 +418,54 @@ async def generate_chat_completion(
|
|||
|
||||
# Server error if there's a generation exception
|
||||
raise HTTPException(503, error_message) from exc
|
||||
|
||||
|
||||
async def generate_tool_calls(
|
||||
data: ChatCompletionRequest,
|
||||
generations: List[str],
|
||||
request: Request,
|
||||
current_generations: str = None,
|
||||
):
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
tool_idx: List[int] = []
|
||||
|
||||
# Copy to make sure the parent JSON schema doesn't get modified
|
||||
# FIXME: May not be necessary depending on how the codebase evolves
|
||||
tool_data = deepcopy(data)
|
||||
tool_data.json_schema = tool_data.tool_call_schema
|
||||
gen_params = tool_data.to_gen_params()
|
||||
|
||||
for idx, gen in enumerate(generations):
|
||||
if gen["stop_str"] in tool_data.tool_call_start:
|
||||
if "text" in gen:
|
||||
# non streaming, all generations will have the text they generated
|
||||
pre_tool_prompt = format_prompt_with_template(data, gen["text"])
|
||||
elif current_generations is not None:
|
||||
# streaming, we wont have text in the generation,
|
||||
# we'll have to use the current_generations
|
||||
pre_tool_prompt = format_prompt_with_template(data, current_generations)
|
||||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
pre_tool_prompt, request.state.id, **gen_params
|
||||
)
|
||||
)
|
||||
)
|
||||
tool_idx.append(idx)
|
||||
|
||||
tool_calls = await asyncio.gather(*gen_tasks)
|
||||
for outer_idx in range(0, len(tool_idx)):
|
||||
gen_idx = tool_idx[outer_idx]
|
||||
generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"]
|
||||
|
||||
return generations
|
||||
|
||||
|
||||
def postprocess_tool_call(call_str: str) -> List[ToolCall]:
|
||||
tool_calls = json.loads(call_str)
|
||||
for tool_call in tool_calls:
|
||||
tool_call["function"]["arguments"] = json.dumps(
|
||||
tool_call["function"]["arguments"]
|
||||
)
|
||||
return [ToolCall(**tool_call) for tool_call in tool_calls]
|
||||
|
|
|
|||
84
templates/chatml_with_headers_tool_calling.jinja
Normal file
84
templates/chatml_with_headers_tool_calling.jinja
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
{# Metadata #}
|
||||
{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %}
|
||||
{% set message_roles = ['system', 'user', 'assistant', 'tool'] %}
|
||||
{% set tool_start = "<|tool_start|>" %}
|
||||
{% set tool_end = "<|tool_end|>" %}
|
||||
{%- set start_header = "<|start_header_id|>" -%}
|
||||
{%- set end_header = "<|end_header_id|>\n" -%}
|
||||
|
||||
{%- set example_tool_call -%}[
|
||||
{
|
||||
"id": "tool_id_1342",
|
||||
"function": {
|
||||
"arguments": "arg_name": 3,
|
||||
"name": "tool_name"
|
||||
},
|
||||
"type": "function"
|
||||
},
|
||||
{
|
||||
"id": "example_id_13f42",
|
||||
"function": {
|
||||
"arguments": "example_arg": 1.0, "another_example_arg": true,
|
||||
"name": "another_tool_name"
|
||||
},
|
||||
"type": "function"
|
||||
}
|
||||
]
|
||||
{%- endset -%}
|
||||
|
||||
{%- set inital_system_prompt -%}You are an assistant that has access to the following set of tools, to call a tool:
|
||||
1. Prefix calls with '{{ tool_start }}' and end calls with '{{ tool_end }}'
|
||||
2. Ensure you use the correct type for arguments. For example, if the argument is a string, ensure it is enclosed in quotes, otherwise, it should not be.
|
||||
3. Generate all calls using the following json tool call format. Here is a multi tool call example:
|
||||
|
||||
{{ tool_start }}{{ example_tool_call }}{{ tool_end }}
|
||||
|
||||
Here are the tools available for you to call:
|
||||
{{ tools_json }}
|
||||
{%- endset -%}
|
||||
|
||||
{%- set tool_reminder -%}Available Tools:
|
||||
{{ tools_json }}
|
||||
|
||||
Tool Call Format Example:
|
||||
{{ tool_start }}{{ example_tool_call }}
|
||||
|
||||
Prefix & Suffix: Begin tool calls with {{ tool_start }} and end with {{ tool_end }}.
|
||||
Argument Types: Use correct data types for arguments (e.g., strings in quotes, numbers without).
|
||||
{%- endset -%}
|
||||
|
||||
{# Template #}
|
||||
|
||||
{%- for message in messages -%}
|
||||
{%- set role = message['role'] | lower -%}
|
||||
{%- if role not in message_roles -%}
|
||||
{{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles | join(', ') + ' are supported.') }}
|
||||
{%- endif -%}
|
||||
|
||||
{%- set content = message['content'] | default('', true) | trim -%}
|
||||
{%- if loop.first -%}
|
||||
{{ bos_token }}{{ start_header }}{{ role }}{{ end_header }}
|
||||
{{ inital_system_prompt }}
|
||||
|
||||
{{ content }}{{ eos_token }}
|
||||
{%- endif -%}
|
||||
|
||||
{%- if not loop.first -%}
|
||||
{{ start_header }}{{ role }}{{ end_header }}
|
||||
{{ content }}
|
||||
{%- if 'tool_calls_json' in message and message['tool_calls_json'] -%}
|
||||
{{ tool_start }}{{ message['tool_calls_json']}}{{ tool_end }}
|
||||
{%- endif -%}
|
||||
{{ eos_token }}
|
||||
|
||||
{%- endif -%}
|
||||
{%- endfor -%}
|
||||
|
||||
{%- if tool_precursor -%}
|
||||
{{ start_header }}system{{ end_header }}
|
||||
{{ tool_reminder }}{{ eos_token }}
|
||||
{{ start_header }}assistant{{ end_header }}
|
||||
{{ tool_precursor }}{{ tool_start }}
|
||||
{%- else -%}
|
||||
{{ start_header }}assistant{{ end_header }}
|
||||
{%- endif -%}
|
||||
Loading…
Add table
Add a link
Reference in a new issue