API: Clean up tokenizing endpoint
Split the get tokens function into separate wrapper encode and decode functions for overall code cleanliness. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
bb48f77ca1
commit
284f20263f
2 changed files with 19 additions and 23 deletions
|
|
@ -446,24 +446,23 @@ class ExllamaV2Container:
|
|||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_tokens(self, text: Optional[str], ids: Optional[List[int]], **kwargs):
|
||||
"""Common function for token operations"""
|
||||
if text:
|
||||
# Assume token encoding
|
||||
return self.tokenizer.encode(
|
||||
text,
|
||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
)
|
||||
if ids:
|
||||
# Assume token decoding
|
||||
ids = torch.tensor([ids])
|
||||
return self.tokenizer.decode(
|
||||
ids,
|
||||
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
||||
)[0]
|
||||
def encode_tokens(self, text: str, **kwargs):
|
||||
"""Wrapper to encode tokens from a text string"""
|
||||
|
||||
return None
|
||||
return self.tokenizer.encode(
|
||||
text,
|
||||
add_bos=unwrap(kwargs.get("add_bos_token"), True),
|
||||
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
|
||||
)[0].tolist()
|
||||
|
||||
def decode_tokens(self, ids: List[int], **kwargs):
|
||||
"""Wrapper to decode tokens from a list of IDs"""
|
||||
|
||||
ids = torch.tensor([ids])
|
||||
return self.tokenizer.decode(
|
||||
ids,
|
||||
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
|
||||
)[0]
|
||||
|
||||
def get_special_tokens(self, add_bos_token: bool, ban_eos_token: bool):
|
||||
return {
|
||||
|
|
|
|||
9
main.py
9
main.py
|
|
@ -414,11 +414,8 @@ async def unload_loras():
|
|||
)
|
||||
async def encode_tokens(data: TokenEncodeRequest):
|
||||
"""Encodes a string into tokens."""
|
||||
raw_tokens = MODEL_CONTAINER.get_tokens(data.text, None, **data.get_params())
|
||||
|
||||
# Have to use this if check otherwise Torch's tensors error out
|
||||
# with a boolean issue
|
||||
tokens = raw_tokens[0].tolist() if raw_tokens is not None else []
|
||||
raw_tokens = MODEL_CONTAINER.encode_tokens(data.text, **data.get_params())
|
||||
tokens = unwrap(raw_tokens, [])
|
||||
response = TokenEncodeResponse(tokens=tokens, length=len(tokens))
|
||||
|
||||
return response
|
||||
|
|
@ -431,7 +428,7 @@ async def encode_tokens(data: TokenEncodeRequest):
|
|||
)
|
||||
async def decode_tokens(data: TokenDecodeRequest):
|
||||
"""Decodes tokens into a string."""
|
||||
message = MODEL_CONTAINER.get_tokens(None, data.tokens, **data.get_params())
|
||||
message = MODEL_CONTAINER.decode_tokens(data.tokens, **data.get_params())
|
||||
response = TokenDecodeResponse(text=unwrap(message, ""))
|
||||
|
||||
return response
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue