diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index 5cd144d..edd9b34 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -361,6 +361,8 @@ async def stream_generate_chat_completion( if tool_start: if "stop_str" in generation: generations = await generate_tool_calls( + prompt, + embeddings, data, [generation], request, @@ -442,7 +444,9 @@ async def generate_chat_completion( # Check all the generations and see if a tool call is required if tool_start: - generations = await generate_tool_calls(data, generations, request) + generations = await generate_tool_calls( + prompt, embeddings, data, generations, request + ) response = _create_response(request.state.id, generations, model_path.name) @@ -461,6 +465,8 @@ async def generate_chat_completion( async def generate_tool_calls( + prompt: str, + embeddings: MultimodalEmbeddingWrapper, data: ChatCompletionRequest, generations: List[str], request: Request, @@ -482,12 +488,10 @@ async def generate_tool_calls( logger.info(f"Detected tool call in chat completion request {request.state.id}") - # Append the existing generation as part of the response prefix + # Append the existing generation text if present precursor_text = current_generation_text or gen.get("text") if precursor_text: - tool_data.response_prefix = precursor_text - - pre_tool_prompt, embeddings = await apply_chat_template(tool_data) + prompt = prompt + precursor_text gen_request_id = _parse_gen_request_id(data.n, request.state.id, idx) tool_request_id = f"{gen_request_id}-tool" @@ -496,7 +500,7 @@ async def generate_tool_calls( asyncio.create_task( model.container.generate( tool_request_id, - pre_tool_prompt, + prompt, tool_data, mm_embeddings=embeddings, )