OAI: Tokenize chat completion messages

Since chat completion messages are a structure, format the prompt
before checking in the tokenizer.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-04-15 14:17:16 -04:00
parent ed05f376d9
commit 515b3c2930
3 changed files with 26 additions and 5 deletions

View file

@ -566,7 +566,9 @@ class ExllamaV2Container:
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]
def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
def get_special_tokens(
self, add_bos_token: bool = True, ban_eos_token: bool = False
):
return {
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",

View file

@ -16,6 +16,7 @@ from common.concurrency import (
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import (
get_all_templates,
get_prompt_from_template,
get_template_from_file,
)
from common.utils import coalesce, unwrap
@ -386,8 +387,26 @@ async def unload_loras():
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
async def encode_tokens(data: TokenEncodeRequest):
"""Encodes a string into tokens."""
raw_tokens = model.container.encode_tokens(data.text, **data.get_params())
"""Encodes a string or chat completion messages into tokens."""
if isinstance(data.text, str):
text = data.text
else:
special_tokens_dict = model.container.get_special_tokens(
unwrap(data.add_bos_token, True)
)
template_vars = {
"messages": data.text,
"add_generation_prompt": False,
**special_tokens_dict,
}
text, _ = get_prompt_from_template(
model.container.prompt_template, template_vars
)
raw_tokens = model.container.encode_tokens(text, **data.get_params())
tokens = unwrap(raw_tokens, [])
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))

View file

@ -1,7 +1,7 @@
"""Tokenization types"""
from pydantic import BaseModel
from typing import List
from typing import Dict, List, Union
class CommonTokenRequest(BaseModel):
@ -23,7 +23,7 @@ class CommonTokenRequest(BaseModel):
class TokenEncodeRequest(CommonTokenRequest):
"""Represents a tokenization request."""
text: str
text: Union[str, List[Dict[str, str]]]
class TokenEncodeResponse(BaseModel):