Merge pull request #364 from theroyallab/tool-calls

Streamline tool calling
This commit is contained in:
Brian 2025-07-11 11:34:10 -04:00 committed by GitHub
commit 2419d2d0a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 155 additions and 205 deletions

View file

@ -4,6 +4,8 @@ import traceback
import aiofiles
import json
import pathlib
from dataclasses import dataclass, field
from datetime import datetime
from importlib.metadata import version as package_version
from typing import List, Optional
from jinja2 import Template, TemplateError
@ -11,7 +13,6 @@ from jinja2.ext import loopcontrols
from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
from packaging import version
from datetime import datetime
from common.utils import unwrap
@ -23,11 +24,12 @@ class TemplateLoadError(Exception):
pass
@dataclass
class TemplateMetadata:
"""Represents the parsed metadata from a template."""
stop_strings: List[str] = []
tool_starts: List[str] = []
stop_strings: List[str] = field(default_factory=list)
tool_start: Optional[str] = None
class PromptTemplate:
@ -72,11 +74,7 @@ class PromptTemplate:
if hasattr(template_module, "tool_start"):
if isinstance(template_module.tool_start, str):
template_metadata.tool_starts.append(template_module.tool_start)
if hasattr(template_module, "tool_start_token"):
if isinstance(template_module.tool_start_token, int):
template_metadata.tool_starts.append(template_module.tool_start_token)
template_metadata.tool_start = template_module.tool_start
self.metadata = template_metadata
return template_metadata

View file

@ -32,7 +32,7 @@ For example, if you are using a Llama 3.1 Family model you can simply modify you
```yaml
model:
...
prompt_template: chatml_with_headers_tool_calling
prompt_template: tool_calls/chatml_with_headers
```
If loading via `/v1/model/load`, you would also need to specify a tool-supporting `prompt_template`.
@ -40,7 +40,6 @@ If loading via `/v1/model/load`, you would also need to specify a tool-supportin
## Tool Template Variables
- `tools`: Tools object.
- `tools_json`: Tools object as a JSON string.
## Creating a Tool Calling Prompt Template
@ -56,29 +55,36 @@ Here's how to create a TabbyAPI tool calling prompt template:
```jinja
{# Metadata #}
{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %}
{% set message_roles = ['system', 'user', 'assistant', 'tool'] %}
{% set tool_start = "<|tool_start|>" %}
{% set tool_end = "<|tool_end|>" %}
{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%}
{%- set tool_start = "<|tool_start|>" -%}
{# Optional Metadata #}
{%- set tool_end = "<|tool_end|>" -%}
```
`tool_start` and `tool_end` should be selected based on which model you decide to use. For example, [Groq's Tool calling models](https://huggingface.co/Groq/Llama-3-Groq-70B-Tool-Use) expects `<tool_call>` and `</tool_call>` while [Llama3 FireFunctionV2's](https://huggingface.co/fireworks-ai/llama-3-firefunction-v2) model expects only `functools` to start the call, without a `tool_end`
2. Define an `initial_system_prompt`:
While the name of your `inital_system_prompt` can vary, it's purpose does not. This inital prompt is typically a simple instruction set followed by accessing the `tools_json` variable. This will contain the function specification the user provided to the `tools` endpoint in their client when the chat completion request. Inside the template we can call this like so: `{{ tools_json }}`.
While the name of your `inital_system_prompt` can vary, it's purpose does not. This initial prompt is typically a simple instruction set followed by accessing the `tools` template variable.
Note: Depending on the model you are using, it's possible your model may expect a special set of tokens to surround the function specifications. Feel free to surround `tools_json` with these tokens.
This will contain the function specification the user provided to the `tools` endpoint in their client when the chat completion request. Inside the template we can call this like so: `{{ tools | tojson }}`.
> [!NOTE]
> Depending on the model you are using, it's possible your model may expect a special set of tokens to surround the function specifications. Feel free to surround `tools_json` with these tokens.
> [!NOTE]
> To get a JSON representation of the tools variable, use `| tojson(indent=2)` in the assignment
```jinja
{% set initial_system_prompt %}
Your instructions here...
Available functions:
{{ tools_json }}
{{ tools | tojson(indent=2) }}
{% endset %}
```
You'll then want to make sure to provide this to the model in the first message it recieves. Here is a simple example:
You'll then want to make sure to provide this to the model in the first message it receives. Here is a simple example:
```jinja
{%- if loop.first -%}
@ -88,11 +94,11 @@ Here's how to create a TabbyAPI tool calling prompt template:
{{ content }}{{ eos_token }}
```
3. Handle messages with the `tool` role:
4. Handle messages with the `tool` role:
After a tool call is made, a *well behaved* client will respond to the model with a new message containing the role `tool`. This is a response to a tool call containing the results of it's execution.
The simplest implementation of this will be to ensure your `message_roles` list within your prompt template contains `tool`. Further customization may be required for models that expect specific tokens surrounding tool reponses. An example of this customization is the Groq family of models from above. They expect special tokens surrounding their tool responses such as:
The simplest implementation of this will be to ensure your `message_roles` list within your prompt template contains `tool`. Further customization may be required for models that expect specific tokens surrounding tool responses. An example of this customization is the Groq family of models from above. They expect special tokens surrounding their tool responses such as:
```jinja
{% if role == 'tool' %}
@ -100,51 +106,29 @@ Here's how to create a TabbyAPI tool calling prompt template:
{% endif %}
```
4. Preserve tool calls from prior messages:
5. Preserve tool calls from prior messages:
When creating a tool calling `prompt_template`, ensure you handle previous tool calls from the model gracefully. Each `message` object within `messages` exposed within the `prompt_template` could also contain `tool_calls_json`. This field will contain tool calls made by the assistant in previous turns, and must be handled appropriatly so that the model understands what previous actions it has taken (and can properly identify what tool response ID belongs to which call).
When creating a tool calling `prompt_template`, ensure you handle previous tool calls from the model gracefully. Each `message` object within `messages` exposed within the `prompt_template` could also contain `tool_calls`.
This will require using the `tool_start` (and possibly `tool_end`) from above to wrap the `tool_call_json` like so:
This field will contain tool calls made by the assistant in previous turns, and must be handled appropriately so that the model understands what previous actions it has taken (and can properly identify what tool response ID belongs to which call).
This will require using the `tool_start` (and possibly `tool_end`) from above to wrap the `tool_call` object like so:
```jinja
{% if 'tool_calls_json' in message and message['tool_calls_json'] %}
{{ tool_start }}{{ message['tool_calls_json'] }}{{ tool_end }}
{% if 'tool_calls' in message and message['tool_calls'] %}
{{ tool_start }}{{ message['tool_calls'] | tojson(indent=2) }}{{ tool_end }}
{% endif %}
```
5. Handle tool call generation:
6. Add the generation prompt check at the end:
```jinja
{% set tool_reminder %}
Available Tools:
{{ tools_json }}
Tool Call Format Example:
{{ tool_start }}{{ example_tool_call }}
Prefix & Suffix: Begin tool calls with {{ tool_start }} and end with {{ tool_end }}.
Argument Types: Use correct data types for arguments (e.g., strings in quotes, numbers without).
{% endset %}
{% if tool_precursor %}
{{ start_header }}system{{ end_header }}
{{ tool_reminder }}{{ eos_token }}
{{ start_header }}assistant{{ end_header }}
{{ tool_precursor }}{{ tool_start }}
{% else %}
{% if add_generation_prompt %}
{{ start_header }}assistant{{ end_header }}
{% endif %}
```
This clever bit of temporal manipulation allows us to slip in a reminder as a system message right before the model generates a tool call, but after it writes the `tool_start` token. This is possible due to TabbyAPI revisitng the `prompt_template` after a `tool_start` token is detected. Here's how it works:
- We detect `tool_precursor`, which signals the model is about to generate a tool call.
- We then inject a system message with our `tool_reminder`.
- Finally, we initialize an assistant message using `tool_precursor` as the content.
This creates the illusion that the model just happened to remember the available tools and proper formatting right before generating the tool call. It's like giving the model a little nudge at exactly the right moment, enhancing its performance without altering what the user sees.
When creating your own tool calling `prompt_template`, it's best to reference the default `chatml_with_headers_tool_calling.jinja` template as a starting point.
> [!NOTE]
> When creating your own tool calling template, it's best to reference the default `chatml_with_headers` template as a starting point.
## Support and Bug Reporting
For bugs, please create a detailed issue with the model, prompt template, and conversation that caused it. Alternatively, join our [Discord](https://discord.gg/sYQxnuD7Fj) and ask for Storm.

View file

@ -1,11 +1,10 @@
from pydantic import AliasChoices, BaseModel, Field, field_validator
from pydantic.json_schema import SkipJsonSchema
from time import time
from typing import Literal, Union, List, Optional, Dict
from uuid import uuid4
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema
from endpoints.OAI.types.tools import ToolSpec, ToolCall
class ChatCompletionLogprob(BaseModel):
@ -32,7 +31,7 @@ class ChatCompletionMessage(BaseModel):
role: str = "user"
content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None
tool_calls: Optional[List[ToolCall]] = None
tool_calls_json: SkipJsonSchema[Optional[str]] = None
tool_call_id: Optional[str] = None
class ChatCompletionRespChoice(BaseModel):
@ -56,7 +55,7 @@ class ChatCompletionStreamChoice(BaseModel):
# Inherited from common request
class ChatCompletionRequest(CommonCompletionRequest):
messages: List[ChatCompletionMessage] = Field(default_factory=list)
messages: List[ChatCompletionMessage]
prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = Field(
@ -73,15 +72,6 @@ class ChatCompletionRequest(CommonCompletionRequest):
tools: Optional[List[ToolSpec]] = None
functions: Optional[List[Dict]] = None
# Typically collected from Chat Template.
# Don't include this in the OpenAPI docs
# TODO: Use these custom parameters
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = Field(
default_factory=list
)
tool_call_end: SkipJsonSchema[Optional[str]] = None
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
# Chat completions requests do not have a BOS token preference. Backend
# respects the tokenization config for the individual model.
add_bos_token: Optional[bool] = None

View file

@ -1,29 +1,6 @@
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing import Dict, Literal
tool_call_schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
# Converted to OAI's string in post process
"type": "object"
},
},
"required": ["name", "arguments"],
},
"type": {"type": "string", "enum": ["function"]},
},
"required": ["id", "function", "type"],
},
}
from uuid import uuid4
class Function(BaseModel):
@ -53,6 +30,6 @@ class Tool(BaseModel):
class ToolCall(BaseModel):
"""Represents an OAI tool description."""
id: str
id: str = Field(default_factory=lambda: str(uuid4()).replace("-", "")[:9])
function: Tool
type: Literal["function"]
type: Literal["function"] = "function"

View file

@ -1,7 +1,6 @@
"""Chat completion utilities for OAI server."""
import asyncio
import json
import pathlib
from asyncio import CancelledError
from typing import List, Optional
@ -30,7 +29,7 @@ from endpoints.OAI.types.chat_completion import (
)
from endpoints.OAI.types.common import UsageStats
from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector
from endpoints.OAI.utils.tools import ToolCallProcessor
from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA
def _create_response(
@ -71,12 +70,11 @@ def _create_response(
logprob_response = ChatCompletionLogprobs(content=collected_token_probs)
# Initialize finish_reason with a default value or from generation data
finish_reason = generation.get("finish_reason", "stop")
# If a tool call is present, mark the finish reason as such
# Set finish reason
if message.tool_calls:
finish_reason = "tool_calls"
else:
finish_reason = generation.get("finish_reason", "stop")
choice = ChatCompletionRespChoice(
index=index,
@ -153,7 +151,6 @@ def _create_stream_chunk(
choice.finish_reason = "tool_calls"
choices.append(choice)
else:
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
@ -207,15 +204,11 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars:
if isinstance(data.stop, str):
data.stop = [data.stop] + template_metadata.stop_strings
else:
data.stop += template_metadata.stop_strings
data.stop.extend(template_metadata.stop_strings)
# Tool call start strings
if template_metadata.tool_starts:
if data.tool_call_start is None:
data.tool_call_start = template_metadata.tool_starts
# Append to stop strings to halt for a tool call generation
data.stop.extend(template_metadata.tool_starts)
# if a tool start is present, append it to stopping strings
if template_metadata.tool_start:
data.stop.append(template_metadata.tool_start)
async def format_messages_with_template(
@ -242,13 +235,6 @@ async def format_messages_with_template(
# Convert the message content into a concatenated string
message.content = concatenated_content
if message.tool_calls:
message.tool_calls_json = ToolCallProcessor.to_json(message.tool_calls)
# The tools variable is inspectable in the template, so
# store the list of dicts rather than the ToolCallProcessor object.
message.tool_calls = ToolCallProcessor.dump(message.tool_calls)
message_dicts.append(message.model_dump(exclude_none=True))
# Get all special tokens
@ -260,9 +246,7 @@ async def format_messages_with_template(
return prompt, mm_embeddings, template_vars
async def apply_chat_template(
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
):
async def apply_chat_template(data: ChatCompletionRequest):
"""
Compile the prompt and get any additional stop strings from the template.
Template stop strings can be overriden by sampler overrides if force is true.
@ -276,10 +260,7 @@ async def apply_chat_template(
{
"add_generation_prompt": data.add_generation_prompt,
"tools": tools,
"tools_json": json.dumps(tools, indent=2),
"functions": data.functions,
"functions_json": json.dumps(data.functions, indent=2),
"tool_precursor": tool_precursor,
}
)
@ -337,6 +318,7 @@ async def stream_generate_chat_completion(
abort_event = asyncio.Event()
gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = []
tool_start = model.container.prompt_template.metadata.tool_start
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try:
@ -360,7 +342,7 @@ async def stream_generate_chat_completion(
gen_tasks.append(gen_task)
# We need to keep track of the text generated so we can resume the tool calls
# Text accumulation for tool calls
current_generation_text = ""
# Consumer loop
@ -372,19 +354,23 @@ async def stream_generate_chat_completion(
)
generation = await gen_queue.get()
# lets only append the text if we need it for tool calls later
if data.tool_call_start and "text" in generation:
current_generation_text += generation["text"]
# check if we are running a tool model, and that we are at stop
if data.tool_call_start and "stop_str" in generation:
generations = await generate_tool_calls(
data,
[generation],
request,
current_generations=current_generation_text,
)
generation = generations[0] # We only have one generation in this case
# Handle options if a tool model is present
if tool_start:
if "stop_str" in generation:
generations = await generate_tool_calls(
prompt,
embeddings,
data,
[generation],
request,
current_generation_text=current_generation_text,
)
# Only one generation present in this case
generation = generations[0]
elif "text" in generation:
current_generation_text += generation["text"]
# Stream collector will push an exception to the queue if it fails
if isinstance(generation, Exception):
@ -433,6 +419,7 @@ async def generate_chat_completion(
model_path: pathlib.Path,
):
gen_tasks: List[asyncio.Task] = []
tool_start = model.container.prompt_template.metadata.tool_start
try:
logger.info(f"Received chat completion request {request.state.id}")
@ -453,9 +440,11 @@ async def generate_chat_completion(
generations = await asyncio.gather(*gen_tasks)
# Let's not waste our time if we arn't running a tool model
if data.tool_call_start:
generations = await generate_tool_calls(data, generations, request)
# Check all the generations and see if a tool call is required
if tool_start:
generations = await generate_tool_calls(
prompt, embeddings, data, generations, request
)
response = _create_response(request.state.id, generations, model_path.name)
@ -474,54 +463,55 @@ async def generate_chat_completion(
async def generate_tool_calls(
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
data: ChatCompletionRequest,
generations: List[str],
request: Request,
current_generations: str = None,
current_generation_text: str = None,
):
gen_tasks: List[asyncio.Task] = []
tool_start = model.container.prompt_template.metadata.tool_start
# Tracks which generations asked for a tool call
tool_idx: List[int] = []
# Copy to make sure the parent JSON schema doesn't get modified
# FIXME: May not be necessary depending on how the codebase evolves
tool_data = data.model_copy(deep=True)
tool_data.json_schema = tool_data.tool_call_schema
tool_data.json_schema = TOOL_CALL_SCHEMA
for idx, gen in enumerate(generations):
if gen["stop_str"] in tool_data.tool_call_start:
logger.info(
f"Detected tool call in chat completion request {request.state.id}"
)
if gen["stop_str"] != tool_start:
continue
if "text" in gen:
# non streaming, all generations will have the text they generated
pre_tool_prompt, embeddings = await apply_chat_template(
data, gen["text"]
)
elif current_generations is not None:
# streaming, we wont have text in the generation,
# we'll have to use the current_generations
pre_tool_prompt, embeddings = await apply_chat_template(
data, current_generations
)
logger.info(f"Detected tool call in chat completion request {request.state.id}")
request_id = _parse_gen_request_id(data.n, request.state.id, idx)
# Append the existing generation text if present
precursor_text = current_generation_text or gen.get("text")
if precursor_text:
prompt = prompt + precursor_text
gen_tasks.append(
asyncio.create_task(
model.container.generate(
request_id,
pre_tool_prompt,
tool_data,
mm_embeddings=embeddings,
)
gen_request_id = _parse_gen_request_id(data.n, request.state.id, idx)
tool_request_id = f"{gen_request_id}-tool"
gen_tasks.append(
asyncio.create_task(
model.container.generate(
tool_request_id,
prompt,
tool_data,
mm_embeddings=embeddings,
)
)
tool_idx.append(idx)
)
tool_calls = await asyncio.gather(*gen_tasks)
for outer_idx in range(0, len(tool_idx)):
gen_idx = tool_idx[outer_idx]
generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"]
tool_idx.append(idx)
if len(tool_idx) > 0:
tool_calls = await asyncio.gather(*gen_tasks)
# Map tool calls to their appropriate generation
for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True):
generations[gen_idx]["tool_calls"] = tool_call["text"]
return generations

View file

@ -5,6 +5,29 @@ from typing import List
from endpoints.OAI.types.tools import ToolCall
TOOL_CALL_SCHEMA = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "array",
"items": {
"type": "object",
"properties": {
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
# Converted to OAI's string in post process
"type": "object"
},
},
"required": ["name", "arguments"],
},
},
"required": ["function"],
},
}
class ToolCallProcessor:
@staticmethod
def from_json(tool_calls_str: str) -> List[ToolCall]:

View file

@ -1,12 +1,14 @@
{# Metadata #}
{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%}
{%- set message_roles = ['system', 'user', 'assistant', 'tool'] -%}
{%- set tool_start = "<|tool_start|>" -%}
{# Variables #}
{%- set message_roles = ['system', 'user', 'assistant', 'tool'] -%}
{%- set tool_end = "<|tool_end|>" -%}
{%- set start_header = "<|start_header_id|>" -%}
{%- set end_header = "<|end_header_id|>\n" -%}
{%- set example_tool_call = '[
{%- set example_tool_call -%}[
{
"id": "tool_id_1342",
"function": {
@ -23,28 +25,19 @@
},
"type": "function"
}
]' -%}
]
{%- endset -%}
{%- set inital_system_prompt = 'You are an assistant that has access to the following set of tools, to call a tool:
1. Prefix calls with ' + tool_start + ' and end calls with ' + tool_end + '
{%- set inital_system_prompt -%}You are an assistant that has access to the following set of tools, to call a tool:
1. Prefix calls with '{{ tool_start }}' and end calls with '{{ tool_end }}'
2. Ensure you use the correct type for arguments. For example, if the argument is a string, ensure it is enclosed in quotes, otherwise, it should not be.
3. Generate all calls using the following json tool call format. Here is a multi tool call example:
' + tool_start + example_tool_call + tool_end + '
{{ tool_start }}{{ example_tool_call }}{{ tool_end }}
Here are the tools available for you to call:
' + tools_json -%}
{%- set tool_reminder = 'Available Tools:
' + tools_json + '
Tool Call Format Example:
' + tool_start + example_tool_call + '
Prefix & Suffix: Begin tool calls with ' + tool_start + ' and end with ' + tool_end + '.
Argument Types: Use correct data types for arguments (e.g., strings in quotes, numbers without).' -%}
{# Template #}
{{ tools | tojson(indent=2) }}
{%- endset -%}
{%- for message in messages -%}
{%- set role = message['role'] | lower -%}
@ -54,28 +47,23 @@ Argument Types: Use correct data types for arguments (e.g., strings in quotes, n
{%- set content = message['content'] if message['content'] is defined else '' | trim -%}
{%- if loop.first -%}
{{ bos_token }}{{ start_header }}{{ role }}{{ end_header }}
{{ inital_system_prompt }}
{{- bos_token }}{{ start_header }}{{ role }}{{ end_header }}
{{- inital_system_prompt + "\n\n" }}
{{ content }}{{ eos_token }}
{{- content }}{{ eos_token }}
{%- endif -%}
{%- if not loop.first -%}
{{ start_header }}{{ role }}{{ end_header }}
{{ content }}
{%- if 'tool_calls_json' in message and message['tool_calls_json'] -%}
{{ tool_start }}{{ message['tool_calls_json']}}{{ tool_end }}
{{- start_header }}{{ role }}{{ end_header }}
{{- content }}
{%- if 'tool_calls' in message and message['tool_calls'] -%}
{{- tool_start }}{{ message['tool_calls'] | tojson(indent=2) }}{{ tool_end }}
{%- endif -%}
{{ eos_token }}
{{- eos_token }}
{%- endif -%}
{%- endfor -%}
{%- if tool_precursor -%}
{{ start_header }}system{{ end_header }}
{{ tool_reminder }}{{ eos_token }}
{{ start_header }}assistant{{ end_header }}
{{ tool_precursor }}{{ tool_start }}
{%- else -%}
{{ start_header }}assistant{{ end_header }}
{%- endif -%}
{%- if add_generation_prompt %}
{{- start_header }}assistant{{ end_header }}
{%- endif %}