API: Fix chat completion formatting flow
Previously, the flow for parsing chat completion messages and rendering from the prompt template was disconnected between endpoints. Now, create a common function to render and handle everything appropriately afterwards. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
c652a6e030
commit
902045edbb
6 changed files with 92 additions and 115 deletions
|
|
@ -177,11 +177,11 @@ def _create_stream_chunk(
|
|||
return chunk
|
||||
|
||||
|
||||
async def _append_template_metadata(data: ChatCompletionRequest):
|
||||
async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict):
|
||||
"""Adding metadata is a one-time process."""
|
||||
|
||||
template_metadata = await model.container.prompt_template.extract_metadata(
|
||||
data.template_vars
|
||||
template_vars
|
||||
)
|
||||
|
||||
# Stop strings
|
||||
|
|
@ -199,7 +199,43 @@ async def _append_template_metadata(data: ChatCompletionRequest):
|
|||
data.stop.extend(template_metadata.tool_starts)
|
||||
|
||||
|
||||
async def format_prompt_with_template(
|
||||
async def format_messages_with_template(
|
||||
messages: List[ChatCompletionMessage],
|
||||
existing_template_vars: Optional[dict] = None,
|
||||
add_bos_token: bool = True,
|
||||
ban_eos_token: bool = False,
|
||||
):
|
||||
"""Barebones function to format chat completion messages into a prompt."""
|
||||
|
||||
template_vars = unwrap(existing_template_vars, {})
|
||||
mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message.content, list):
|
||||
concatenated_content = ""
|
||||
for content in message.content:
|
||||
if content.type == "text":
|
||||
concatenated_content += content.text
|
||||
elif content.type == "image_url" and mm_embeddings:
|
||||
await mm_embeddings.add(content.image_url.url)
|
||||
concatenated_content += mm_embeddings.text_alias[-1]
|
||||
|
||||
if message.tool_calls:
|
||||
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)
|
||||
|
||||
message.content = concatenated_content
|
||||
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
add_bos_token, ban_eos_token
|
||||
)
|
||||
|
||||
template_vars.update({"messages": messages, **special_tokens_dict})
|
||||
|
||||
prompt = await model.container.prompt_template.render(template_vars)
|
||||
return prompt, mm_embeddings, template_vars
|
||||
|
||||
|
||||
async def apply_chat_template(
|
||||
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
|
||||
):
|
||||
"""
|
||||
|
|
@ -208,40 +244,18 @@ async def format_prompt_with_template(
|
|||
"""
|
||||
|
||||
try:
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True),
|
||||
unwrap(data.ban_eos_token, False),
|
||||
)
|
||||
|
||||
# Convert list to text-based content
|
||||
# Use the first instance of text inside the part list
|
||||
for message in data.messages:
|
||||
if isinstance(message.content, list):
|
||||
message.content = next(
|
||||
(
|
||||
content.text
|
||||
for content in message.content
|
||||
if content.type == "text"
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
if message.tool_calls:
|
||||
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)
|
||||
|
||||
# Overwrite any protected vars with their values
|
||||
data.template_vars.update(
|
||||
{
|
||||
"messages": data.messages,
|
||||
"add_generation_prompt": data.add_generation_prompt,
|
||||
"tools_json": json.dumps(data.model_dump()["tools"], indent=2),
|
||||
"functions_json": json.dumps(data.functions, indent=2),
|
||||
"tool_precursor": tool_precursor,
|
||||
**special_tokens_dict,
|
||||
}
|
||||
)
|
||||
|
||||
prompt = await model.container.prompt_template.render(data.template_vars)
|
||||
prompt, mm_embeddings, template_vars = await format_messages_with_template(
|
||||
data.messages, data.template_vars, data.add_bos_token, data.ban_eos_token
|
||||
)
|
||||
|
||||
# Append response prefix if present
|
||||
if data.response_prefix:
|
||||
|
|
@ -255,14 +269,14 @@ async def format_prompt_with_template(
|
|||
|
||||
# Removes the starting BOS token if present
|
||||
# This is to prevent add_bos_token from adding multiple bos tokens
|
||||
bos_token = special_tokens_dict.get("bos_token")
|
||||
bos_token = template_vars.get("bos_token")
|
||||
if bos_token and prompt.startswith(bos_token):
|
||||
prompt = prompt.removeprefix(bos_token)
|
||||
|
||||
# Add template metadata
|
||||
await _append_template_metadata(data)
|
||||
await _append_template_metadata(data, template_vars)
|
||||
|
||||
return prompt
|
||||
return prompt, mm_embeddings
|
||||
|
||||
except KeyError as exc:
|
||||
error_message = handle_request_error(
|
||||
|
|
@ -302,9 +316,9 @@ async def stream_generate_chat_completion(
|
|||
n,
|
||||
gen_queue,
|
||||
prompt,
|
||||
embeddings,
|
||||
request.state.id,
|
||||
abort_event,
|
||||
embeddings=embeddings,
|
||||
**task_gen_params.model_dump(exclude={"prompt"}),
|
||||
)
|
||||
)
|
||||
|
|
@ -391,8 +405,8 @@ async def generate_chat_completion(
|
|||
asyncio.create_task(
|
||||
model.container.generate(
|
||||
prompt,
|
||||
embeddings,
|
||||
request.state.id,
|
||||
embeddings=embeddings,
|
||||
**data.model_dump(exclude={"prompt"}),
|
||||
)
|
||||
)
|
||||
|
|
@ -439,13 +453,11 @@ async def generate_tool_calls(
|
|||
if gen["stop_str"] in tool_data.tool_call_start:
|
||||
if "text" in gen:
|
||||
# non streaming, all generations will have the text they generated
|
||||
pre_tool_prompt = await format_prompt_with_template(data, gen["text"])
|
||||
pre_tool_prompt = await apply_chat_template(data, gen["text"])
|
||||
elif current_generations is not None:
|
||||
# streaming, we wont have text in the generation,
|
||||
# we'll have to use the current_generations
|
||||
pre_tool_prompt = await format_prompt_with_template(
|
||||
data, current_generations
|
||||
)
|
||||
pre_tool_prompt = await apply_chat_template(data, current_generations)
|
||||
|
||||
gen_tasks.append(
|
||||
asyncio.create_task(
|
||||
|
|
@ -471,21 +483,3 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]:
|
|||
tool_call["function"]["arguments"]
|
||||
)
|
||||
return [ToolCall(**tool_call) for tool_call in tool_calls]
|
||||
|
||||
|
||||
# TODO: Combine this with the existing preprocessor in format_prompt_with_template
|
||||
async def preprocess_vision_request(messages: List[ChatCompletionMessage]):
|
||||
embeddings = MultimodalEmbeddingWrapper()
|
||||
for message in messages:
|
||||
if isinstance(message.content, list):
|
||||
concatenated_content = ""
|
||||
for content in message.content:
|
||||
if content.type == "text":
|
||||
concatenated_content += content.text
|
||||
elif content.type == "image_url":
|
||||
await embeddings.add(content.image_url.url)
|
||||
concatenated_content += embeddings.text_alias[-1]
|
||||
|
||||
message.content = concatenated_content
|
||||
|
||||
return messages, embeddings
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue