OAI: Tokenize chat completion messages
Since chat completion messages are a structure, format the prompt before checking in the tokenizer. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
ed05f376d9
commit
515b3c2930
3 changed files with 26 additions and 5 deletions
|
|
@ -566,7 +566,9 @@ class ExllamaV2Container:
|
|||
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
||||
)[0]
|
||||
|
||||
def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
|
||||
def get_special_tokens(
|
||||
self, add_bos_token: bool = True, ban_eos_token: bool = False
|
||||
):
|
||||
return {
|
||||
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
|
||||
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from common.concurrency import (
|
|||
from common.networking import handle_request_error, run_with_request_disconnect
|
||||
from common.templating import (
|
||||
get_all_templates,
|
||||
get_prompt_from_template,
|
||||
get_template_from_file,
|
||||
)
|
||||
from common.utils import coalesce, unwrap
|
||||
|
|
@ -386,8 +387,26 @@ async def unload_loras():
|
|||
dependencies=[Depends(check_api_key), Depends(check_model_container)],
|
||||
)
|
||||
async def encode_tokens(data: TokenEncodeRequest):
|
||||
"""Encodes a string into tokens."""
|
||||
raw_tokens = model.container.encode_tokens(data.text, **data.get_params())
|
||||
"""Encodes a string or chat completion messages into tokens."""
|
||||
|
||||
if isinstance(data.text, str):
|
||||
text = data.text
|
||||
else:
|
||||
special_tokens_dict = model.container.get_special_tokens(
|
||||
unwrap(data.add_bos_token, True)
|
||||
)
|
||||
|
||||
template_vars = {
|
||||
"messages": data.text,
|
||||
"add_generation_prompt": False,
|
||||
**special_tokens_dict,
|
||||
}
|
||||
|
||||
text, _ = get_prompt_from_template(
|
||||
model.container.prompt_template, template_vars
|
||||
)
|
||||
|
||||
raw_tokens = model.container.encode_tokens(text, **data.get_params())
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Tokenization types"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
from typing import Dict, List, Union
|
||||
|
||||
|
||||
class CommonTokenRequest(BaseModel):
|
||||
|
|
@ -23,7 +23,7 @@ class CommonTokenRequest(BaseModel):
|
|||
class TokenEncodeRequest(CommonTokenRequest):
|
||||
"""Represents a tokenization request."""
|
||||
|
||||
text: str
|
||||
text: Union[str, List[Dict[str, str]]]
|
||||
|
||||
|
||||
class TokenEncodeResponse(BaseModel):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue