Move tokenizer_data cache to global scope

This commit is contained in:
turboderp 2024-07-08 02:54:49 +02:00
parent 4d0bb1ffc3
commit b7e7df1220

View file

@ -52,6 +52,11 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
return self.fsm.allowed_token_ids(self.state), set()
@lru_cache(10)
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)
class ExLlamaV2Grammar:
"""ExLlamaV2 class for various grammar filters/parsers."""
@ -60,10 +65,6 @@ class ExLlamaV2Grammar:
def __init__(self):
self.filters = []
@lru_cache(10)
def _get_lmfe_tokenizer_data(self, tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)
def add_json_schema_filter(
self,
json_schema: dict,
@ -87,7 +88,7 @@ class ExLlamaV2Grammar:
# Allow JSON objects or JSON arrays at the top level
json_prefixes = ["[", "{"]
lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, self._get_lmfe_tokenizer_data(tokenizer))
lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, _get_lmfe_tokenizer_data(tokenizer))
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
# Append the filters
@ -112,7 +113,7 @@ class ExLlamaV2Grammar:
return
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, self._get_lmfe_tokenizer_data(tokenizer))
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, _get_lmfe_tokenizer_data(tokenizer))
# Append the filters
self.filters.append(lmfilter)