API: Fix chat completion formatting flow

Previously, the flow for parsing chat completion messages and rendering
from the prompt template was disconnected between endpoints. Now, create
a common function to render and handle everything appropriately afterwards.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-11-21 17:51:14 -05:00
parent c652a6e030
commit 902045edbb
6 changed files with 92 additions and 115 deletions

View file

@ -1,5 +1,4 @@
import asyncio
from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import APIRouter, Depends, HTTPException, Request
from sse_starlette import EventSourceResponse
from sys import maxsize
@ -16,9 +15,8 @@ from endpoints.OAI.types.chat_completion import (
)
from endpoints.OAI.types.embedding import EmbeddingsRequest, EmbeddingsResponse
from endpoints.OAI.utils.chat_completion import (
format_prompt_with_template,
apply_chat_template,
generate_chat_completion,
preprocess_vision_request,
stream_generate_chat_completion,
)
from endpoints.OAI.utils.completion import (
@ -125,15 +123,7 @@ async def chat_completion_request(
model_path = model.container.model_dir
embeddings = MultimodalEmbeddingWrapper()
if isinstance(data.messages, str):
prompt = data.messages
else:
if model.container.use_vision:
data.messages, embeddings = await preprocess_vision_request(data.messages)
prompt = await format_prompt_with_template(data)
prompt, embeddings = await apply_chat_template(data)
# Set an empty JSON schema if the request wants a JSON response
if data.response_format.type == "json":

View file

@ -177,11 +177,11 @@ def _create_stream_chunk(
return chunk
async def _append_template_metadata(data: ChatCompletionRequest):
async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict):
"""Adding metadata is a one-time process."""
template_metadata = await model.container.prompt_template.extract_metadata(
data.template_vars
template_vars
)
# Stop strings
@ -199,7 +199,43 @@ async def _append_template_metadata(data: ChatCompletionRequest):
data.stop.extend(template_metadata.tool_starts)
async def format_prompt_with_template(
async def format_messages_with_template(
messages: List[ChatCompletionMessage],
existing_template_vars: Optional[dict] = None,
add_bos_token: bool = True,
ban_eos_token: bool = False,
):
"""Barebones function to format chat completion messages into a prompt."""
template_vars = unwrap(existing_template_vars, {})
mm_embeddings = MultimodalEmbeddingWrapper() if model.container.use_vision else None
for message in messages:
if isinstance(message.content, list):
concatenated_content = ""
for content in message.content:
if content.type == "text":
concatenated_content += content.text
elif content.type == "image_url" and mm_embeddings:
await mm_embeddings.add(content.image_url.url)
concatenated_content += mm_embeddings.text_alias[-1]
if message.tool_calls:
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)
message.content = concatenated_content
special_tokens_dict = model.container.get_special_tokens(
add_bos_token, ban_eos_token
)
template_vars.update({"messages": messages, **special_tokens_dict})
prompt = await model.container.prompt_template.render(template_vars)
return prompt, mm_embeddings, template_vars
async def apply_chat_template(
data: ChatCompletionRequest, tool_precursor: Optional[str] = None
):
"""
@ -208,40 +244,18 @@ async def format_prompt_with_template(
"""
try:
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True),
unwrap(data.ban_eos_token, False),
)
# Convert list to text-based content
# Use the first instance of text inside the part list
for message in data.messages:
if isinstance(message.content, list):
message.content = next(
(
content.text
for content in message.content
if content.type == "text"
),
"",
)
if message.tool_calls:
message.tool_calls_json = json.dumps(message.tool_calls, indent=2)
# Overwrite any protected vars with their values
data.template_vars.update(
{
"messages": data.messages,
"add_generation_prompt": data.add_generation_prompt,
"tools_json": json.dumps(data.model_dump()["tools"], indent=2),
"functions_json": json.dumps(data.functions, indent=2),
"tool_precursor": tool_precursor,
**special_tokens_dict,
}
)
prompt = await model.container.prompt_template.render(data.template_vars)
prompt, mm_embeddings, template_vars = await format_messages_with_template(
data.messages, data.template_vars, data.add_bos_token, data.ban_eos_token
)
# Append response prefix if present
if data.response_prefix:
@ -255,14 +269,14 @@ async def format_prompt_with_template(
# Removes the starting BOS token if present
# This is to prevent add_bos_token from adding multiple bos tokens
bos_token = special_tokens_dict.get("bos_token")
bos_token = template_vars.get("bos_token")
if bos_token and prompt.startswith(bos_token):
prompt = prompt.removeprefix(bos_token)
# Add template metadata
await _append_template_metadata(data)
await _append_template_metadata(data, template_vars)
return prompt
return prompt, mm_embeddings
except KeyError as exc:
error_message = handle_request_error(
@ -302,9 +316,9 @@ async def stream_generate_chat_completion(
n,
gen_queue,
prompt,
embeddings,
request.state.id,
abort_event,
embeddings=embeddings,
**task_gen_params.model_dump(exclude={"prompt"}),
)
)
@ -391,8 +405,8 @@ async def generate_chat_completion(
asyncio.create_task(
model.container.generate(
prompt,
embeddings,
request.state.id,
embeddings=embeddings,
**data.model_dump(exclude={"prompt"}),
)
)
@ -439,13 +453,11 @@ async def generate_tool_calls(
if gen["stop_str"] in tool_data.tool_call_start:
if "text" in gen:
# non streaming, all generations will have the text they generated
pre_tool_prompt = await format_prompt_with_template(data, gen["text"])
pre_tool_prompt = await apply_chat_template(data, gen["text"])
elif current_generations is not None:
# streaming, we wont have text in the generation,
# we'll have to use the current_generations
pre_tool_prompt = await format_prompt_with_template(
data, current_generations
)
pre_tool_prompt = await apply_chat_template(data, current_generations)
gen_tasks.append(
asyncio.create_task(
@ -471,21 +483,3 @@ def postprocess_tool_call(call_str: str) -> List[ToolCall]:
tool_call["function"]["arguments"]
)
return [ToolCall(**tool_call) for tool_call in tool_calls]
# TODO: Combine this with the existing preprocessor in format_prompt_with_template
async def preprocess_vision_request(messages: List[ChatCompletionMessage]):
embeddings = MultimodalEmbeddingWrapper()
for message in messages:
if isinstance(message.content, list):
concatenated_content = ""
for content in message.content:
if content.type == "text":
concatenated_content += content.text
elif content.type == "image_url":
await embeddings.add(content.image_url.url)
concatenated_content += embeddings.text_alias[-1]
message.content = concatenated_content
return messages, embeddings

View file

@ -7,7 +7,6 @@ Also serves as a common module for completions and chat completions.
import asyncio
import pathlib
from asyncio import CancelledError
from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import HTTPException, Request
from typing import List, Union
@ -88,7 +87,6 @@ async def _stream_collector(
task_idx: int,
gen_queue: asyncio.Queue,
prompt: str,
embeddings: MultimodalEmbeddingWrapper,
request_id: str,
abort_event: asyncio.Event,
**kwargs,
@ -97,7 +95,7 @@ async def _stream_collector(
try:
new_generation = model.container.generate_gen(
prompt, embeddings, request_id, abort_event, **kwargs
prompt, request_id, abort_event, **kwargs
)
async for generation in new_generation:
generation["index"] = task_idx

View file

@ -1,6 +1,7 @@
import asyncio
import pathlib
from sys import maxsize
from typing import Optional
from common.multimodal import MultimodalEmbeddingWrapper
from fastapi import APIRouter, Depends, HTTPException, Request, Response
from sse_starlette import EventSourceResponse
@ -14,6 +15,7 @@ from common.tabby_config import config
from common.templating import PromptTemplate, get_all_templates
from common.utils import unwrap
from common.health import HealthManager
from endpoints.OAI.utils.chat_completion import format_messages_with_template
from endpoints.core.types.auth import AuthPermissionResponse
from endpoints.core.types.download import DownloadRequest, DownloadResponse
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
@ -359,61 +361,48 @@ async def unload_embedding_model():
)
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
"""Encodes a string or chat completion messages into tokens."""
embeddings = MultimodalEmbeddingWrapper()
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None
if isinstance(data.text, str):
text = data.text
elif isinstance(data.text, list) and "oai" in config.network.api_servers:
# TODO: Support additional chat completion args for encode
# i.e. add_generation_prompt, template selection, tool args, template kwargs
if model.container.prompt_template is None:
elif isinstance(data.text, list):
if "oai" not in config.network.api_servers:
error_message = handle_request_error(
"Tokenization of chat completion requests is disabled "
"because a prompt template is not set.",
"Enable the OAI server to handle chat completion messages.",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
from endpoints.OAI.utils.chat_completion import preprocess_vision_request
if not model.container.prompt_template:
error_message = handle_request_error(
"Cannot tokenize chat completion message because "
+ "a prompt template is not set.",
exc_info=False,
).error.message
if model.container.use_vision:
data.text, embeddings = await preprocess_vision_request(data.text)
# Keeping behavior consistent with format_prompt_with_template
# Deal with list in messages.content
# Just replace the content list with the very first text message
for message in data.text:
if isinstance(message["content"], list):
message["content"] = next(
(
content["text"]
for content in message["content"]
if content["type"] == "text"
),
"",
)
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True)
)
raise HTTPException(422, error_message)
template_vars = {
"messages": data.text,
"add_generation_prompt": False,
**special_tokens_dict,
}
text = await model.container.prompt_template.render(template_vars)
# Don't need template vars again
text, mm_embeddings, _ = await format_messages_with_template(
data.text, template_vars, data.add_bos_token
)
else:
error_message = handle_request_error(
"OAI API server must be enabled to handle chat completion message inputs.",
"Unable to tokenize the provided text. Check your formatting?",
exc_info=False,
).error.message
raise HTTPException(422, error_message)
raw_tokens = model.container.encode_tokens(text, embeddings, **data.get_params())
raw_tokens = model.container.encode_tokens(
text, embeddings=mm_embeddings, **data.get_params()
)
tokens = unwrap(raw_tokens, [])
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))

View file

@ -1,7 +1,9 @@
"""Tokenization types"""
from pydantic import BaseModel
from typing import Dict, List, Union
from typing import List, Union
from endpoints.OAI.types.chat_completion import ChatCompletionMessage
class CommonTokenRequest(BaseModel):
@ -23,7 +25,7 @@ class CommonTokenRequest(BaseModel):
class TokenEncodeRequest(CommonTokenRequest):
"""Represents a tokenization request."""
text: Union[str, List[Dict]]
text: Union[str, List[ChatCompletionMessage]]
class TokenEncodeResponse(BaseModel):