Support more common tool variables in templates (tools, message.tool_calls) (#308)
* Add non-JSON version of `tools` and `functions` to `template_vars`. Increase the compatibility with VLLM templates which use a non-JSON tools object. * Add list of tool template variables to the documentation * Use Jinja templates to provide `tools_json` and `functions_json` This should be functionally equivelant, but the JSON won't be produced unless it's needed. * Make message.tool_calls match the JSON from ToolCallProcessor * Log something when generating tool calls * Add template for Qwen QwQ 32b * Only log if tool calls have been detected * API: Fix tool call variable assignments Jinja functions do not run when variables are called. Use json.dumps instead. In addition, log the request ID when stating that a tool call was fired. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> * Add `ToolCallProcessor.dump()` to get the list of processed dicts * Remove qwen_qwq_32b.jinja This will be added to the following repository at a later date: https://github.com/theroyallab/llm-prompt-templates --------- Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com> Co-authored-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
ccf23243c1
commit
436ce752da
3 changed files with 44 additions and 10 deletions
|
|
@ -37,6 +37,11 @@ For example, if you are using a Llama 3.1 Family model you can simply modify you
|
||||||
|
|
||||||
If loading via `/v1/model/load`, you would also need to specify a tool-supporting `prompt_template`.
|
If loading via `/v1/model/load`, you would also need to specify a tool-supporting `prompt_template`.
|
||||||
|
|
||||||
|
## Tool Template Variables
|
||||||
|
|
||||||
|
- `tools`: Tools object.
|
||||||
|
- `tools_json`: Tools object as a JSON string.
|
||||||
|
|
||||||
## Creating a Tool Calling Prompt Template
|
## Creating a Tool Calling Prompt Template
|
||||||
|
|
||||||
Here's how to create a TabbyAPI tool calling prompt template:
|
Here's how to create a TabbyAPI tool calling prompt template:
|
||||||
|
|
@ -142,4 +147,4 @@ When creating your own tool calling `prompt_template`, it's best to reference th
|
||||||
|
|
||||||
## Support and Bug Reporting
|
## Support and Bug Reporting
|
||||||
|
|
||||||
For bugs, please create a detailed issue with the model, prompt template, and conversation that caused it. Alternatively, join our [Discord](https://discord.gg/sYQxnuD7Fj) and ask for Storm.
|
For bugs, please create a detailed issue with the model, prompt template, and conversation that caused it. Alternatively, join our [Discord](https://discord.gg/sYQxnuD7Fj) and ask for Storm.
|
||||||
|
|
|
||||||
|
|
@ -234,6 +234,10 @@ async def format_messages_with_template(
|
||||||
if message.tool_calls:
|
if message.tool_calls:
|
||||||
message.tool_calls_json = ToolCallProcessor.to_json(message.tool_calls)
|
message.tool_calls_json = ToolCallProcessor.to_json(message.tool_calls)
|
||||||
|
|
||||||
|
# The tools variable is inspectable in the template, so
|
||||||
|
# store the list of dicts rather than the ToolCallProcessor object.
|
||||||
|
message.tool_calls = ToolCallProcessor.dump(message.tool_calls)
|
||||||
|
|
||||||
special_tokens_dict = model.container.get_special_tokens(
|
special_tokens_dict = model.container.get_special_tokens(
|
||||||
add_bos_token, ban_eos_token
|
add_bos_token, ban_eos_token
|
||||||
)
|
)
|
||||||
|
|
@ -252,11 +256,16 @@ async def apply_chat_template(
|
||||||
Template stop strings can be overriden by sampler overrides if force is true.
|
Template stop strings can be overriden by sampler overrides if force is true.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# Locally store tools dict
|
||||||
|
tools = data.model_dump()["tools"]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data.template_vars.update(
|
data.template_vars.update(
|
||||||
{
|
{
|
||||||
"add_generation_prompt": data.add_generation_prompt,
|
"add_generation_prompt": data.add_generation_prompt,
|
||||||
"tools_json": json.dumps(data.model_dump()["tools"], indent=2),
|
"tools": tools,
|
||||||
|
"tools_json": json.dumps(tools, indent=2),
|
||||||
|
"functions": data.functions,
|
||||||
"functions_json": json.dumps(data.functions, indent=2),
|
"functions_json": json.dumps(data.functions, indent=2),
|
||||||
"tool_precursor": tool_precursor,
|
"tool_precursor": tool_precursor,
|
||||||
}
|
}
|
||||||
|
|
@ -460,6 +469,10 @@ async def generate_tool_calls(
|
||||||
|
|
||||||
for idx, gen in enumerate(generations):
|
for idx, gen in enumerate(generations):
|
||||||
if gen["stop_str"] in tool_data.tool_call_start:
|
if gen["stop_str"] in tool_data.tool_call_start:
|
||||||
|
logger.info(
|
||||||
|
f"Detected tool call in chat completion request {request.state.id}"
|
||||||
|
)
|
||||||
|
|
||||||
if "text" in gen:
|
if "text" in gen:
|
||||||
# non streaming, all generations will have the text they generated
|
# non streaming, all generations will have the text they generated
|
||||||
pre_tool_prompt, mm_embeddings = await apply_chat_template(
|
pre_tool_prompt, mm_embeddings = await apply_chat_template(
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,28 @@ class ToolCallProcessor:
|
||||||
|
|
||||||
return [ToolCall(**tool_call) for tool_call in tool_calls]
|
return [ToolCall(**tool_call) for tool_call in tool_calls]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def dump(tool_calls: List[ToolCall]) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Convert ToolCall objects to a list of dictionaries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_calls (List[ToolCall]): List of ToolCall objects to convert
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[dict]: List of dictionaries representing the tool calls
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Don't use list comprehension here
|
||||||
|
# as that will fail rather than warn
|
||||||
|
dumped_tool_calls = []
|
||||||
|
for tool_call_obj in tool_calls:
|
||||||
|
try:
|
||||||
|
dumped_tool_calls.append(tool_call_obj.model_dump())
|
||||||
|
except (json.JSONDecodeError, AttributeError) as e:
|
||||||
|
logger.warning(f"Error processing tool call: {e}")
|
||||||
|
return dumped_tool_calls
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_json(tool_calls: List[ToolCall]) -> str:
|
def to_json(tool_calls: List[ToolCall]) -> str:
|
||||||
"""
|
"""
|
||||||
|
|
@ -33,14 +55,8 @@ class ToolCallProcessor:
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
# Don't use list comprehension here
|
# Use the dump method to get the list of dictionaries
|
||||||
# as that will fail rather than warn
|
dumped_tool_calls = ToolCallProcessor.dump(tool_calls)
|
||||||
dumped_tool_calls = []
|
|
||||||
for tool_call_obj in tool_calls:
|
|
||||||
try:
|
|
||||||
dumped_tool_calls.append(tool_call_obj.model_dump())
|
|
||||||
except (json.JSONDecodeError, AttributeError) as e:
|
|
||||||
logger.warning(f"Error processing tool call: {e}")
|
|
||||||
|
|
||||||
# Serialize the dumped array
|
# Serialize the dumped array
|
||||||
return json.dumps(dumped_tool_calls, indent=2)
|
return json.dumps(dumped_tool_calls, indent=2)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue