diff --git a/.gitignore b/.gitignore index 2761a6b..49aa517 100644 --- a/.gitignore +++ b/.gitignore @@ -192,6 +192,7 @@ templates/* !templates/place_your_templates_here.txt !templates/alpaca.jinja !templates/chatml.jinja +!templates/chatml_with_headers_tool_calling.jinja # Sampler overrides folder sampler_overrides/* diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 98b5636..d9f2af6 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -852,6 +852,7 @@ class ExllamaV2Container: "text": "", "prompt_tokens": 0, "generation_tokens": 0, + "tool_calls": None, "offset": [], "token_probs": {}, "logprobs": [], @@ -864,6 +865,7 @@ class ExllamaV2Container: joined_generation["finish_reason"] = finish_reason_gen.get( "finish_reason" ) + joined_generation["stop_str"] = finish_reason_gen.get("stop_str") else: joined_generation["finish_reason"] = "stop" @@ -1068,6 +1070,15 @@ class ExllamaV2Container: gen_settings.top_p = 0 gen_settings.typical = 0 + logger.warning( + "".join( + [ + "Temperature is set to 0. Overriding temp, ", + "top_k, top_p, and typical to 1.0, 1, 0, and 0.", + ] + ) + ) + # Store the gen settings for logging purposes gen_settings_log_dict = vars(gen_settings) @@ -1227,9 +1238,17 @@ class ExllamaV2Container: log_response(request_id, full_response) eos_reason = result.get("eos_reason") - finish_reason = ( - "length" if eos_reason == "max_new_tokens" else "stop" - ) + + stop_str = None + if eos_reason == "max_new_tokens": + finish_reason = "length" + else: + finish_reason = "stop" + # Grab stop string if stop was the reason + if eos_reason == "stop_token": + stop_str = result.get("eos_triggering_token_str") + elif eos_reason == "stop_string": + stop_str = result.get("eos_triggering_string") # Save the final result for metrics logging metrics_result = result @@ -1239,6 +1258,7 @@ class ExllamaV2Container: "prompt_tokens": generation.get("prompt_tokens"), "generated_tokens": generation.get("generated_tokens"), "finish_reason": finish_reason, + "stop_str": stop_str, } yield generation diff --git a/common/sampling.py b/common/sampling.py index 72552ce..e0eb158 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -33,7 +33,7 @@ class BaseSamplerRequest(BaseModel): examples=[512], ) - stop: Optional[Union[str, List[str]]] = Field( + stop: Optional[Union[str, List[Union[str, int]]]] = Field( default_factory=lambda: get_default_sampler_value("stop", []), validation_alias=AliasChoices("stop", "stop_sequence"), description="Aliases: stop_sequence", diff --git a/common/templating.py b/common/templating.py index 7a59946..47ce7e8 100644 --- a/common/templating.py +++ b/common/templating.py @@ -3,7 +3,7 @@ import json import pathlib from importlib.metadata import version as package_version -from typing import Optional +from typing import List, Optional from jinja2 import Template, TemplateError from jinja2.sandbox import ImmutableSandboxedEnvironment from loguru import logger @@ -18,6 +18,13 @@ class TemplateLoadError(Exception): pass +class TemplateMetadata: + """Represents the parsed metadata from a template.""" + + stop_strings: List[str] = [] + tool_starts: List[str] = [] + + class PromptTemplate: """A template for chat completion prompts.""" @@ -27,23 +34,33 @@ class PromptTemplate: environment: ImmutableSandboxedEnvironment = ImmutableSandboxedEnvironment( trim_blocks=True, lstrip_blocks=True ) + metadata: TemplateMetadata - def stop_strings(self, template_vars: dict): - """Appends extra stop strings if present in a chat template.""" + def extract_metadata(self): + """Returns deserialized template metadata from a chat template.""" - extra_stop_strings = [] - template_module = self.template.make_module(template_vars) + template_metadata = TemplateMetadata() + + template_module = self.template.make_module() if hasattr(template_module, "stop_strings"): if isinstance(template_module.stop_strings, list): - extra_stop_strings += template_module.stop_strings + template_metadata.stop_strings += template_module.stop_strings else: logger.warning( "Skipping append of stopping strings from chat template " "because stop_strings isn't a list." ) - return extra_stop_strings + if hasattr(template_module, "tool_start"): + if isinstance(template_module.tool_start, str): + template_metadata.tool_starts.append(template_module.tool_start) + + if hasattr(template_module, "tool_start_token"): + if isinstance(template_module.tool_start_token, int): + template_metadata.tool_starts.append(template_module.tool_start_token) + + return template_metadata def render(self, template_vars: dict): """Get a prompt from a template and a list of messages.""" @@ -56,9 +73,8 @@ class PromptTemplate: ) rendered_template = self.template.render(**template_vars) - template_stop_strings = self.stop_strings(template_vars) - return rendered_template, template_stop_strings + return rendered_template def compile(self, template_str: str): """Compiles and stores a jinja2 template""" @@ -77,6 +93,7 @@ class PromptTemplate: self.name = name self.raw_template = raw_template self.template = self.compile(raw_template) + self.metadata = self.extract_metadata() @classmethod def from_file(self, prompt_template_name: str): diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index b50e646..8977792 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -1,9 +1,11 @@ from pydantic import BaseModel, Field +from pydantic.json_schema import SkipJsonSchema from time import time from typing import Union, List, Optional, Dict from uuid import uuid4 from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest +from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema class ChatCompletionLogprob(BaseModel): @@ -19,12 +21,16 @@ class ChatCompletionLogprobs(BaseModel): class ChatCompletionMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = None class ChatCompletionRespChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 finish_reason: Optional[str] = None + + # let's us understand why it stopped and if we need to generate a tool_call + stop_str: Optional[str] = None message: ChatCompletionMessage logprobs: Optional[ChatCompletionLogprobs] = None @@ -42,12 +48,28 @@ class ChatCompletionRequest(CommonCompletionRequest): # Messages # Take in a string as well even though it's not part of the OAI spec # support messages.content as a list of dict - messages: Union[str, List[Dict[str, Union[str, List[Dict[str, str]]]]]] + + # WIP this can probably be tightened, or maybe match the OAI lib type + # in openai\types\chat\chat_completion_message_param.py + messages: Union[str, List[Dict]] prompt_template: Optional[str] = None add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = {} response_prefix: Optional[str] = None + # tools is follows the format OAI schema, functions is more flexible + # both are available in the chat template. + + tools: Optional[List[ToolSpec]] = None + functions: Optional[List[Dict]] = None + + # Typically collected from Chat Template. + # Don't include this in the OpenAPI docs + # TODO: Use these custom parameters + tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = None + tool_call_end: SkipJsonSchema[Optional[str]] = None + tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema + class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") diff --git a/endpoints/OAI/types/tools.py b/endpoints/OAI/types/tools.py new file mode 100644 index 0000000..b9c0b33 --- /dev/null +++ b/endpoints/OAI/types/tools.py @@ -0,0 +1,58 @@ +from pydantic import BaseModel +from typing import Dict, Literal + +tool_call_schema = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": { + # Converted to OAI's string in post process + "type": "object" + }, + }, + "required": ["name", "arguments"], + }, + "type": {"type": "string", "enum": ["function"]}, + }, + "required": ["id", "function", "type"], + }, +} + + +class Function(BaseModel): + """Represents a description of a tool function.""" + + name: str + description: str + parameters: Dict[str, object] + + +class ToolSpec(BaseModel): + """Wrapper for an inner tool function.""" + + function: Function + type: Literal["function"] + + +class Tool(BaseModel): + """Represents an OAI tool description.""" + + name: str + + # Makes more sense to be a dict, but OAI knows best + arguments: str + + +class ToolCall(BaseModel): + """Represents an OAI tool description.""" + + id: str + function: Tool + type: Literal["function"] diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 80b5715..e15e820 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -5,6 +5,7 @@ import pathlib from asyncio import CancelledError from copy import deepcopy from typing import List, Optional +import json from fastapi import HTTPException, Request from jinja2 import TemplateError @@ -30,6 +31,7 @@ from endpoints.OAI.types.chat_completion import ( ) from endpoints.OAI.types.common import UsageStats from endpoints.OAI.utils.completion import _stream_collector +from endpoints.OAI.types.tools import ToolCall def _create_response( @@ -46,6 +48,10 @@ def _create_response( role="assistant", content=unwrap(generation.get("text"), "") ) + tool_calls = generation["tool_calls"] + if tool_calls: + message.tool_calls = postprocess_tool_call(tool_calls) + logprob_response = None token_probs = unwrap(generation.get("token_probs"), {}) @@ -72,6 +78,7 @@ def _create_response( choice = ChatCompletionRespChoice( index=index, finish_reason=generation.get("finish_reason"), + stop_str=generation.get("stop_str"), message=message, logprobs=logprob_response, ) @@ -119,7 +126,16 @@ def _create_stream_chunk( finish_reason=generation.get("finish_reason"), ) + # lets check if we have tool calls since we are at the end of the generation + if "tool_calls" in generation: + tool_calls = generation["tool_calls"] + message = ChatCompletionMessage( + tool_calls=postprocess_tool_call(tool_calls) + ) + choice.delta = message + choices.append(choice) + else: message = ChatCompletionMessage( role="assistant", content=unwrap(generation.get("text"), "") @@ -162,7 +178,29 @@ def _create_stream_chunk( return chunk -def format_prompt_with_template(data: ChatCompletionRequest): +def _append_template_metadata(data: ChatCompletionRequest): + """Adding metadata is a one-time process.""" + + template_metadata = model.container.prompt_template.metadata + + # Stop strings + if isinstance(data.stop, str): + data.stop = [data.stop] + template_metadata.stop_strings + else: + data.stop += template_metadata.stop_strings + + # Tool call start strings + if template_metadata.tool_starts: + if data.tool_call_start is None: + data.tool_call_start = template_metadata.tool_starts + + # Append to stop strings to halt for a tool call generation + data.stop.extend(template_metadata.tool_starts) + + +def format_prompt_with_template( + data: ChatCompletionRequest, tool_precursor: Optional[str] = None +): """ Compile the prompt and get any additional stop strings from the template. Template stop strings can be overriden by sampler overrides if force is true. @@ -187,18 +225,22 @@ def format_prompt_with_template(data: ChatCompletionRequest): "", ) + if "tool_calls" in message: + message["tool_calls_json"] = json.dumps(message["tool_calls"], indent=2) + # Overwrite any protected vars with their values data.template_vars.update( { "messages": data.messages, "add_generation_prompt": data.add_generation_prompt, + "tools_json": json.dumps(data.model_dump()["tools"], indent=2), + "functions_json": json.dumps(data.functions, indent=2), + "tool_precursor": tool_precursor, **special_tokens_dict, } ) - prompt, template_stop_strings = model.container.prompt_template.render( - data.template_vars - ) + prompt = model.container.prompt_template.render(data.template_vars) # Append response prefix if present if data.response_prefix: @@ -216,11 +258,8 @@ def format_prompt_with_template(data: ChatCompletionRequest): if bos_token and prompt.startswith(bos_token): prompt = prompt.removeprefix(bos_token) - # Append template stop strings - if isinstance(data.stop, str): - data.stop = [data.stop] + template_stop_strings - else: - data.stop += template_stop_strings + # Add template metadata + _append_template_metadata(data) return prompt @@ -271,6 +310,9 @@ async def stream_generate_chat_completion( gen_tasks.append(gen_task) + # We need to keep track of the text generated so we can resume the tool calls + current_generation_text = "" + # Consumer loop while True: if disconnect_task.done(): @@ -280,6 +322,19 @@ async def stream_generate_chat_completion( ) generation = await gen_queue.get() + # lets only append the text if we need it for tool calls later + if data.tool_call_start and "text" in generation: + current_generation_text += generation["text"] + + # check if we are running a tool model, and that we are at stop + if data.tool_call_start and "stop_str" in generation: + generations = await generate_tool_calls( + data, + [generation], + request, + current_generations=current_generation_text, + ) + generation = generations[0] # We only have one generation in this case # Stream collector will push an exception to the queue if it fails if isinstance(generation, Exception): @@ -344,6 +399,11 @@ async def generate_chat_completion( ) generations = await asyncio.gather(*gen_tasks) + + # Let's not waste our time if we arn't running a tool model + if data.tool_call_start: + generations = await generate_tool_calls(data, generations, request) + response = _create_response(request.state.id, generations, model_path.name) logger.info(f"Finished chat completion request {request.state.id}") @@ -358,3 +418,54 @@ async def generate_chat_completion( # Server error if there's a generation exception raise HTTPException(503, error_message) from exc + + +async def generate_tool_calls( + data: ChatCompletionRequest, + generations: List[str], + request: Request, + current_generations: str = None, +): + gen_tasks: List[asyncio.Task] = [] + tool_idx: List[int] = [] + + # Copy to make sure the parent JSON schema doesn't get modified + # FIXME: May not be necessary depending on how the codebase evolves + tool_data = deepcopy(data) + tool_data.json_schema = tool_data.tool_call_schema + gen_params = tool_data.to_gen_params() + + for idx, gen in enumerate(generations): + if gen["stop_str"] in tool_data.tool_call_start: + if "text" in gen: + # non streaming, all generations will have the text they generated + pre_tool_prompt = format_prompt_with_template(data, gen["text"]) + elif current_generations is not None: + # streaming, we wont have text in the generation, + # we'll have to use the current_generations + pre_tool_prompt = format_prompt_with_template(data, current_generations) + + gen_tasks.append( + asyncio.create_task( + model.container.generate( + pre_tool_prompt, request.state.id, **gen_params + ) + ) + ) + tool_idx.append(idx) + + tool_calls = await asyncio.gather(*gen_tasks) + for outer_idx in range(0, len(tool_idx)): + gen_idx = tool_idx[outer_idx] + generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"] + + return generations + + +def postprocess_tool_call(call_str: str) -> List[ToolCall]: + tool_calls = json.loads(call_str) + for tool_call in tool_calls: + tool_call["function"]["arguments"] = json.dumps( + tool_call["function"]["arguments"] + ) + return [ToolCall(**tool_call) for tool_call in tool_calls] diff --git a/templates/chatml_with_headers_tool_calling.jinja b/templates/chatml_with_headers_tool_calling.jinja new file mode 100644 index 0000000..ecd1d2c --- /dev/null +++ b/templates/chatml_with_headers_tool_calling.jinja @@ -0,0 +1,84 @@ +{# Metadata #} +{% set stop_strings = ["<|im_start|>", "<|im_end|>"] %} +{% set message_roles = ['system', 'user', 'assistant', 'tool'] %} +{% set tool_start = "<|tool_start|>" %} +{% set tool_end = "<|tool_end|>" %} +{%- set start_header = "<|start_header_id|>" -%} +{%- set end_header = "<|end_header_id|>\n" -%} + +{%- set example_tool_call -%}[ + { + "id": "tool_id_1342", + "function": { + "arguments": "arg_name": 3, + "name": "tool_name" + }, + "type": "function" + }, + { + "id": "example_id_13f42", + "function": { + "arguments": "example_arg": 1.0, "another_example_arg": true, + "name": "another_tool_name" + }, + "type": "function" + } +] +{%- endset -%} + +{%- set inital_system_prompt -%}You are an assistant that has access to the following set of tools, to call a tool: +1. Prefix calls with '{{ tool_start }}' and end calls with '{{ tool_end }}' +2. Ensure you use the correct type for arguments. For example, if the argument is a string, ensure it is enclosed in quotes, otherwise, it should not be. +3. Generate all calls using the following json tool call format. Here is a multi tool call example: + +{{ tool_start }}{{ example_tool_call }}{{ tool_end }} + +Here are the tools available for you to call: +{{ tools_json }} +{%- endset -%} + +{%- set tool_reminder -%}Available Tools: +{{ tools_json }} + +Tool Call Format Example: +{{ tool_start }}{{ example_tool_call }} + +Prefix & Suffix: Begin tool calls with {{ tool_start }} and end with {{ tool_end }}. +Argument Types: Use correct data types for arguments (e.g., strings in quotes, numbers without). +{%- endset -%} + +{# Template #} + +{%- for message in messages -%} + {%- set role = message['role'] | lower -%} + {%- if role not in message_roles -%} + {{ raise_exception('Invalid role ' + message['role'] + '. Only ' + message_roles | join(', ') + ' are supported.') }} + {%- endif -%} + + {%- set content = message['content'] | default('', true) | trim -%} + {%- if loop.first -%} +{{ bos_token }}{{ start_header }}{{ role }}{{ end_header }} +{{ inital_system_prompt }} + +{{ content }}{{ eos_token }} + {%- endif -%} + + {%- if not loop.first -%} +{{ start_header }}{{ role }}{{ end_header }} +{{ content }} + {%- if 'tool_calls_json' in message and message['tool_calls_json'] -%} +{{ tool_start }}{{ message['tool_calls_json']}}{{ tool_end }} + {%- endif -%} +{{ eos_token }} + + {%- endif -%} +{%- endfor -%} + +{%- if tool_precursor -%} +{{ start_header }}system{{ end_header }} +{{ tool_reminder }}{{ eos_token }} +{{ start_header }}assistant{{ end_header }} +{{ tool_precursor }}{{ tool_start }} +{%- else -%} +{{ start_header }}assistant{{ end_header }} +{%- endif -%} \ No newline at end of file