API: Modify tool calling for wider compat

When revisiting tool calls, the formats have more or less become standard.
For greater compatibility with templates, primarily use the message.tools
parameter and remove the extra custom metadata that is no longer required.

However, unlike other backends, tabbyAPI still uses template metadata
to declare what the tool start string is. This allows for template-level
customization along with giving more power to the user while the server
exists to consume rather than work on a case-by-case basis.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-07-05 14:28:12 -04:00
parent b6a26da50c
commit 879f4cee7e
5 changed files with 89 additions and 103 deletions

View file

@ -4,6 +4,8 @@ import traceback
import aiofiles import aiofiles
import json import json
import pathlib import pathlib
from dataclasses import dataclass, field
from datetime import datetime
from importlib.metadata import version as package_version from importlib.metadata import version as package_version
from typing import List, Optional from typing import List, Optional
from jinja2 import Template, TemplateError from jinja2 import Template, TemplateError
@ -11,7 +13,6 @@ from jinja2.ext import loopcontrols
from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger from loguru import logger
from packaging import version from packaging import version
from datetime import datetime
from common.utils import unwrap from common.utils import unwrap
@ -23,11 +24,12 @@ class TemplateLoadError(Exception):
pass pass
@dataclass
class TemplateMetadata: class TemplateMetadata:
"""Represents the parsed metadata from a template.""" """Represents the parsed metadata from a template."""
stop_strings: List[str] = [] stop_strings: List[str] = field(default_factory=list)
tool_starts: List[str] = [] tool_start: Optional[str] = None
class PromptTemplate: class PromptTemplate:
@ -72,11 +74,7 @@ class PromptTemplate:
if hasattr(template_module, "tool_start"): if hasattr(template_module, "tool_start"):
if isinstance(template_module.tool_start, str): if isinstance(template_module.tool_start, str):
template_metadata.tool_starts.append(template_module.tool_start) template_metadata.tool_start = template_module.tool_start
if hasattr(template_module, "tool_start_token"):
if isinstance(template_module.tool_start_token, int):
template_metadata.tool_starts.append(template_module.tool_start_token)
self.metadata = template_metadata self.metadata = template_metadata
return template_metadata return template_metadata

View file

@ -5,7 +5,7 @@ from typing import Literal, Union, List, Optional, Dict
from uuid import uuid4 from uuid import uuid4
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema from endpoints.OAI.types.tools import ToolSpec, ToolCall
class ChatCompletionLogprob(BaseModel): class ChatCompletionLogprob(BaseModel):
@ -73,15 +73,6 @@ class ChatCompletionRequest(CommonCompletionRequest):
tools: Optional[List[ToolSpec]] = None tools: Optional[List[ToolSpec]] = None
functions: Optional[List[Dict]] = None functions: Optional[List[Dict]] = None
# Typically collected from Chat Template.
# Don't include this in the OpenAPI docs
# TODO: Use these custom parameters
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = Field(
default_factory=list
)
tool_call_end: SkipJsonSchema[Optional[str]] = None
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
# Chat completions requests do not have a BOS token preference. Backend # Chat completions requests do not have a BOS token preference. Backend
# respects the tokenization config for the individual model. # respects the tokenization config for the individual model.
add_bos_token: Optional[bool] = None add_bos_token: Optional[bool] = None

View file

@ -1,30 +1,6 @@
from pydantic import BaseModel from pydantic import BaseModel
from typing import Dict, Literal from typing import Dict, Literal
tool_call_schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
# Converted to OAI's string in post process
"type": "object"
},
},
"required": ["name", "arguments"],
},
"type": {"type": "string", "enum": ["function"]},
},
"required": ["id", "function", "type"],
},
}
class Function(BaseModel): class Function(BaseModel):
"""Represents a description of a tool function.""" """Represents a description of a tool function."""

View file

@ -1,7 +1,6 @@
"""Chat completion utilities for OAI server.""" """Chat completion utilities for OAI server."""
import asyncio import asyncio
import json
import pathlib import pathlib
from asyncio import CancelledError from asyncio import CancelledError
from typing import List, Optional from typing import List, Optional
@ -30,7 +29,7 @@ from endpoints.OAI.types.chat_completion import (
) )
from endpoints.OAI.types.common import UsageStats from endpoints.OAI.types.common import UsageStats
from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector
from endpoints.OAI.utils.tools import ToolCallProcessor from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA
def _create_response( def _create_response(
@ -209,12 +208,9 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars:
else: else:
data.stop.extend(template_metadata.stop_strings) data.stop.extend(template_metadata.stop_strings)
# Tool call start strings # if a tool start is present, append it to stopping strings
if template_metadata.tool_starts: if template_metadata.tool_start:
data.tool_call_start.extend(template_metadata.tool_starts) data.stop.append(template_metadata.tool_start)
# Append to stop strings to halt for a tool call generation
data.stop.extend(template_metadata.tool_starts)
async def format_messages_with_template( async def format_messages_with_template(
@ -255,9 +251,7 @@ async def format_messages_with_template(
return prompt, mm_embeddings, template_vars return prompt, mm_embeddings, template_vars
async def apply_chat_template( async def apply_chat_template(data: ChatCompletionRequest):
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
):
""" """
Compile the prompt and get any additional stop strings from the template. Compile the prompt and get any additional stop strings from the template.
Template stop strings can be overriden by sampler overrides if force is true. Template stop strings can be overriden by sampler overrides if force is true.
@ -271,10 +265,7 @@ async def apply_chat_template(
{ {
"add_generation_prompt": data.add_generation_prompt, "add_generation_prompt": data.add_generation_prompt,
"tools": tools, "tools": tools,
"tools_json": json.dumps(tools, indent=2),
"functions": data.functions, "functions": data.functions,
"functions_json": json.dumps(data.functions, indent=2),
"tool_precursor": tool_precursor,
} }
) )
@ -332,6 +323,7 @@ async def stream_generate_chat_completion(
abort_event = asyncio.Event() abort_event = asyncio.Event()
gen_queue = asyncio.Queue() gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = [] gen_tasks: List[asyncio.Task] = []
tool_start = model.container.prompt_template.metadata.tool_start
disconnect_task = asyncio.create_task(request_disconnect_loop(request)) disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try: try:
@ -355,7 +347,7 @@ async def stream_generate_chat_completion(
gen_tasks.append(gen_task) gen_tasks.append(gen_task)
# We need to keep track of the text generated so we can resume the tool calls # Text accumulation for tool calls
current_generation_text = "" current_generation_text = ""
# Consumer loop # Consumer loop
@ -367,19 +359,21 @@ async def stream_generate_chat_completion(
) )
generation = await gen_queue.get() generation = await gen_queue.get()
# lets only append the text if we need it for tool calls later
if data.tool_call_start and "text" in generation:
current_generation_text += generation["text"]
# check if we are running a tool model, and that we are at stop # Handle options if a tool model is present
if data.tool_call_start and "stop_str" in generation: if tool_start:
generations = await generate_tool_calls( if "stop_str" in generation:
data, generations = await generate_tool_calls(
[generation], data,
request, [generation],
current_generations=current_generation_text, request,
) current_generation_text=current_generation_text,
generation = generations[0] # We only have one generation in this case )
# Only one generation present in this case
generation = generations[0]
elif "text" in generation:
current_generation_text += generation["text"]
# Stream collector will push an exception to the queue if it fails # Stream collector will push an exception to the queue if it fails
if isinstance(generation, Exception): if isinstance(generation, Exception):
@ -428,6 +422,7 @@ async def generate_chat_completion(
model_path: pathlib.Path, model_path: pathlib.Path,
): ):
gen_tasks: List[asyncio.Task] = [] gen_tasks: List[asyncio.Task] = []
tool_start = model.container.prompt_template.metadata.tool_start
try: try:
logger.info(f"Received chat completion request {request.state.id}") logger.info(f"Received chat completion request {request.state.id}")
@ -448,8 +443,8 @@ async def generate_chat_completion(
generations = await asyncio.gather(*gen_tasks) generations = await asyncio.gather(*gen_tasks)
# Let's not waste our time if we arn't running a tool model # Check all the generations and see if a tool call is required
if data.tool_call_start: if tool_start:
generations = await generate_tool_calls(data, generations, request) generations = await generate_tool_calls(data, generations, request)
response = _create_response(request.state.id, generations, model_path.name) response = _create_response(request.state.id, generations, model_path.name)
@ -472,51 +467,52 @@ async def generate_tool_calls(
data: ChatCompletionRequest, data: ChatCompletionRequest,
generations: List[str], generations: List[str],
request: Request, request: Request,
current_generations: str = None, current_generation_text: str = None,
): ):
gen_tasks: List[asyncio.Task] = [] gen_tasks: List[asyncio.Task] = []
tool_start = model.container.prompt_template.metadata.tool_start
# Tracks which generations asked for a tool call
tool_idx: List[int] = [] tool_idx: List[int] = []
# Copy to make sure the parent JSON schema doesn't get modified # Copy to make sure the parent JSON schema doesn't get modified
# FIXME: May not be necessary depending on how the codebase evolves
tool_data = data.model_copy(deep=True) tool_data = data.model_copy(deep=True)
tool_data.json_schema = tool_data.tool_call_schema tool_data.json_schema = TOOL_CALL_SCHEMA
for idx, gen in enumerate(generations): for idx, gen in enumerate(generations):
if gen["stop_str"] in tool_data.tool_call_start: if gen["stop_str"] != tool_start:
logger.info( continue
f"Detected tool call in chat completion request {request.state.id}"
)
if "text" in gen: logger.info(f"Detected tool call in chat completion request {request.state.id}")
# non streaming, all generations will have the text they generated
pre_tool_prompt, embeddings = await apply_chat_template(
data, gen["text"]
)
elif current_generations is not None:
# streaming, we wont have text in the generation,
# we'll have to use the current_generations
pre_tool_prompt, embeddings = await apply_chat_template(
data, current_generations
)
request_id = _parse_gen_request_id(data.n, request.state.id, idx) # Append the existing generation as part of the response prefix
precursor_text = current_generation_text or gen.get("text")
if precursor_text:
tool_data.response_prefix = precursor_text
gen_tasks.append( pre_tool_prompt, embeddings = await apply_chat_template(tool_data)
asyncio.create_task(
model.container.generate( gen_request_id = _parse_gen_request_id(data.n, request.state.id, idx)
request_id, tool_request_id = f"{gen_request_id}-tool"
pre_tool_prompt,
tool_data, gen_tasks.append(
mm_embeddings=embeddings, asyncio.create_task(
) model.container.generate(
tool_request_id,
pre_tool_prompt,
tool_data,
mm_embeddings=embeddings,
) )
) )
tool_idx.append(idx) )
tool_calls = await asyncio.gather(*gen_tasks) tool_idx.append(idx)
for outer_idx in range(0, len(tool_idx)):
gen_idx = tool_idx[outer_idx] if len(tool_idx) > 0:
generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"] tool_calls = await asyncio.gather(*gen_tasks)
# Map tool calls to their appropriate generation
for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True):
generations[gen_idx]["tool_calls"] = tool_call["text"]
return generations return generations

View file

@ -5,6 +5,31 @@ from typing import List
from endpoints.OAI.types.tools import ToolCall from endpoints.OAI.types.tools import ToolCall
TOOL_CALL_SCHEMA = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
# Converted to OAI's string in post process
"type": "object"
},
},
"required": ["name", "arguments"],
},
"type": {"type": "string", "enum": ["function"]},
},
"required": ["id", "function", "type"],
},
}
class ToolCallProcessor: class ToolCallProcessor:
@staticmethod @staticmethod
def from_json(tool_calls_str: str) -> List[ToolCall]: def from_json(tool_calls_str: str) -> List[ToolCall]: