API: Modify tool calling for wider compat

When revisiting tool calls, the formats have more or less become standard.
For greater compatibility with templates, primarily use the message.tools
parameter and remove the extra custom metadata that is no longer required.

However, unlike other backends, tabbyAPI still uses template metadata
to declare what the tool start string is. This allows for template-level
customization along with giving more power to the user while the server
exists to consume rather than work on a case-by-case basis.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-07-05 14:28:12 -04:00
parent b6a26da50c
commit 879f4cee7e
5 changed files with 89 additions and 103 deletions

View file

@ -4,6 +4,8 @@ import traceback
import aiofiles
import json
import pathlib
from dataclasses import dataclass, field
from datetime import datetime
from importlib.metadata import version as package_version
from typing import List, Optional
from jinja2 import Template, TemplateError
@ -11,7 +13,6 @@ from jinja2.ext import loopcontrols
from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
from packaging import version
from datetime import datetime
from common.utils import unwrap
@ -23,11 +24,12 @@ class TemplateLoadError(Exception):
pass
@dataclass
class TemplateMetadata:
"""Represents the parsed metadata from a template."""
stop_strings: List[str] = []
tool_starts: List[str] = []
stop_strings: List[str] = field(default_factory=list)
tool_start: Optional[str] = None
class PromptTemplate:
@ -72,11 +74,7 @@ class PromptTemplate:
if hasattr(template_module, "tool_start"):
if isinstance(template_module.tool_start, str):
template_metadata.tool_starts.append(template_module.tool_start)
if hasattr(template_module, "tool_start_token"):
if isinstance(template_module.tool_start_token, int):
template_metadata.tool_starts.append(template_module.tool_start_token)
template_metadata.tool_start = template_module.tool_start
self.metadata = template_metadata
return template_metadata

View file

@ -5,7 +5,7 @@ from typing import Literal, Union, List, Optional, Dict
from uuid import uuid4
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema
from endpoints.OAI.types.tools import ToolSpec, ToolCall
class ChatCompletionLogprob(BaseModel):
@ -73,15 +73,6 @@ class ChatCompletionRequest(CommonCompletionRequest):
tools: Optional[List[ToolSpec]] = None
functions: Optional[List[Dict]] = None
# Typically collected from Chat Template.
# Don't include this in the OpenAPI docs
# TODO: Use these custom parameters
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = Field(
default_factory=list
)
tool_call_end: SkipJsonSchema[Optional[str]] = None
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
# Chat completions requests do not have a BOS token preference. Backend
# respects the tokenization config for the individual model.
add_bos_token: Optional[bool] = None

View file

@ -1,30 +1,6 @@
from pydantic import BaseModel
from typing import Dict, Literal
tool_call_schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
# Converted to OAI's string in post process
"type": "object"
},
},
"required": ["name", "arguments"],
},
"type": {"type": "string", "enum": ["function"]},
},
"required": ["id", "function", "type"],
},
}
class Function(BaseModel):
"""Represents a description of a tool function."""

View file

@ -1,7 +1,6 @@
"""Chat completion utilities for OAI server."""
import asyncio
import json
import pathlib
from asyncio import CancelledError
from typing import List, Optional
@ -30,7 +29,7 @@ from endpoints.OAI.types.chat_completion import (
)
from endpoints.OAI.types.common import UsageStats
from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector
from endpoints.OAI.utils.tools import ToolCallProcessor
from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA
def _create_response(
@ -209,12 +208,9 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars:
else:
data.stop.extend(template_metadata.stop_strings)
# Tool call start strings
if template_metadata.tool_starts:
data.tool_call_start.extend(template_metadata.tool_starts)
# Append to stop strings to halt for a tool call generation
data.stop.extend(template_metadata.tool_starts)
# if a tool start is present, append it to stopping strings
if template_metadata.tool_start:
data.stop.append(template_metadata.tool_start)
async def format_messages_with_template(
@ -255,9 +251,7 @@ async def format_messages_with_template(
return prompt, mm_embeddings, template_vars
async def apply_chat_template(
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
):
async def apply_chat_template(data: ChatCompletionRequest):
"""
Compile the prompt and get any additional stop strings from the template.
Template stop strings can be overriden by sampler overrides if force is true.
@ -271,10 +265,7 @@ async def apply_chat_template(
{
"add_generation_prompt": data.add_generation_prompt,
"tools": tools,
"tools_json": json.dumps(tools, indent=2),
"functions": data.functions,
"functions_json": json.dumps(data.functions, indent=2),
"tool_precursor": tool_precursor,
}
)
@ -332,6 +323,7 @@ async def stream_generate_chat_completion(
abort_event = asyncio.Event()
gen_queue = asyncio.Queue()
gen_tasks: List[asyncio.Task] = []
tool_start = model.container.prompt_template.metadata.tool_start
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
try:
@ -355,7 +347,7 @@ async def stream_generate_chat_completion(
gen_tasks.append(gen_task)
# We need to keep track of the text generated so we can resume the tool calls
# Text accumulation for tool calls
current_generation_text = ""
# Consumer loop
@ -367,19 +359,21 @@ async def stream_generate_chat_completion(
)
generation = await gen_queue.get()
# lets only append the text if we need it for tool calls later
if data.tool_call_start and "text" in generation:
current_generation_text += generation["text"]
# check if we are running a tool model, and that we are at stop
if data.tool_call_start and "stop_str" in generation:
generations = await generate_tool_calls(
data,
[generation],
request,
current_generations=current_generation_text,
)
generation = generations[0] # We only have one generation in this case
# Handle options if a tool model is present
if tool_start:
if "stop_str" in generation:
generations = await generate_tool_calls(
data,
[generation],
request,
current_generation_text=current_generation_text,
)
# Only one generation present in this case
generation = generations[0]
elif "text" in generation:
current_generation_text += generation["text"]
# Stream collector will push an exception to the queue if it fails
if isinstance(generation, Exception):
@ -428,6 +422,7 @@ async def generate_chat_completion(
model_path: pathlib.Path,
):
gen_tasks: List[asyncio.Task] = []
tool_start = model.container.prompt_template.metadata.tool_start
try:
logger.info(f"Received chat completion request {request.state.id}")
@ -448,8 +443,8 @@ async def generate_chat_completion(
generations = await asyncio.gather(*gen_tasks)
# Let's not waste our time if we arn't running a tool model
if data.tool_call_start:
# Check all the generations and see if a tool call is required
if tool_start:
generations = await generate_tool_calls(data, generations, request)
response = _create_response(request.state.id, generations, model_path.name)
@ -472,51 +467,52 @@ async def generate_tool_calls(
data: ChatCompletionRequest,
generations: List[str],
request: Request,
current_generations: str = None,
current_generation_text: str = None,
):
gen_tasks: List[asyncio.Task] = []
tool_start = model.container.prompt_template.metadata.tool_start
# Tracks which generations asked for a tool call
tool_idx: List[int] = []
# Copy to make sure the parent JSON schema doesn't get modified
# FIXME: May not be necessary depending on how the codebase evolves
tool_data = data.model_copy(deep=True)
tool_data.json_schema = tool_data.tool_call_schema
tool_data.json_schema = TOOL_CALL_SCHEMA
for idx, gen in enumerate(generations):
if gen["stop_str"] in tool_data.tool_call_start:
logger.info(
f"Detected tool call in chat completion request {request.state.id}"
)
if gen["stop_str"] != tool_start:
continue
if "text" in gen:
# non streaming, all generations will have the text they generated
pre_tool_prompt, embeddings = await apply_chat_template(
data, gen["text"]
)
elif current_generations is not None:
# streaming, we wont have text in the generation,
# we'll have to use the current_generations
pre_tool_prompt, embeddings = await apply_chat_template(
data, current_generations
)
logger.info(f"Detected tool call in chat completion request {request.state.id}")
request_id = _parse_gen_request_id(data.n, request.state.id, idx)
# Append the existing generation as part of the response prefix
precursor_text = current_generation_text or gen.get("text")
if precursor_text:
tool_data.response_prefix = precursor_text
gen_tasks.append(
asyncio.create_task(
model.container.generate(
request_id,
pre_tool_prompt,
tool_data,
mm_embeddings=embeddings,
)
pre_tool_prompt, embeddings = await apply_chat_template(tool_data)
gen_request_id = _parse_gen_request_id(data.n, request.state.id, idx)
tool_request_id = f"{gen_request_id}-tool"
gen_tasks.append(
asyncio.create_task(
model.container.generate(
tool_request_id,
pre_tool_prompt,
tool_data,
mm_embeddings=embeddings,
)
)
tool_idx.append(idx)
)
tool_calls = await asyncio.gather(*gen_tasks)
for outer_idx in range(0, len(tool_idx)):
gen_idx = tool_idx[outer_idx]
generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"]
tool_idx.append(idx)
if len(tool_idx) > 0:
tool_calls = await asyncio.gather(*gen_tasks)
# Map tool calls to their appropriate generation
for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True):
generations[gen_idx]["tool_calls"] = tool_call["text"]
return generations

View file

@ -5,6 +5,31 @@ from typing import List
from endpoints.OAI.types.tools import ToolCall
TOOL_CALL_SCHEMA = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
# Converted to OAI's string in post process
"type": "object"
},
},
"required": ["name", "arguments"],
},
"type": {"type": "string", "enum": ["function"]},
},
"required": ["id", "function", "type"],
},
}
class ToolCallProcessor:
@staticmethod
def from_json(tool_calls_str: str) -> List[ToolCall]: