diff --git a/common/templating.py b/common/templating.py index 3a9347b..a33895e 100644 --- a/common/templating.py +++ b/common/templating.py @@ -4,6 +4,8 @@ import traceback import aiofiles import json import pathlib +from dataclasses import dataclass, field +from datetime import datetime from importlib.metadata import version as package_version from typing import List, Optional from jinja2 import Template, TemplateError @@ -11,7 +13,6 @@ from jinja2.ext import loopcontrols from jinja2.sandbox import ImmutableSandboxedEnvironment from loguru import logger from packaging import version -from datetime import datetime from common.utils import unwrap @@ -23,11 +24,12 @@ class TemplateLoadError(Exception): pass +@dataclass class TemplateMetadata: """Represents the parsed metadata from a template.""" - stop_strings: List[str] = [] - tool_starts: List[str] = [] + stop_strings: List[str] = field(default_factory=list) + tool_start: Optional[str] = None class PromptTemplate: @@ -72,11 +74,7 @@ class PromptTemplate: if hasattr(template_module, "tool_start"): if isinstance(template_module.tool_start, str): - template_metadata.tool_starts.append(template_module.tool_start) - - if hasattr(template_module, "tool_start_token"): - if isinstance(template_module.tool_start_token, int): - template_metadata.tool_starts.append(template_module.tool_start_token) + template_metadata.tool_start = template_module.tool_start self.metadata = template_metadata return template_metadata diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 405a609..03e69de 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -5,7 +5,7 @@ from typing import Literal, Union, List, Optional, Dict from uuid import uuid4 from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest -from endpoints.OAI.types.tools import ToolSpec, ToolCall, tool_call_schema +from endpoints.OAI.types.tools import ToolSpec, ToolCall class ChatCompletionLogprob(BaseModel): @@ -73,15 +73,6 @@ class ChatCompletionRequest(CommonCompletionRequest): tools: Optional[List[ToolSpec]] = None functions: Optional[List[Dict]] = None - # Typically collected from Chat Template. - # Don't include this in the OpenAPI docs - # TODO: Use these custom parameters - tool_call_start: SkipJsonSchema[Optional[List[Union[str, int]]]] = Field( - default_factory=list - ) - tool_call_end: SkipJsonSchema[Optional[str]] = None - tool_call_schema: SkipJsonSchema[Optional[dict]] = tool_call_schema - # Chat completions requests do not have a BOS token preference. Backend # respects the tokenization config for the individual model. add_bos_token: Optional[bool] = None diff --git a/endpoints/OAI/types/tools.py b/endpoints/OAI/types/tools.py index b9c0b33..c9ccd8b 100644 --- a/endpoints/OAI/types/tools.py +++ b/endpoints/OAI/types/tools.py @@ -1,30 +1,6 @@ from pydantic import BaseModel from typing import Dict, Literal -tool_call_schema = { - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "array", - "items": { - "type": "object", - "properties": { - "id": {"type": "string"}, - "function": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "arguments": { - # Converted to OAI's string in post process - "type": "object" - }, - }, - "required": ["name", "arguments"], - }, - "type": {"type": "string", "enum": ["function"]}, - }, - "required": ["id", "function", "type"], - }, -} - class Function(BaseModel): """Represents a description of a tool function.""" diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index bfed45a..10d3aec 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -1,7 +1,6 @@ """Chat completion utilities for OAI server.""" import asyncio -import json import pathlib from asyncio import CancelledError from typing import List, Optional @@ -30,7 +29,7 @@ from endpoints.OAI.types.chat_completion import ( ) from endpoints.OAI.types.common import UsageStats from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector -from endpoints.OAI.utils.tools import ToolCallProcessor +from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA def _create_response( @@ -209,12 +208,9 @@ async def _append_template_metadata(data: ChatCompletionRequest, template_vars: else: data.stop.extend(template_metadata.stop_strings) - # Tool call start strings - if template_metadata.tool_starts: - data.tool_call_start.extend(template_metadata.tool_starts) - - # Append to stop strings to halt for a tool call generation - data.stop.extend(template_metadata.tool_starts) + # if a tool start is present, append it to stopping strings + if template_metadata.tool_start: + data.stop.append(template_metadata.tool_start) async def format_messages_with_template( @@ -255,9 +251,7 @@ async def format_messages_with_template( return prompt, mm_embeddings, template_vars -async def apply_chat_template( - data: ChatCompletionRequest, tool_precursor: Optional[str] = None -): +async def apply_chat_template(data: ChatCompletionRequest): """ Compile the prompt and get any additional stop strings from the template. Template stop strings can be overriden by sampler overrides if force is true. @@ -271,10 +265,7 @@ async def apply_chat_template( { "add_generation_prompt": data.add_generation_prompt, "tools": tools, - "tools_json": json.dumps(tools, indent=2), "functions": data.functions, - "functions_json": json.dumps(data.functions, indent=2), - "tool_precursor": tool_precursor, } ) @@ -332,6 +323,7 @@ async def stream_generate_chat_completion( abort_event = asyncio.Event() gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] + tool_start = model.container.prompt_template.metadata.tool_start disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: @@ -355,7 +347,7 @@ async def stream_generate_chat_completion( gen_tasks.append(gen_task) - # We need to keep track of the text generated so we can resume the tool calls + # Text accumulation for tool calls current_generation_text = "" # Consumer loop @@ -367,19 +359,21 @@ async def stream_generate_chat_completion( ) generation = await gen_queue.get() - # lets only append the text if we need it for tool calls later - if data.tool_call_start and "text" in generation: - current_generation_text += generation["text"] - # check if we are running a tool model, and that we are at stop - if data.tool_call_start and "stop_str" in generation: - generations = await generate_tool_calls( - data, - [generation], - request, - current_generations=current_generation_text, - ) - generation = generations[0] # We only have one generation in this case + # Handle options if a tool model is present + if tool_start: + if "stop_str" in generation: + generations = await generate_tool_calls( + data, + [generation], + request, + current_generation_text=current_generation_text, + ) + + # Only one generation present in this case + generation = generations[0] + elif "text" in generation: + current_generation_text += generation["text"] # Stream collector will push an exception to the queue if it fails if isinstance(generation, Exception): @@ -428,6 +422,7 @@ async def generate_chat_completion( model_path: pathlib.Path, ): gen_tasks: List[asyncio.Task] = [] + tool_start = model.container.prompt_template.metadata.tool_start try: logger.info(f"Received chat completion request {request.state.id}") @@ -448,8 +443,8 @@ async def generate_chat_completion( generations = await asyncio.gather(*gen_tasks) - # Let's not waste our time if we arn't running a tool model - if data.tool_call_start: + # Check all the generations and see if a tool call is required + if tool_start: generations = await generate_tool_calls(data, generations, request) response = _create_response(request.state.id, generations, model_path.name) @@ -472,51 +467,52 @@ async def generate_tool_calls( data: ChatCompletionRequest, generations: List[str], request: Request, - current_generations: str = None, + current_generation_text: str = None, ): gen_tasks: List[asyncio.Task] = [] + tool_start = model.container.prompt_template.metadata.tool_start + + # Tracks which generations asked for a tool call tool_idx: List[int] = [] # Copy to make sure the parent JSON schema doesn't get modified - # FIXME: May not be necessary depending on how the codebase evolves tool_data = data.model_copy(deep=True) - tool_data.json_schema = tool_data.tool_call_schema + tool_data.json_schema = TOOL_CALL_SCHEMA for idx, gen in enumerate(generations): - if gen["stop_str"] in tool_data.tool_call_start: - logger.info( - f"Detected tool call in chat completion request {request.state.id}" - ) + if gen["stop_str"] != tool_start: + continue - if "text" in gen: - # non streaming, all generations will have the text they generated - pre_tool_prompt, embeddings = await apply_chat_template( - data, gen["text"] - ) - elif current_generations is not None: - # streaming, we wont have text in the generation, - # we'll have to use the current_generations - pre_tool_prompt, embeddings = await apply_chat_template( - data, current_generations - ) + logger.info(f"Detected tool call in chat completion request {request.state.id}") - request_id = _parse_gen_request_id(data.n, request.state.id, idx) + # Append the existing generation as part of the response prefix + precursor_text = current_generation_text or gen.get("text") + if precursor_text: + tool_data.response_prefix = precursor_text - gen_tasks.append( - asyncio.create_task( - model.container.generate( - request_id, - pre_tool_prompt, - tool_data, - mm_embeddings=embeddings, - ) + pre_tool_prompt, embeddings = await apply_chat_template(tool_data) + + gen_request_id = _parse_gen_request_id(data.n, request.state.id, idx) + tool_request_id = f"{gen_request_id}-tool" + + gen_tasks.append( + asyncio.create_task( + model.container.generate( + tool_request_id, + pre_tool_prompt, + tool_data, + mm_embeddings=embeddings, ) ) - tool_idx.append(idx) + ) - tool_calls = await asyncio.gather(*gen_tasks) - for outer_idx in range(0, len(tool_idx)): - gen_idx = tool_idx[outer_idx] - generations[gen_idx]["tool_calls"] = tool_calls[outer_idx]["text"] + tool_idx.append(idx) + + if len(tool_idx) > 0: + tool_calls = await asyncio.gather(*gen_tasks) + + # Map tool calls to their appropriate generation + for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True): + generations[gen_idx]["tool_calls"] = tool_call["text"] return generations diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index 7650e96..8473d60 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -5,6 +5,31 @@ from typing import List from endpoints.OAI.types.tools import ToolCall +TOOL_CALL_SCHEMA = { + "$schema": "http://json-schema.org/draft-07/schema#", + "type": "array", + "items": { + "type": "object", + "properties": { + "id": {"type": "string"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": { + # Converted to OAI's string in post process + "type": "object" + }, + }, + "required": ["name", "arguments"], + }, + "type": {"type": "string", "enum": ["function"]}, + }, + "required": ["id", "function", "type"], + }, +} + + class ToolCallProcessor: @staticmethod def from_json(tool_calls_str: str) -> List[ToolCall]: