[WIP] OpenAI Tools Support/Function calling (#154)

* returning stop str if exists from gen

* added chat template for firefunctionv2

* pulling tool vars from template

* adding parsing for tool inputs/outputs

* passing tool data from endpoint to chat template, adding tool_start to the stop list

* loosened typing on the response tool call, leaning more on the user supplying a quality schema if they want a particular format

* non streaming generation prototype

* cleaning template

* Continued work with type, ingestion into template, and chat template for fire func

* Correction - streaming toolcall comes back as delta obj not inside chatcomprespchoice per chat_completion_chunk.py inside OAI lib.

* Ruff Formating

* Moved stop string and tool updates out of prompt creation func

Updated tool pydantic to match OAI

Support for streaming

Updated generate tool calls to use flag within chat_template and insert tool reminder

* Llama 3.1 chat templates

Updated fire func template

* renamed llama3.1 to chatml_with_headers..

* update name of template

* Support for calling a tool start token rather than the string.

Simplified tool_params

Warning when gen_settings are being overidden becuase user set temp to 0

Corrected schema and tools to correct types for function args. Str for some reason

* draft groq tool use model template

* changed headers to vars for readablity (but mostly because some models are weird about newlines after headers, so this is an easier way to change globally)

* Clean up comments and code in chat comp

* Post processed tool call to meet OAI spec rather than forcing model to write json in a string in the middle of the call.

* changes example back to args as json rather than string of json

* Standardize chat templates to each other

* cleaning/rewording

* stop elements can also be ints (tokens)

* Cleaning/formatting

* added special tokens for tools and tool_response as specified in description

* Cleaning

* removing aux templates - going to live in llm-promp-templates repo instead

* Tree: Format

Signed-off-by: kingbri <bdashore3@proton.me>

* Chat Completions: Don't include internal tool variables in OpenAPI

Use SkipJsonSchema to supress inclusion with the OpenAPI JSON. The
location of these variables may need to be changed in the future.

Signed-off-by: kingbri <bdashore3@proton.me>

* Templates: Deserialize metadata on template load

Since we're only looking for specific template variables that are
static in the template, it makes more sense to render when the template
is initialized.

Signed-off-by: kingbri <bdashore3@proton.me>

* Tools: Fix comments

Adhere to the format style of comments in the rest of the project.

Signed-off-by: kingbri <bdashore3@proton.me>

---------

Co-authored-by: Ben Gitter <gitterbd@gmail.com>
Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
Ben Gitter 2024-08-17 00:16:25 -04:00 committed by GitHub
parent 9cc0e70098
commit 70b9fc95de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 336 additions and 23 deletions

1
.gitignore vendored
View file

@ -192,6 +192,7 @@ templates/*
!templates/place_your_templates_here.txt
!templates/alpaca.jinja
!templates/chatml.jinja
!templates/chatml_with_headers_tool_calling.jinja
# Sampler overrides folder
sampler_overrides/*

View file

@ -852,6 +852,7 @@ class ExllamaV2Container:
"text": "",
"prompt_tokens": 0,
"generation_tokens": 0,
"tool_calls": None,
"offset": [],
"token_probs": {},
"logprobs": [],
@ -864,6 +865,7 @@ class ExllamaV2Container:
joined_generation["finish_reason"] = finish_reason_gen.get(
"finish_reason"
)
joined_generation["stop_str"] = finish_reason_gen.get("stop_str")
else:
joined_generation["finish_reason"] = "stop"
@ -1068,6 +1070,15 @@ class ExllamaV2Container:
gen_settings.top_p = 0
gen_settings.typical = 0
logger.warning(
"".join(
[
"Temperature is set to 0. Overriding temp, ",
"top_k, top_p, and typical to 1.0, 1, 0, and 0.",
]
)
)
# Store the gen settings for logging purposes
gen_settings_log_dict = vars(gen_settings)
@ -1227,9 +1238,17 @@ class ExllamaV2Container:
log_response(request_id, full_response)
eos_reason = result.get("eos_reason")
finish_reason = (
"length" if eos_reason == "max_new_tokens" else "stop"
)
stop_str = None
if eos_reason == "max_new_tokens":
finish_reason = "length"
else:
finish_reason = "stop"
# Grab stop string if stop was the reason
if eos_reason == "stop_token":
stop_str = result.get("eos_triggering_token_str")
elif eos_reason == "stop_string":
stop_str = result.get("eos_triggering_string")
# Save the final result for metrics logging
metrics_result = result
@ -1239,6 +1258,7 @@ class ExllamaV2Container:
"prompt_tokens": generation.get("prompt_tokens"),
"generated_tokens": generation.get("generated_tokens"),
"finish_reason": finish_reason,
"stop_str": stop_str,
}
yield generation

View file

@ -33,7 +33,7 @@ class BaseSamplerRequest(BaseModel):
examples=[512],
)
stop: Optional[Union[str, List[str]]] = Field(
stop: Optional[Union[str, List[Union[str, int]]]] = Field(
default_factory=lambda: get_default_sampler_value("stop", []),
validation_alias=AliasChoices("stop", "stop_sequence"),
description="Aliases: stop_sequence",

View file

@ -3,7 +3,7 @@
import json
import pathlib
from importlib.metadata import version as package_version
from typing import Optional
from typing import List, Optional
from jinja2 import Template, TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
@ -18,6 +18,13 @@ class TemplateLoadError(Exception):
pass
class TemplateMetadata:
"""Represents the parsed metadata from a template."""
stop_strings: List[str] = []
tool_starts: List[str] = []
class PromptTemplate:
"""A template for chat completion prompts."""
@ -27,23 +34,33 @@ class PromptTemplate:
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True
)
metadata: TemplateMetadata
def stop_strings(self, template_vars: dict):
"""Appends extra stop strings if present in a chat template."""
def extract_metadata(self):
"""Returns deserialized template metadata from a chat template."""
extra_stop_strings = []
template_module = self.template.make_module(template_vars)
template_metadata = TemplateMetadata()
template_module = self.template.make_module()
if hasattr(template_module, "stop_strings"):
if isinstance(template_module.stop_strings, list):
extra_stop_strings += template_module.stop_strings
template_metadata.stop_strings += template_module.stop_strings
else:
logger.warning(
"Skipping append of stopping strings from chat template "
"because stop_strings isn't a list."
)
return extra_stop_strings
if hasattr(template_module, "tool_start"):
if isinstance(template_module.tool_start, str):
template_metadata.tool_starts.append(template_module.tool_start)
if hasattr(template_module, "tool_start_token"):
if isinstance(template_module.tool_start_token, int):
template_metadata.tool_starts.append(template_module.tool_start_token)
return template_metadata
def render(self, template_vars: dict):
"""Get a prompt from a template and a list of messages."""
@ -56,9 +73,8 @@ class PromptTemplate:
)
rendered_template = self.template.render(**template_vars)
template_stop_strings = self.stop_strings(template_vars)
return rendered_template, template_stop_strings
return rendered_template
def compile(self, template_str: str):
"""Compiles and stores a jinja2 template"""
@ -77,6 +93,7 @@ class PromptTemplate:
self.name = name
self.raw_template = raw_template
self.template = self.compile(raw_template)
self.metadata = self.extract_metadata()
@classmethod
def from_file(self, prompt_template_name: str):

View file

@ -1,9 +1,11 @@
from pydantic import BaseModel, Field
from pydantic.json_schema import SkipJsonSchema
from time import time
from typing import Union, List, Optional, Dict
from uuid import uuid4
from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema
class ChatCompletionLogprob(BaseModel):
@ -19,12 +21,16 @@ class ChatCompletionLogprobs(BaseModel):
class ChatCompletionMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
class ChatCompletionRespChoice(BaseModel):
# Index is 0 since we aren't using multiple choices
index: int = 0
finish_reason: Optional[str] = None
# let's us understand why it stopped and if we need to generate a tool_call
stop_str: Optional[str] = None
message: ChatCompletionMessage
logprobs: Optional[ChatCompletionLogprobs] = None
@ -42,12 +48,28 @@ class ChatCompletionRequest(CommonCompletionRequest):
# Messages
# Take in a string as well even though it's not part of the OAI spec
# support messages.content as a list of dict
messages: Union[str, List[Dict[str, Union[str, List[Dict[str, str]]]]]]
# WIP this can probably be tightened, or maybe match the OAI lib type
# in openai\types\chat\chat_completion_message_param.py
messages: Union[str, List[Dict]]
prompt_template: Optional[str] = None
add_generation_prompt: Optional[bool] = True
template_vars: Optional[dict] = {}
response_prefix: Optional[str] = None
# tools is follows the format OAI schema, functions is more flexible
# both are available in the chat template.
tools: Optional[List[ToolSpec]] = None
functions: Optional[List[Dict]] = None
# Typically collected from Chat Template.
# Don't include this in the OpenAPI docs
# TODO: Use these custom parameters
tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = None
tool_call_end: SkipJsonSchema[Optional[str]] = None
tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema
class ChatCompletionResponse(BaseModel):
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}")

View file

@ -0,0 +1,58 @@
from pydantic import BaseModel
from typing import Dict, Literal
tool_call_schema = {
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "array",
"items": {
"type": "object",
"properties": {
"id": {"type": "string"},
"function": {
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
# Converted to OAI's string in post process
"type": "object"
},
},
"required": ["name", "arguments"],
},
"type": {"type": "string", "enum": ["function"]},
},
"required": ["id", "function", "type"],
},
}
class Function(BaseModel):
"""Represents a description of a tool function."""
name: str
description: str
parameters: Dict[str, object]
class ToolSpec(BaseModel):
"""Wrapper for an inner tool function."""
function: Function
type: Literal["function"]
class Tool(BaseModel):
"""Represents an OAI tool description."""
name: str
# Makes more sense to be a dict, but OAI knows best
arguments: str
class ToolCall(BaseModel):
"""Represents an OAI tool description."""
id: str
function: Tool
type: Literal["function"]

View file

@ -5,6 +5,7 @@ import pathlib
from asyncio import CancelledError
from copy import deepcopy
from typing import List, Optional
import json
from fastapi import HTTPException, Request
from jinja2 import TemplateError
@ -30,6 +31,7 @@ from endpoints.OAI.types.chat_completion import (
)
from endpoints.OAI.types.common import UsageStats
from endpoints.OAI.utils.completion import _stream_collector
from endpoints.OAI.types.tools import ToolCall
def _create_response(
@ -46,6 +48,10 @@ def _create_response(
role="assistant", content=unwrap(generation.get("text"), "")
)
tool_calls = generation["tool_calls"]
if tool_calls:
message.tool_calls = postprocess_tool_call(tool_calls)
logprob_response = None
token_probs = unwrap(generation.get("token_probs"), {})
@ -72,6 +78,7 @@ def _create_response(
choice = ChatCompletionRespChoice(
index=index,
finish_reason=generation.get("finish_reason"),
stop_str=generation.get("stop_str"),
message=message,
logprobs=logprob_response,
)
@ -119,7 +126,16 @@ def _create_stream_chunk(
finish_reason=generation.get("finish_reason"),
)
# lets check if we have tool calls since we are at the end of the generation
if "tool_calls" in generation:
tool_calls = generation["tool_calls"]
message = ChatCompletionMessage(
tool_calls=postprocess_tool_call(tool_calls)
)
choice.delta = message
choices.append(choice)
else:
message = ChatCompletionMessage(
role="assistant", content=unwrap(generation.get("text"), "")
@ -162,7 +178,29 @@ def _create_stream_chunk(
return chunk
def format_prompt_with_template(data: ChatCompletionRequest):
def _append_template_metadata(data: ChatCompletionRequest):
"""Adding metadata is a one-time process."""
template_metadata = model.container.prompt_template.metadata
# Stop strings
if isinstance(data.stop, str):
data.stop = [data.stop] + template_metadata.stop_strings
else:
data.stop += template_metadata.stop_strings
# Tool call start strings
if template_metadata.tool_starts:
if data.tool_call_start is None:
data.tool_call_start = template_metadata.tool_starts
# Append to stop strings to halt for a tool call generation
data.stop.extend(template_metadata.tool_starts)
def format_prompt_with_template(
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
):
"""
Compile the prompt and get any additional stop strings from the template.
Template stop strings can be overriden by sampler overrides if force is true.
@ -187,18 +225,22 @@ def format_prompt_with_template(data: ChatCompletionRequest):
"",
)
if "tool_calls" in message:
message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2)
# Overwrite any protected vars with their values
data.template_vars.update(
{
"messages": data.messages,
"add_generation_prompt": data.add_generation_prompt,
"tools_json": json.dumps(data.model_dump()["tools"], indent=2),
"functions_json": json.dumps(data.functions, indent=2),
"tool_precursor": tool_precursor,
**special_tokens_dict,
}
)
prompt, template_stop_strings = model.container.prompt_template.render(
data.template_vars
)
prompt = model.container.prompt_template.render(data.template_vars)
# Append response prefix if present
if data.response_prefix:
@ -216,11 +258,8 @@ def format_prompt_with_template(data: ChatCompletionRequest):
if bos_token and prompt.startswith(bos_token):
prompt = prompt.removeprefix(bos_token)
# Append template stop strings
if isinstance(data.stop, str):
data.stop = [data.stop] + template_stop_strings
else:
data.stop += template_stop_strings
# Add template metadata
_append_template_metadata(data)
return prompt
@ -271,6 +310,9 @@ async def stream_generate_chat_completion(
gen_tasks.append(gen_task)
# We need to keep track of the text generated so we can resume the tool calls
current_generation_text = ""
# Consumer loop
while True:
if disconnect_task.done():
@ -280,6 +322,19 @@ async def stream_generate_chat_completion(
)
generation = await gen_queue.get()
# lets only append the text if we need it for tool calls later
if data.tool_call_start and "text" in generation:
current_generation_text += generation["text"]
# check if we are running a tool model, and that we are at stop
if data.tool_call_start and "stop_str" in generation:
generations = await generate_tool_calls(
data,
[generation],
request,
current_generations=current_generation_text,
)
generation = generations[0] # We only have one generation in this case
# Stream collector will push an exception to the queue if it fails
if isinstance(generation, Exception):
@ -344,6 +399,11 @@ async def generate_chat_completion(
)
generations = await asyncio.gather(*gen_tasks)
# Let's not waste our time if we arn't running a tool model
if data.tool_call_start:
generations = await generate_tool_calls(data, generations, request)
response = _create_response(request.state.id, generations, model_path.name)
logger.info(f"Finished chat completion request {request.state.id}")
@ -358,3 +418,54 @@ async def generate_chat_completion(
# Server error if there's a generation exception
raise HTTPException(503, error_message) from exc
async def generate_tool_calls(
data: ChatCompletionRequest,
generations: List[str],
request: Request,
current_generations: str = None,
):
gen_tasks: List[asyncio.Task] = []
tool_idx: List[int] = []
# Copy to make sure the parent JSON schema doesn't get modified
# FIXME: May not be necessary depending on how the codebase evolves
tool_data = deepcopy(data)
tool_data.json_schema = tool_data.tool_call_schema
gen_params = tool_data.to_gen_params()
for idx, gen in enumerate(generations):
if gen["stop_str"] in tool_data.tool_call_start:
if "text" in gen:
# non streaming, all generations will have the text they generated
pre_tool_prompt = format_prompt_with_template(data, gen["text"])
elif current_generations is not None:
# streaming, we wont have text in the generation,
# we'll have to use the current_generations
pre_tool_prompt = format_prompt_with_template(data, current_generations)
gen_tasks.append(
asyncio.create_task(
model.container.generate(
pre_tool_prompt, request.state.id, **gen_params
)
)
)
tool_idx.append(idx)
tool_calls = await asyncio.gather(*gen_tasks)
for outer_idx in range(0, len(tool_idx)):
gen_idx = tool_idx[outer_idx]
generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"]
return generations
def postprocess_tool_call(call_str: str) -> List[ToolCall]:
tool_calls = json.loads(call_str)
for tool_call in tool_calls:
tool_call["function"]["arguments"] = json.dumps(
tool_call["function"]["arguments"]
)
return [ToolCall(**tool_call) for tool_call in tool_calls]

View file

@ -0,0 +1,84 @@
{# Metadata #}
{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %}
{% set message_roles = ['system', 'user', 'assistant', 'tool'] %}
{% set tool_start = "<|tool_start|>" %}
{% set tool_end = "<|tool_end|>" %}
{%- set start_header = "<|start_header_id|>" -%}
{%- set end_header = "<|end_header_id|>\n" -%}
{%- set example_tool_call -%}[
{
"id": "tool_id_1342",
"function": {
"arguments": "arg_name": 3,
"name": "tool_name"
},
"type": "function"
},
{
"id": "example_id_13f42",
"function": {
"arguments": "example_arg": 1.0, "another_example_arg": true,
"name": "another_tool_name"
},
"type": "function"
}
]
{%- endset -%}
{%- set inital_system_prompt -%}You are an assistant that has access to the following set of tools, to call a tool:
1. Prefix calls with '{{ tool_start }}' and end calls with '{{ tool_end }}'
2. Ensure you use the correct type for arguments. For example, if the argument is a string, ensure it is enclosed in quotes, otherwise, it should not be.
3. Generate all calls using the following json tool call format. Here is a multi tool call example:
{{ tool_start }}{{ example_tool_call }}{{ tool_end }}
Here are the tools available for you to call:
{{ tools_json }}
{%- endset -%}
{%- set tool_reminder -%}Available Tools:
{{ tools_json }}
Tool Call Format Example:
{{ tool_start }}{{ example_tool_call }}
Prefix & Suffix: Begin tool calls with {{ tool_start }} and end with {{ tool_end }}.
Argument Types: Use correct data types for arguments (e.g., strings in quotes, numbers without).
{%- endset -%}
{# Template #}
{%- for message in messages -%}
{%- set role = message['role'] | lower -%}
{%- if role not in message_roles -%}
{{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles | join(', ') + ' are supported.') }}
{%- endif -%}
{%- set content = message['content'] | default('', true) | trim -%}
{%- if loop.first -%}
{{ bos_token }}{{ start_header }}{{ role }}{{ end_header }}
{{ inital_system_prompt }}
{{ content }}{{ eos_token }}
{%- endif -%}
{%- if not loop.first -%}
{{ start_header }}{{ role }}{{ end_header }}
{{ content }}
{%- if 'tool_calls_json' in message and message['tool_calls_json'] -%}
{{ tool_start }}{{ message['tool_calls_json']}}{{ tool_end }}
{%- endif -%}
{{ eos_token }}
{%- endif -%}
{%- endfor -%}
{%- if tool_precursor -%}
{{ start_header }}system{{ end_header }}
{{ tool_reminder }}{{ eos_token }}
{{ start_header }}assistant{{ end_header }}
{{ tool_precursor }}{{ tool_start }}
{%- else -%}
{{ start_header }}assistant{{ end_header }}
{%- endif -%}