API: Fix chat completion formatting flow
Previously, the flow for parsing chat completion messages and rendering from the prompt template was disconnected between endpoints. Now, create a common function to render and handle everything appropriately afterwards. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
c652a6e030
commit
902045edbb
6 changed files with 92 additions and 115 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import pathlib
|
||||
from sys import maxsize
|
||||
from typing import Optional
|
||||
from common.multimodal import MultimodalEmbeddingWrapper
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response
|
||||
from sse_starlette import EventSourceResponse
|
||||
|
|
@ -14,6 +15,7 @@ from common.tabby_config import config
|
|||
from common.templating import PromptTemplate, get_all_templates
|
||||
from common.utils import unwrap
|
||||
from common.health import HealthManager
|
||||
from endpoints.OAI.utils.chat_completion import format_messages_with_template
|
||||
from endpoints.core.types.auth import AuthPermissionResponse
|
||||
from endpoints.core.types.download import DownloadRequest, DownloadResponse
|
||||
from endpoints.core.types.lora import LoraList, LoraLoadRequest, LoraLoadResponse
|
||||
|
|
@ -359,61 +361,48 @@ async def unload_embedding_model():
|
|||
)
|
||||
async def encode_tokens(data: TokenEncodeRequest) -> TokenEncodeResponse:
|
||||
"""Encodes a string or chat completion messages into tokens."""
|
||||
embeddings = MultimodalEmbeddingWrapper()
|
||||
|
||||
mm_embeddings: Optional[MultimodalEmbeddingWrapper] = None
|
||||
|
||||
if isinstance(data.text, str):
|
||||
text = data.text
|
||||
elif isinstance(data.text, list) and "oai" in config.network.api_servers:
|
||||
# TODO: Support additional chat completion args for encode
|
||||
# i.e. add_generation_prompt, template selection, tool args, template kwargs
|
||||
if model.container.prompt_template is None:
|
||||
elif isinstance(data.text, list):
|
||||
if "oai" not in config.network.api_servers:
|
||||
error_message = handle_request_error(
|
||||
"Tokenization of chat completion requests is disabled "
|
||||
"because a prompt template is not set.",
|
||||
"Enable the OAI server to handle chat completion messages.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(422, error_message)
|
||||
|
||||
from endpoints.OAI.utils.chat_completion import preprocess_vision_request
|
||||
if not model.container.prompt_template:
|
||||
error_message = handle_request_error(
|
||||
"Cannot tokenize chat completion message because "
|
||||
+ "a prompt template is not set.",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
if model.container.use_vision:
|
||||
data.text, embeddings = await preprocess_vision_request(data.text)
|
||||
|
||||
# Keeping behavior consistent with format_prompt_with_template
|
||||
# Deal with list in messages.content
|
||||
# Just replace the content list with the very first text message
|
||||
for message in data.text:
|
||||
if isinstance(message["content"], list):
|
||||
message["content"] = next(
|
||||
(
|
||||
content["text"]
|
||||
for content in message["content"]
|
||||
if content["type"] == "text"
|
||||
),
|
||||
"",
|
||||
)
|
||||
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True)
|
||||
)
|
||||
raise HTTPException(422, error_message)
|
||||
|
||||
template_vars = {
|
||||
"messages": data.text,
|
||||
"add_generation_prompt": False,
|
||||
**special_tokens_dict,
|
||||
}
|
||||
|
||||
text = await model.container.prompt_template.render(template_vars)
|
||||
# Don't need template vars again
|
||||
text, mm_embeddings, _ = await format_messages_with_template(
|
||||
data.text, template_vars, data.add_bos_token
|
||||
)
|
||||
else:
|
||||
error_message = handle_request_error(
|
||||
"OAI API server must be enabled to handle chat completion message inputs.",
|
||||
"Unable to tokenize the provided text. Check your formatting?",
|
||||
exc_info=False,
|
||||
).error.message
|
||||
|
||||
raise HTTPException(422, error_message)
|
||||
|
||||
raw_tokens = model.container.encode_tokens(text, embeddings, **data.get_params())
|
||||
raw_tokens = model.container.encode_tokens(
|
||||
text, embeddings=mm_embeddings, **data.get_params()
|
||||
)
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
"""Tokenization types"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import Dict, List, Union
|
||||
from typing import List, Union
|
||||
|
||||
from endpoints.OAI.types.chat_completion import ChatCompletionMessage
|
||||
|
||||
|
||||
class CommonTokenRequest(BaseModel):
|
||||
|
|
@ -23,7 +25,7 @@ class CommonTokenRequest(BaseModel):
|
|||
class TokenEncodeRequest(CommonTokenRequest):
|
||||
"""Represents a tokenization request."""
|
||||
|
||||
text: Union[str, List[Dict]]
|
||||
text: Union[str, List[ChatCompletionMessage]]
|
||||
|
||||
|
||||
class TokenEncodeResponse(BaseModel):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue