API: Don't do a second re-render when tool calling
Re-rendering the template is an expensive operation when it's possible to just concatenate the prompt and current generation text together. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
3dfa965019
commit
5b1db3ad83
1 changed files with 10 additions and 6 deletions
|
|
@ -361,6 +361,8 @@ async def stream_generate_chat_completion(
|
|||
if tool_start:
|
||||
if "stop_str" in generation:
|
||||
generations = await generate_tool_calls(
|
||||
prompt,
|
||||
embeddings,
|
||||
data,
|
||||
[generation],
|
||||
request,
|
||||
|
|
@ -442,7 +444,9 @@ async def generate_chat_completion(
|
|||
|
||||
# Check all the generations and see if a tool call is required
|
||||
if tool_start:
|
||||
generations = await generate_tool_calls(data, generations, request)
|
||||
generations = await generate_tool_calls(
|
||||
prompt, embeddings, data, generations, request
|
||||
)
|
||||
|
||||
response = _create_response(request.state.id, generations, model_path.name)
|
||||
|
||||
|
|
@ -461,6 +465,8 @@ async def generate_chat_completion(
|
|||
|
||||
|
||||
async def generate_tool_calls(
|
||||
prompt: str,
|
||||
embeddings: MultimodalEmbeddingWrapper,
|
||||
data: ChatCompletionRequest,
|
||||
generations: List[str],
|
||||
request: Request,
|
||||
|
|
@ -482,12 +488,10 @@ async def generate_tool_calls(
|
|||
|
||||
logger.info(f"Detected tool call in chat completion request {request.state.id}")
|
||||
|
||||
# Append the existing generation as part of the response prefix
|
||||
# Append the existing generation text if present
|
||||
precursor_text = current_generation_text or gen.get("text")
|
||||
if precursor_text:
|
||||
tool_data.response_prefix = precursor_text
|
||||
|
||||
pre_tool_prompt, embeddings = await apply_chat_template(tool_data)
|
||||
prompt = prompt + precursor_text
|
||||
|
||||
gen_request_id = _parse_gen_request_id(data.n, request.state.id, idx)
|
||||
tool_request_id = f"{gen_request_id}-tool"
|
||||
|
|
@ -496,7 +500,7 @@ async def generate_tool_calls(
|
|||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
tool_request_id,
|
||||
pre_tool_prompt,
|
||||
prompt,
|
||||
tool_data,
|
||||
mm_embeddings=embeddings,
|
||||
)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue