Cache creation tokenizer_data in LMFE

This commit is contained in:
turboderp 2024-07-08 00:51:59 +02:00
parent bb8b02a60a
commit 4d0bb1ffc3

View file

@ -2,9 +2,10 @@ import traceback
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import JsonSchemaParser, RegexParser
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter, build_token_enforcer_tokenizer_data
from loguru import logger
from typing import List
from functools import lru_cache
class OutlinesTokenizerWrapper:
@ -59,6 +60,10 @@ class ExLlamaV2Grammar:
def __init__(self):
self.filters = []
@lru_cache(10)
def _get_lmfe_tokenizer_data(self, tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)
def add_json_schema_filter(
self,
json_schema: dict,
@ -82,7 +87,7 @@ class ExLlamaV2Grammar:
# Allow JSON objects or JSON arrays at the top level
json_prefixes = ["[", "{"]
lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer)
lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, self._get_lmfe_tokenizer_data(tokenizer))
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
# Append the filters
@ -107,7 +112,7 @@ class ExLlamaV2Grammar:
return
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer)
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, self._get_lmfe_tokenizer_data(tokenizer))
# Append the filters
self.filters.append(lmfilter)
@ -135,4 +140,4 @@ class ExLlamaV2Grammar:
return
self.filters.append(ebnf_filter)
self.filters.append(ebnf_filter)