diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index 2cb6870..f2abf85 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -2,7 +2,10 @@ import traceback from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter from lmformatenforcer import JsonSchemaParser, RegexParser -from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter, build_token_enforcer_tokenizer_data +from lmformatenforcer.integrations.exllamav2 import ( + ExLlamaV2TokenEnforcerFilter, + build_token_enforcer_tokenizer_data, +) from loguru import logger from typing import List from functools import lru_cache @@ -56,6 +59,7 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter): def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer): return build_token_enforcer_tokenizer_data(tokenizer) + def clear_grammar_func_cache(): """Flush tokenizer_data cache to avoid holding references to tokenizers after unloading a model""" @@ -93,7 +97,9 @@ class ExLlamaV2Grammar: # Allow JSON objects or JSON arrays at the top level json_prefixes = ["[", "{"] - lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, _get_lmfe_tokenizer_data(tokenizer)) + lmfilter = ExLlamaV2TokenEnforcerFilter( + schema_parser, _get_lmfe_tokenizer_data(tokenizer) + ) prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes) # Append the filters @@ -118,7 +124,9 @@ class ExLlamaV2Grammar: return - lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, _get_lmfe_tokenizer_data(tokenizer)) + lmfilter = ExLlamaV2TokenEnforcerFilter( + pattern_parser, _get_lmfe_tokenizer_data(tokenizer) + ) # Append the filters self.filters.append(lmfilter) @@ -146,4 +154,4 @@ class ExLlamaV2Grammar: return - self.filters.append(ebnf_filter) \ No newline at end of file + self.filters.append(ebnf_filter)