API: Modify tool calling for wider compat
When revisiting tool calls, the formats have more or less become standard. For greater compatibility with templates, primarily use the message.tools parameter and remove the extra custom metadata that is no longer required. However, unlike other backends, tabbyAPI still uses template metadata to declare what the tool start string is. This allows for template-level customization along with giving more power to the user while the server exists to consume rather than work on a case-by-case basis. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
b6a26da50c
commit
879f4cee7e
5 changed files with 89 additions and 103 deletions
|
|
@ -4,6 +4,8 @@ import traceback
|
|||
import aiofiles
|
||||
import json
|
||||
import pathlib
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from importlib.metadata import version as package_version
|
||||
from typing import List, Optional
|
||||
from jinja2 import Template, TemplateError
|
||||
|
|
@ -11,7 +13,6 @@ from jinja2.ext import loopcontrols
|
|||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
from loguru import logger
|
||||
from packaging import version
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
from common.utils import unwrap
|
||||
|
|
@ -23,11 +24,12 @@ class TemplateLoadError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class TemplateMetadata:
|
||||
"""Represents the parsed metadata from a template."""
|
||||
|
||||
stop_strings: List[str] = []
|
||||
tool_starts: List[str] = []
|
||||
stop_strings: List[str] = field(default_factory=list)
|
||||
tool_start: Optional[str] = None
|
||||
|
||||
|
||||
class PromptTemplate:
|
||||
|
|
@ -72,11 +74,7 @@ class PromptTemplate:
|
|||
|
||||
if hasattr(template_module, "tool_start"):
|
||||
if isinstance(template_module.tool_start, str):
|
||||
template_metadata.tool_starts.append(template_module.tool_start)
|
||||
|
||||
if hasattr(template_module, "tool_start_token"):
|
||||
if isinstance(template_module.tool_start_token, int):
|
||||
template_metadata.tool_starts.append(template_module.tool_start_token)
|
||||
template_metadata.tool_start = template_module.tool_start
|
||||
|
||||
self.metadata = template_metadata
|
||||
return template_metadata
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Literal, Union, List, Optional, Dict
|
|||
from uuid import uuid4
|
||||
|
||||
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
|
||||
from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema
|
||||
from endpoints.OAI.types.tools import ToolSpec, ToolCall
|
||||
|
||||
|
||||
class ChatCompletionLogprob(BaseModel):
|
||||
|
|
@ -73,15 +73,6 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
|||
tools: Optional[List[ToolSpec]] = None
|
||||
functions: Optional[List[Dict]] = None
|
||||
|
||||
# Typically collected from Chat Template.
|
||||
# Don't include this in the OpenAPI docs
|
||||
# TODO: Use these custom parameters
|
||||
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = Field(
|
||||
default_factory=list
|
||||
)
|
||||
tool_call_end: SkipJsonSchema[Optional[str]] = None
|
||||
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
|
||||
|
||||
# Chat completions requests do not have a BOS token preference. Backend
|
||||
# respects the tokenization config for the individual model.
|
||||
add_bos_token: Optional[bool] = None
|
||||
|
|
|
|||
|
|
@ -1,30 +1,6 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Dict, Literal
|
||||
|
||||
tool_call_schema = {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"arguments": {
|
||||
# Converted to OAI's string in post process
|
||||
"type": "object"
|
||||
},
|
||||
},
|
||||
"required": ["name", "arguments"],
|
||||
},
|
||||
"type": {"type": "string", "enum": ["function"]},
|
||||
},
|
||||
"required": ["id", "function", "type"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""Represents a description of a tool function."""
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
"""Chat completion utilities for OAI server."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import pathlib
|
||||
from asyncio import CancelledError
|
||||
from typing import List, Optional
|
||||
|
|
@ -30,7 +29,7 @@ from endpoints.OAI.types.chat_completion import (
|
|||
)
|
||||
from endpoints.OAI.types.common import UsageStats
|
||||
from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector
|
||||
from endpoints.OAI.utils.tools import ToolCallProcessor
|
||||
from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA
|
||||
|
||||
|
||||
def _create_response(
|
||||
|
|
@ -209,12 +208,9 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars:
|
|||
else:
|
||||
data.stop.extend(template_metadata.stop_strings)
|
||||
|
||||
# Tool call start strings
|
||||
if template_metadata.tool_starts:
|
||||
data.tool_call_start.extend(template_metadata.tool_starts)
|
||||
|
||||
# Append to stop strings to halt for a tool call generation
|
||||
data.stop.extend(template_metadata.tool_starts)
|
||||
# if a tool start is present, append it to stopping strings
|
||||
if template_metadata.tool_start:
|
||||
data.stop.append(template_metadata.tool_start)
|
||||
|
||||
|
||||
async def format_messages_with_template(
|
||||
|
|
@ -255,9 +251,7 @@ async def format_messages_with_template(
|
|||
return prompt, mm_embeddings, template_vars
|
||||
|
||||
|
||||
async def apply_chat_template(
|
||||
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
|
||||
):
|
||||
async def apply_chat_template(data: ChatCompletionRequest):
|
||||
"""
|
||||
Compile the prompt and get any additional stop strings from the template.
|
||||
Template stop strings can be overriden by sampler overrides if force is true.
|
||||
|
|
@ -271,10 +265,7 @@ async def apply_chat_template(
|
|||
{
|
||||
"add_generation_prompt": data.add_generation_prompt,
|
||||
"tools": tools,
|
||||
"tools_json": json.dumps(tools, indent=2),
|
||||
"functions": data.functions,
|
||||
"functions_json": json.dumps(data.functions, indent=2),
|
||||
"tool_precursor": tool_precursor,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
@ -332,6 +323,7 @@ async def stream_generate_chat_completion(
|
|||
abort_event = asyncio.Event()
|
||||
gen_queue = asyncio.Queue()
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
tool_start = model.container.prompt_template.metadata.tool_start
|
||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||
|
||||
try:
|
||||
|
|
@ -355,7 +347,7 @@ async def stream_generate_chat_completion(
|
|||
|
||||
gen_tasks.append(gen_task)
|
||||
|
||||
# We need to keep track of the text generated so we can resume the tool calls
|
||||
# Text accumulation for tool calls
|
||||
current_generation_text = ""
|
||||
|
||||
# Consumer loop
|
||||
|
|
@ -367,19 +359,21 @@ async def stream_generate_chat_completion(
|
|||
)
|
||||
|
||||
generation = await gen_queue.get()
|
||||
# lets only append the text if we need it for tool calls later
|
||||
if data.tool_call_start and "text" in generation:
|
||||
current_generation_text += generation["text"]
|
||||
|
||||
# check if we are running a tool model, and that we are at stop
|
||||
if data.tool_call_start and "stop_str" in generation:
|
||||
# Handle options if a tool model is present
|
||||
if tool_start:
|
||||
if "stop_str" in generation:
|
||||
generations = await generate_tool_calls(
|
||||
data,
|
||||
[generation],
|
||||
request,
|
||||
current_generations=current_generation_text,
|
||||
current_generation_text=current_generation_text,
|
||||
)
|
||||
generation = generations[0] # We only have one generation in this case
|
||||
|
||||
# Only one generation present in this case
|
||||
generation = generations[0]
|
||||
elif "text" in generation:
|
||||
current_generation_text += generation["text"]
|
||||
|
||||
# Stream collector will push an exception to the queue if it fails
|
||||
if isinstance(generation, Exception):
|
||||
|
|
@ -428,6 +422,7 @@ async def generate_chat_completion(
|
|||
model_path: pathlib.Path,
|
||||
):
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
tool_start = model.container.prompt_template.metadata.tool_start
|
||||
|
||||
try:
|
||||
logger.info(f"Received chat completion request {request.state.id}")
|
||||
|
|
@ -448,8 +443,8 @@ async def generate_chat_completion(
|
|||
|
||||
generations = await asyncio.gather(*gen_tasks)
|
||||
|
||||
# Let's not waste our time if we arn't running a tool model
|
||||
if data.tool_call_start:
|
||||
# Check all the generations and see if a tool call is required
|
||||
if tool_start:
|
||||
generations = await generate_tool_calls(data, generations, request)
|
||||
|
||||
response = _create_response(request.state.id, generations, model_path.name)
|
||||
|
|
@ -472,51 +467,52 @@ async def generate_tool_calls(
|
|||
data: ChatCompletionRequest,
|
||||
generations: List[str],
|
||||
request: Request,
|
||||
current_generations: str = None,
|
||||
current_generation_text: str = None,
|
||||
):
|
||||
gen_tasks: List[asyncio.Task] = []
|
||||
tool_start = model.container.prompt_template.metadata.tool_start
|
||||
|
||||
# Tracks which generations asked for a tool call
|
||||
tool_idx: List[int] = []
|
||||
|
||||
# Copy to make sure the parent JSON schema doesn't get modified
|
||||
# FIXME: May not be necessary depending on how the codebase evolves
|
||||
tool_data = data.model_copy(deep=True)
|
||||
tool_data.json_schema = tool_data.tool_call_schema
|
||||
tool_data.json_schema = TOOL_CALL_SCHEMA
|
||||
|
||||
for idx, gen in enumerate(generations):
|
||||
if gen["stop_str"] in tool_data.tool_call_start:
|
||||
logger.info(
|
||||
f"Detected tool call in chat completion request {request.state.id}"
|
||||
)
|
||||
if gen["stop_str"] != tool_start:
|
||||
continue
|
||||
|
||||
if "text" in gen:
|
||||
# non streaming, all generations will have the text they generated
|
||||
pre_tool_prompt, embeddings = await apply_chat_template(
|
||||
data, gen["text"]
|
||||
)
|
||||
elif current_generations is not None:
|
||||
# streaming, we wont have text in the generation,
|
||||
# we'll have to use the current_generations
|
||||
pre_tool_prompt, embeddings = await apply_chat_template(
|
||||
data, current_generations
|
||||
)
|
||||
logger.info(f"Detected tool call in chat completion request {request.state.id}")
|
||||
|
||||
request_id = _parse_gen_request_id(data.n, request.state.id, idx)
|
||||
# Append the existing generation as part of the response prefix
|
||||
precursor_text = current_generation_text or gen.get("text")
|
||||
if precursor_text:
|
||||
tool_data.response_prefix = precursor_text
|
||||
|
||||
pre_tool_prompt, embeddings = await apply_chat_template(tool_data)
|
||||
|
||||
gen_request_id = _parse_gen_request_id(data.n, request.state.id, idx)
|
||||
tool_request_id = f"{gen_request_id}-tool"
|
||||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
request_id,
|
||||
tool_request_id,
|
||||
pre_tool_prompt,
|
||||
tool_data,
|
||||
mm_embeddings=embeddings,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
tool_idx.append(idx)
|
||||
|
||||
if len(tool_idx) > 0:
|
||||
tool_calls = await asyncio.gather(*gen_tasks)
|
||||
for outer_idx in range(0, len(tool_idx)):
|
||||
gen_idx = tool_idx[outer_idx]
|
||||
generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"]
|
||||
|
||||
# Map tool calls to their appropriate generation
|
||||
for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True):
|
||||
generations[gen_idx]["tool_calls"] = tool_call["text"]
|
||||
|
||||
return generations
|
||||
|
|
|
|||
|
|
@ -5,6 +5,31 @@ from typing import List
|
|||
from endpoints.OAI.types.tools import ToolCall
|
||||
|
||||
|
||||
TOOL_CALL_SCHEMA = {
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"id": {"type": "string"},
|
||||
"function": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"arguments": {
|
||||
# Converted to OAI's string in post process
|
||||
"type": "object"
|
||||
},
|
||||
},
|
||||
"required": ["name", "arguments"],
|
||||
},
|
||||
"type": {"type": "string", "enum": ["function"]},
|
||||
},
|
||||
"required": ["id", "function", "type"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class ToolCallProcessor:
|
||||
@staticmethod
|
||||
def from_json(tool_calls_str: str) -> List[ToolCall]:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue