Exl3: Add token encode, decode, and special token fetch

Base class methods

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-05-02 21:32:53 -04:00
parent 0c1d794390
commit b4ff2f23cf

View file

@ -18,6 +18,7 @@ from common.multimodal import MultimodalEmbeddingWrapper
from common.sampling import BaseSamplerRequest
from common.templating import PromptTemplate
from common.transformers_utils import GenerationConfig
from common.utils import unwrap
from endpoints.core.types.model import ModelCard
from exllamav3 import Config, Model, Cache, Tokenizer
@ -175,7 +176,11 @@ class ExllamaV3Container(BaseModelContainer):
A list of integer token IDs.
"""
pass
return self.tokenizer.encode(
text,
add_bos=unwrap(kwargs.get("add_bos_token"), True),
encode_special_tokens=unwrap(kwargs.get("encode_special_tokens"), True),
).flatten().tolist()
def decode_tokens(self, ids: List[int], **kwargs) -> str:
"""
@ -189,9 +194,15 @@ class ExllamaV3Container(BaseModelContainer):
The decoded text string.
"""
pass
ids = torch.tensor([ids])
return self.tokenizer.decode(
ids,
decode_special_tokens=unwrap(kwargs.get("decode_special_tokens"), True),
)[0]
def get_special_tokens(self, **kwargs) -> Dict[str, Any]:
def get_special_tokens(
self, add_bos_token: bool = True, ban_eos_token: bool = False
):
"""
Gets special tokens used by the model/tokenizer.
@ -203,7 +214,12 @@ class ExllamaV3Container(BaseModelContainer):
to their string or ID representation.
"""
pass
return {
"bos_token": self.tokenizer.bos_token if add_bos_token else "",
"eos_token": self.tokenizer.eos_token if not ban_eos_token else "",
"pad_token": self.tokenizer.pad_token,
"unk_token": self.tokenizer.unk_token,
}
def model_info(self) -> ModelCard:
"""