API: Modify tool calling for wider compat
When revisiting tool calls, the formats have more or less become standard. For greater compatibility with templates, primarily use the message.tools parameter and remove the extra custom metadata that is no longer required. However, unlike other backends, tabbyAPI still uses template metadata to declare what the tool start string is. This allows for template-level customization along with giving more power to the user while the server exists to consume rather than work on a case-by-case basis. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
b6a26da50c
commit
879f4cee7e
5 changed files with 89 additions and 103 deletions
|
|
@ -4,6 +4,8 @@ import traceback
|
||||||
import aiofiles
|
import aiofiles
|
||||||
import json
|
import json
|
||||||
import pathlib
|
import pathlib
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
from importlib.metadata import version as package_version
|
from importlib.metadata import version as package_version
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from jinja2 import Template, TemplateError
|
from jinja2 import Template, TemplateError
|
||||||
|
|
@ -11,7 +13,6 @@ from jinja2.ext import loopcontrols
|
||||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||||
from loguru import logger
|
from loguru import logger
|
||||||
from packaging import version
|
from packaging import version
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
|
|
||||||
from common.utils import unwrap
|
from common.utils import unwrap
|
||||||
|
|
@ -23,11 +24,12 @@ class TemplateLoadError(Exception):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class TemplateMetadata:
|
class TemplateMetadata:
|
||||||
"""Represents the parsed metadata from a template."""
|
"""Represents the parsed metadata from a template."""
|
||||||
|
|
||||||
stop_strings: List[str] = []
|
stop_strings: List[str] = field(default_factory=list)
|
||||||
tool_starts: List[str] = []
|
tool_start: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplate:
|
class PromptTemplate:
|
||||||
|
|
@ -72,11 +74,7 @@ class PromptTemplate:
|
||||||
|
|
||||||
if hasattr(template_module, "tool_start"):
|
if hasattr(template_module, "tool_start"):
|
||||||
if isinstance(template_module.tool_start, str):
|
if isinstance(template_module.tool_start, str):
|
||||||
template_metadata.tool_starts.append(template_module.tool_start)
|
template_metadata.tool_start = template_module.tool_start
|
||||||
|
|
||||||
if hasattr(template_module, "tool_start_token"):
|
|
||||||
if isinstance(template_module.tool_start_token, int):
|
|
||||||
template_metadata.tool_starts.append(template_module.tool_start_token)
|
|
||||||
|
|
||||||
self.metadata = template_metadata
|
self.metadata = template_metadata
|
||||||
return template_metadata
|
return template_metadata
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import Literal, Union, List, Optional, Dict
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
|
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
|
||||||
from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema
|
from endpoints.OAI.types.tools import ToolSpec, ToolCall
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionLogprob(BaseModel):
|
class ChatCompletionLogprob(BaseModel):
|
||||||
|
|
@ -73,15 +73,6 @@ class ChatCompletionRequest(CommonCompletionRequest):
|
||||||
tools: Optional[List[ToolSpec]] = None
|
tools: Optional[List[ToolSpec]] = None
|
||||||
functions: Optional[List[Dict]] = None
|
functions: Optional[List[Dict]] = None
|
||||||
|
|
||||||
# Typically collected from Chat Template.
|
|
||||||
# Don't include this in the OpenAPI docs
|
|
||||||
# TODO: Use these custom parameters
|
|
||||||
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = Field(
|
|
||||||
default_factory=list
|
|
||||||
)
|
|
||||||
tool_call_end: SkipJsonSchema[Optional[str]] = None
|
|
||||||
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
|
|
||||||
|
|
||||||
# Chat completions requests do not have a BOS token preference. Backend
|
# Chat completions requests do not have a BOS token preference. Backend
|
||||||
# respects the tokenization config for the individual model.
|
# respects the tokenization config for the individual model.
|
||||||
add_bos_token: Optional[bool] = None
|
add_bos_token: Optional[bool] = None
|
||||||
|
|
|
||||||
|
|
@ -1,30 +1,6 @@
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Dict, Literal
|
from typing import Dict, Literal
|
||||||
|
|
||||||
tool_call_schema = {
|
|
||||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
|
||||||
"type": "array",
|
|
||||||
"items": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"id": {"type": "string"},
|
|
||||||
"function": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": {
|
|
||||||
"name": {"type": "string"},
|
|
||||||
"arguments": {
|
|
||||||
# Converted to OAI's string in post process
|
|
||||||
"type": "object"
|
|
||||||
},
|
|
||||||
},
|
|
||||||
"required": ["name", "arguments"],
|
|
||||||
},
|
|
||||||
"type": {"type": "string", "enum": ["function"]},
|
|
||||||
},
|
|
||||||
"required": ["id", "function", "type"],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class Function(BaseModel):
|
class Function(BaseModel):
|
||||||
"""Represents a description of a tool function."""
|
"""Represents a description of a tool function."""
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
"""Chat completion utilities for OAI server."""
|
"""Chat completion utilities for OAI server."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import pathlib
|
import pathlib
|
||||||
from asyncio import CancelledError
|
from asyncio import CancelledError
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
@ -30,7 +29,7 @@ from endpoints.OAI.types.chat_completion import (
|
||||||
)
|
)
|
||||||
from endpoints.OAI.types.common import UsageStats
|
from endpoints.OAI.types.common import UsageStats
|
||||||
from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector
|
from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector
|
||||||
from endpoints.OAI.utils.tools import ToolCallProcessor
|
from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA
|
||||||
|
|
||||||
|
|
||||||
def _create_response(
|
def _create_response(
|
||||||
|
|
@ -209,12 +208,9 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars:
|
||||||
else:
|
else:
|
||||||
data.stop.extend(template_metadata.stop_strings)
|
data.stop.extend(template_metadata.stop_strings)
|
||||||
|
|
||||||
# Tool call start strings
|
# if a tool start is present, append it to stopping strings
|
||||||
if template_metadata.tool_starts:
|
if template_metadata.tool_start:
|
||||||
data.tool_call_start.extend(template_metadata.tool_starts)
|
data.stop.append(template_metadata.tool_start)
|
||||||
|
|
||||||
# Append to stop strings to halt for a tool call generation
|
|
||||||
data.stop.extend(template_metadata.tool_starts)
|
|
||||||
|
|
||||||
|
|
||||||
async def format_messages_with_template(
|
async def format_messages_with_template(
|
||||||
|
|
@ -255,9 +251,7 @@ async def format_messages_with_template(
|
||||||
return prompt, mm_embeddings, template_vars
|
return prompt, mm_embeddings, template_vars
|
||||||
|
|
||||||
|
|
||||||
async def apply_chat_template(
|
async def apply_chat_template(data: ChatCompletionRequest):
|
||||||
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
|
|
||||||
):
|
|
||||||
"""
|
"""
|
||||||
Compile the prompt and get any additional stop strings from the template.
|
Compile the prompt and get any additional stop strings from the template.
|
||||||
Template stop strings can be overriden by sampler overrides if force is true.
|
Template stop strings can be overriden by sampler overrides if force is true.
|
||||||
|
|
@ -271,10 +265,7 @@ async def apply_chat_template(
|
||||||
{
|
{
|
||||||
"add_generation_prompt": data.add_generation_prompt,
|
"add_generation_prompt": data.add_generation_prompt,
|
||||||
"tools": tools,
|
"tools": tools,
|
||||||
"tools_json": json.dumps(tools, indent=2),
|
|
||||||
"functions": data.functions,
|
"functions": data.functions,
|
||||||
"functions_json": json.dumps(data.functions, indent=2),
|
|
||||||
"tool_precursor": tool_precursor,
|
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -332,6 +323,7 @@ async def stream_generate_chat_completion(
|
||||||
abort_event = asyncio.Event()
|
abort_event = asyncio.Event()
|
||||||
gen_queue = asyncio.Queue()
|
gen_queue = asyncio.Queue()
|
||||||
gen_tasks: List[asyncio.Task] = []
|
gen_tasks: List[asyncio.Task] = []
|
||||||
|
tool_start = model.container.prompt_template.metadata.tool_start
|
||||||
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
disconnect_task = asyncio.create_task(request_disconnect_loop(request))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -355,7 +347,7 @@ async def stream_generate_chat_completion(
|
||||||
|
|
||||||
gen_tasks.append(gen_task)
|
gen_tasks.append(gen_task)
|
||||||
|
|
||||||
# We need to keep track of the text generated so we can resume the tool calls
|
# Text accumulation for tool calls
|
||||||
current_generation_text = ""
|
current_generation_text = ""
|
||||||
|
|
||||||
# Consumer loop
|
# Consumer loop
|
||||||
|
|
@ -367,19 +359,21 @@ async def stream_generate_chat_completion(
|
||||||
)
|
)
|
||||||
|
|
||||||
generation = await gen_queue.get()
|
generation = await gen_queue.get()
|
||||||
# lets only append the text if we need it for tool calls later
|
|
||||||
if data.tool_call_start and "text" in generation:
|
|
||||||
current_generation_text += generation["text"]
|
|
||||||
|
|
||||||
# check if we are running a tool model, and that we are at stop
|
# Handle options if a tool model is present
|
||||||
if data.tool_call_start and "stop_str" in generation:
|
if tool_start:
|
||||||
generations = await generate_tool_calls(
|
if "stop_str" in generation:
|
||||||
data,
|
generations = await generate_tool_calls(
|
||||||
[generation],
|
data,
|
||||||
request,
|
[generation],
|
||||||
current_generations=current_generation_text,
|
request,
|
||||||
)
|
current_generation_text=current_generation_text,
|
||||||
generation = generations[0] # We only have one generation in this case
|
)
|
||||||
|
|
||||||
|
# Only one generation present in this case
|
||||||
|
generation = generations[0]
|
||||||
|
elif "text" in generation:
|
||||||
|
current_generation_text += generation["text"]
|
||||||
|
|
||||||
# Stream collector will push an exception to the queue if it fails
|
# Stream collector will push an exception to the queue if it fails
|
||||||
if isinstance(generation, Exception):
|
if isinstance(generation, Exception):
|
||||||
|
|
@ -428,6 +422,7 @@ async def generate_chat_completion(
|
||||||
model_path: pathlib.Path,
|
model_path: pathlib.Path,
|
||||||
):
|
):
|
||||||
gen_tasks: List[asyncio.Task] = []
|
gen_tasks: List[asyncio.Task] = []
|
||||||
|
tool_start = model.container.prompt_template.metadata.tool_start
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.info(f"Received chat completion request {request.state.id}")
|
logger.info(f"Received chat completion request {request.state.id}")
|
||||||
|
|
@ -448,8 +443,8 @@ async def generate_chat_completion(
|
||||||
|
|
||||||
generations = await asyncio.gather(*gen_tasks)
|
generations = await asyncio.gather(*gen_tasks)
|
||||||
|
|
||||||
# Let's not waste our time if we arn't running a tool model
|
# Check all the generations and see if a tool call is required
|
||||||
if data.tool_call_start:
|
if tool_start:
|
||||||
generations = await generate_tool_calls(data, generations, request)
|
generations = await generate_tool_calls(data, generations, request)
|
||||||
|
|
||||||
response = _create_response(request.state.id, generations, model_path.name)
|
response = _create_response(request.state.id, generations, model_path.name)
|
||||||
|
|
@ -472,51 +467,52 @@ async def generate_tool_calls(
|
||||||
data: ChatCompletionRequest,
|
data: ChatCompletionRequest,
|
||||||
generations: List[str],
|
generations: List[str],
|
||||||
request: Request,
|
request: Request,
|
||||||
current_generations: str = None,
|
current_generation_text: str = None,
|
||||||
):
|
):
|
||||||
gen_tasks: List[asyncio.Task] = []
|
gen_tasks: List[asyncio.Task] = []
|
||||||
|
tool_start = model.container.prompt_template.metadata.tool_start
|
||||||
|
|
||||||
|
# Tracks which generations asked for a tool call
|
||||||
tool_idx: List[int] = []
|
tool_idx: List[int] = []
|
||||||
|
|
||||||
# Copy to make sure the parent JSON schema doesn't get modified
|
# Copy to make sure the parent JSON schema doesn't get modified
|
||||||
# FIXME: May not be necessary depending on how the codebase evolves
|
|
||||||
tool_data = data.model_copy(deep=True)
|
tool_data = data.model_copy(deep=True)
|
||||||
tool_data.json_schema = tool_data.tool_call_schema
|
tool_data.json_schema = TOOL_CALL_SCHEMA
|
||||||
|
|
||||||
for idx, gen in enumerate(generations):
|
for idx, gen in enumerate(generations):
|
||||||
if gen["stop_str"] in tool_data.tool_call_start:
|
if gen["stop_str"] != tool_start:
|
||||||
logger.info(
|
continue
|
||||||
f"Detected tool call in chat completion request {request.state.id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if "text" in gen:
|
logger.info(f"Detected tool call in chat completion request {request.state.id}")
|
||||||
# non streaming, all generations will have the text they generated
|
|
||||||
pre_tool_prompt, embeddings = await apply_chat_template(
|
|
||||||
data, gen["text"]
|
|
||||||
)
|
|
||||||
elif current_generations is not None:
|
|
||||||
# streaming, we wont have text in the generation,
|
|
||||||
# we'll have to use the current_generations
|
|
||||||
pre_tool_prompt, embeddings = await apply_chat_template(
|
|
||||||
data, current_generations
|
|
||||||
)
|
|
||||||
|
|
||||||
request_id = _parse_gen_request_id(data.n, request.state.id, idx)
|
# Append the existing generation as part of the response prefix
|
||||||
|
precursor_text = current_generation_text or gen.get("text")
|
||||||
|
if precursor_text:
|
||||||
|
tool_data.response_prefix = precursor_text
|
||||||
|
|
||||||
gen_tasks.append(
|
pre_tool_prompt, embeddings = await apply_chat_template(tool_data)
|
||||||
asyncio.create_task(
|
|
||||||
model.container.generate(
|
gen_request_id = _parse_gen_request_id(data.n, request.state.id, idx)
|
||||||
request_id,
|
tool_request_id = f"{gen_request_id}-tool"
|
||||||
pre_tool_prompt,
|
|
||||||
tool_data,
|
gen_tasks.append(
|
||||||
mm_embeddings=embeddings,
|
asyncio.create_task(
|
||||||
)
|
model.container.generate(
|
||||||
|
tool_request_id,
|
||||||
|
pre_tool_prompt,
|
||||||
|
tool_data,
|
||||||
|
mm_embeddings=embeddings,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
tool_idx.append(idx)
|
)
|
||||||
|
|
||||||
tool_calls = await asyncio.gather(*gen_tasks)
|
tool_idx.append(idx)
|
||||||
for outer_idx in range(0, len(tool_idx)):
|
|
||||||
gen_idx = tool_idx[outer_idx]
|
if len(tool_idx) > 0:
|
||||||
generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"]
|
tool_calls = await asyncio.gather(*gen_tasks)
|
||||||
|
|
||||||
|
# Map tool calls to their appropriate generation
|
||||||
|
for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True):
|
||||||
|
generations[gen_idx]["tool_calls"] = tool_call["text"]
|
||||||
|
|
||||||
return generations
|
return generations
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,31 @@ from typing import List
|
||||||
from endpoints.OAI.types.tools import ToolCall
|
from endpoints.OAI.types.tools import ToolCall
|
||||||
|
|
||||||
|
|
||||||
|
TOOL_CALL_SCHEMA = {
|
||||||
|
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"id": {"type": "string"},
|
||||||
|
"function": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {"type": "string"},
|
||||||
|
"arguments": {
|
||||||
|
# Converted to OAI's string in post process
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"required": ["name", "arguments"],
|
||||||
|
},
|
||||||
|
"type": {"type": "string", "enum": ["function"]},
|
||||||
|
},
|
||||||
|
"required": ["id", "function", "type"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class ToolCallProcessor:
|
class ToolCallProcessor:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_json(tool_calls_str: str) -> List[ToolCall]:
|
def from_json(tool_calls_str: str) -> List[ToolCall]:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue