diff --git a/docs/10.-Tool-Calling.md b/docs/10.-Tool-Calling.md index e32a55b..801f85e 100644 --- a/docs/10.-Tool-Calling.md +++ b/docs/10.-Tool-Calling.md @@ -37,6 +37,11 @@ For example, if you are using a Llama 3.1 Family model you can simply modify you If loading via `/v1/model/load`, you would also need to specify a tool-supporting `prompt_template`. +## Tool Template Variables + +- `tools`: Tools object. +- `tools_json`: Tools object as a JSON string. + ## Creating a Tool Calling Prompt Template Here's how to create a TabbyAPI tool calling prompt template: @@ -142,4 +147,4 @@ When creating your own tool calling `prompt_template`, it's best to reference th ## Support and Bug Reporting -For bugs, please create a detailed issue with the model, prompt template, and conversation that caused it. Alternatively, join our [Discord](https://discord.gg/sYQxnuD7Fj) and ask for Storm. \ No newline at end of file +For bugs, please create a detailed issue with the model, prompt template, and conversation that caused it. Alternatively, join our [Discord](https://discord.gg/sYQxnuD7Fj) and ask for Storm. diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 2326bc2..f66cb52 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -234,6 +234,10 @@ async def format_messages_with_template( if message.tool_calls: message.tool_calls_json = ToolCallProcessor.to_json(message.tool_calls) + # The tools variable is inspectable in the template, so + # store the list of dicts rather than the ToolCallProcessor object. + message.tool_calls = ToolCallProcessor.dump(message.tool_calls) + special_tokens_dict = model.container.get_special_tokens( add_bos_token, ban_eos_token ) @@ -252,11 +256,16 @@ async def apply_chat_template( Template stop strings can be overriden by sampler overrides if force is true. """ + # Locally store tools dict + tools = data.model_dump()["tools"] + try: data.template_vars.update( { "add_generation_prompt": data.add_generation_prompt, - "tools_json": json.dumps(data.model_dump()["tools"], indent=2), + "tools": tools, + "tools_json": json.dumps(tools, indent=2), + "functions": data.functions, "functions_json": json.dumps(data.functions, indent=2), "tool_precursor": tool_precursor, } @@ -460,6 +469,10 @@ async def generate_tool_calls( for idx, gen in enumerate(generations): if gen["stop_str"] in tool_data.tool_call_start: + logger.info( + f"Detected tool call in chat completion request {request.state.id}" + ) + if "text" in gen: # non streaming, all generations will have the text they generated pre_tool_prompt, mm_embeddings = await apply_chat_template( diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index 7e399a8..7650e96 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -18,6 +18,28 @@ class ToolCallProcessor: return [ToolCall(**tool_call) for tool_call in tool_calls] + @staticmethod + def dump(tool_calls: List[ToolCall]) -> List[dict]: + """ + Convert ToolCall objects to a list of dictionaries. + + Args: + tool_calls (List[ToolCall]): List of ToolCall objects to convert + + Returns: + List[dict]: List of dictionaries representing the tool calls + """ + + # Don't use list comprehension here + # as that will fail rather than warn + dumped_tool_calls = [] + for tool_call_obj in tool_calls: + try: + dumped_tool_calls.append(tool_call_obj.model_dump()) + except (json.JSONDecodeError, AttributeError) as e: + logger.warning(f"Error processing tool call: {e}") + return dumped_tool_calls + @staticmethod def to_json(tool_calls: List[ToolCall]) -> str: """ @@ -33,14 +55,8 @@ class ToolCallProcessor: if not tool_calls: return "" - # Don't use list comprehension here - # as that will fail rather than warn - dumped_tool_calls = [] - for tool_call_obj in tool_calls: - try: - dumped_tool_calls.append(tool_call_obj.model_dump()) - except (json.JSONDecodeError, AttributeError) as e: - logger.warning(f"Error processing tool call: {e}") + # Use the dump method to get the list of dictionaries + dumped_tool_calls = ToolCallProcessor.dump(tool_calls) # Serialize the dumped array return json.dumps(dumped_tool_calls, indent=2)