API: Clean up tokenizing endpoint

Split the get tokens function into separate wrapper encode and decode
functions for overall code cleanliness.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-02-05 00:20:10 -05:00 committed by Brian Dashore
parent bb48f77ca1
commit 284f20263f
2 changed files with 19 additions and 23 deletions

View file

@ -446,24 +446,23 @@ class ExllamaV2Container:
gc.collect()
torch.cuda.empty_cache()
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
"""Common function for token operations"""
if text:
# Assume token encoding
return self.tokenizer.encode(
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
)
if ids:
# Assume token decoding
ids = torch.tensor([ids])
return self.tokenizer.decode(
ids,
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]
def encode_tokens(self, text: str, **kwargs):
"""Wrapper to encode tokens from a text string"""
return None
return self.tokenizer.encode(
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
)[0].tolist()
def decode_tokens(self, ids: List[int], **kwargs):
"""Wrapper to decode tokens from a list of IDs"""
ids = torch.tensor([ids])
return self.tokenizer.decode(
ids,
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]
def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
return {

View file

@ -414,11 +414,8 @@ async def unload_loras():
)
async def encode_tokens(data: TokenEncodeRequest):
"""Encodes a string into tokens."""
raw_tokens = MODEL_CONTAINER.get_tokens(data.text, None, **data.get_params())
# Have to use this if check otherwise Torch's tensors error out
# with a boolean issue
tokens = raw_tokens[0].tolist() if raw_tokens is not None else []
raw_tokens = MODEL_CONTAINER.encode_tokens(data.text, **data.get_params())
tokens = unwrap(raw_tokens, [])
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
return response
@ -431,7 +428,7 @@ async def encode_tokens(data: TokenEncodeRequest):
)
async def decode_tokens(data: TokenDecodeRequest):
"""Decodes tokens into a string."""
message = MODEL_CONTAINER.get_tokens(None, data.tokens, **data.get_params())
message = MODEL_CONTAINER.decode_tokens(data.tokens, **data.get_params())
response = TokenDecodeResponse(text=unwrap(message, ""))
return response