Doing this helps reduce the model's burden of generating the tool call ID and type (which is always "function"). Follow mistral's spec for tool call IDs by using a 9 character alphanumeric string. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
85 lines
2.5 KiB
Python
85 lines
2.5 KiB
Python
import json
|
|
from loguru import logger
|
|
from typing import List
|
|
|
|
from endpoints.OAI.types.tools import ToolCall
|
|
|
|
|
|
TOOL_CALL_SCHEMA = {
|
|
"$schema": "http://json-schema.org/draft-07/schema#",
|
|
"type": "array",
|
|
"items": {
|
|
"type": "object",
|
|
"properties": {
|
|
"function": {
|
|
"type": "object",
|
|
"properties": {
|
|
"name": {"type": "string"},
|
|
"arguments": {
|
|
# Converted to OAI's string in post process
|
|
"type": "object"
|
|
},
|
|
},
|
|
"required": ["name", "arguments"],
|
|
},
|
|
},
|
|
"required": ["function"],
|
|
},
|
|
}
|
|
|
|
|
|
class ToolCallProcessor:
|
|
@staticmethod
|
|
def from_json(tool_calls_str: str) -> List[ToolCall]:
|
|
"""Postprocess tool call JSON to a parseable class"""
|
|
|
|
tool_calls = json.loads(tool_calls_str)
|
|
for tool_call in tool_calls:
|
|
tool_call["function"]["arguments"] = json.dumps(
|
|
tool_call["function"]["arguments"]
|
|
)
|
|
|
|
return [ToolCall(**tool_call) for tool_call in tool_calls]
|
|
|
|
@staticmethod
|
|
def dump(tool_calls: List[ToolCall]) -> List[dict]:
|
|
"""
|
|
Convert ToolCall objects to a list of dictionaries.
|
|
|
|
Args:
|
|
tool_calls (List[ToolCall]): List of ToolCall objects to convert
|
|
|
|
Returns:
|
|
List[dict]: List of dictionaries representing the tool calls
|
|
"""
|
|
|
|
# Don't use list comprehension here
|
|
# as that will fail rather than warn
|
|
dumped_tool_calls = []
|
|
for tool_call_obj in tool_calls:
|
|
try:
|
|
dumped_tool_calls.append(tool_call_obj.model_dump())
|
|
except (json.JSONDecodeError, AttributeError) as e:
|
|
logger.warning(f"Error processing tool call: {e}")
|
|
return dumped_tool_calls
|
|
|
|
@staticmethod
|
|
def to_json(tool_calls: List[ToolCall]) -> str:
|
|
"""
|
|
Convert ToolCall objects to JSON string representation.
|
|
|
|
Args:
|
|
tool_calls (List[ToolCall]): List of ToolCall objects to convert
|
|
|
|
Returns:
|
|
str: JSON representation of the tool calls
|
|
"""
|
|
|
|
if not tool_calls:
|
|
return ""
|
|
|
|
# Use the dump method to get the list of dictionaries
|
|
dumped_tool_calls = ToolCallProcessor.dump(tool_calls)
|
|
|
|
# Serialize the dumped array
|
|
return json.dumps(dumped_tool_calls, indent=2)
|