[WIP] OpenAI Tools Support/Function calling (#154)

* returning stop str if exists from gen

* added chat template for firefunctionv2

* pulling tool vars from template

* adding parsing for tool inputs/outputs

* passing tool data from endpoint to chat template, adding tool_start to the stop list

* loosened typing on the response tool call, leaning more on the user supplying a quality schema if they want a particular format

* non streaming generation prototype

* cleaning template

* Continued work with type, ingestion into template, and chat template for fire func

* Correction - streaming toolcall comes back as delta obj not inside chatcomprespchoice per chat_completion_chunk.py inside OAI lib.

* Ruff Formating

* Moved stop string and tool updates out of prompt creation func

Updated tool pydantic to match OAI

Support for streaming

Updated generate tool calls to use flag within chat_template and insert tool reminder

* Llama 3.1 chat templates

Updated fire func template

* renamed llama3.1 to chatml_with_headers..

* update name of template

* Support for calling a tool start token rather than the string.

Simplified tool_params

Warning when gen_settings are being overidden becuase user set temp to 0

Corrected schema and tools to correct types for function args. Str for some reason

* draft groq tool use model template

* changed headers to vars for readablity (but mostly because some models are weird about newlines after headers, so this is an easier way to change globally)

* Clean up comments and code in chat comp

* Post processed tool call to meet OAI spec rather than forcing model to write json in a string in the middle of the call.

* changes example back to args as json rather than string of json

* Standardize chat templates to each other

* cleaning/rewording

* stop elements can also be ints (tokens)

* Cleaning/formatting

* added special tokens for tools and tool_response as specified in description

* Cleaning

* removing aux templates - going to live in llm-promp-templates repo instead

* Tree: Format

Signed-off-by: kingbri <bdashore3@proton.me>

* Chat Completions: Don't include internal tool variables in OpenAPI

Use SkipJsonSchema to supress inclusion with the OpenAPI JSON. The
location of these variables may need to be changed in the future.

Signed-off-by: kingbri <bdashore3@proton.me>

* Templates: Deserialize metadata on template load

Since we're only looking for specific template variables that are
static in the template, it makes more sense to render when the template
is initialized.

Signed-off-by: kingbri <bdashore3@proton.me>

* Tools: Fix comments

Adhere to the format style of comments in the rest of the project.

Signed-off-by: kingbri <bdashore3@proton.me>

---------

Co-authored-by: Ben Gitter <gitterbd@gmail.com>
Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
Ben Gitter 2024-08-17 00:16:25 -04:00 committed by GitHub
parent 9cc0e70098
commit 70b9fc95de
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 336 additions and 23 deletions

View file

@ -33,7 +33,7 @@ class BaseSamplerRequest(BaseModel):
examples=[512],
)
stop: Optional[Union[str, List[str]]] = Field(
stop: Optional[Union[str, List[Union[str, int]]]] = Field(
default_factory=lambda: get_default_sampler_value("stop", []),
validation_alias=AliasChoices("stop", "stop_sequence"),
description="Aliases: stop_sequence",

View file

@ -3,7 +3,7 @@
import json
import pathlib
from importlib.metadata import version as package_version
from typing import Optional
from typing import List, Optional
from jinja2 import Template, TemplateError
from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
@ -18,6 +18,13 @@ class TemplateLoadError(Exception):
pass
class TemplateMetadata:
"""Represents the parsed metadata from a template."""
stop_strings: List[str] = []
tool_starts: List[str] = []
class PromptTemplate:
"""A template for chat completion prompts."""
@ -27,23 +34,33 @@ class PromptTemplate:
environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment(
trim_blocks=True, lstrip_blocks=True
)
metadata: TemplateMetadata
def stop_strings(self, template_vars: dict):
"""Appends extra stop strings if present in a chat template."""
def extract_metadata(self):
"""Returns deserialized template metadata from a chat template."""
extra_stop_strings = []
template_module = self.template.make_module(template_vars)
template_metadata = TemplateMetadata()
template_module = self.template.make_module()
if hasattr(template_module, "stop_strings"):
if isinstance(template_module.stop_strings, list):
extra_stop_strings += template_module.stop_strings
template_metadata.stop_strings += template_module.stop_strings
else:
logger.warning(
"Skipping append of stopping strings from chat template "
"because stop_strings isn't a list."
)
return extra_stop_strings
if hasattr(template_module, "tool_start"):
if isinstance(template_module.tool_start, str):
template_metadata.tool_starts.append(template_module.tool_start)
if hasattr(template_module, "tool_start_token"):
if isinstance(template_module.tool_start_token, int):
template_metadata.tool_starts.append(template_module.tool_start_token)
return template_metadata
def render(self, template_vars: dict):
"""Get a prompt from a template and a list of messages."""
@ -56,9 +73,8 @@ class PromptTemplate:
)
rendered_template = self.template.render(**template_vars)
template_stop_strings = self.stop_strings(template_vars)
return rendered_template, template_stop_strings
return rendered_template
def compile(self, template_str: str):
"""Compiles and stores a jinja2 template"""
@ -77,6 +93,7 @@ class PromptTemplate:
self.name = name
self.raw_template = raw_template
self.template = self.compile(raw_template)
self.metadata = self.extract_metadata()
@classmethod
def from_file(self, prompt_template_name: str):