This commit is contained in:
turboderp 2024-07-08 03:49:26 +02:00
parent 4cf79c5ae1
commit 8bbce3455c

View file

@ -2,7 +2,10 @@ import traceback
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import JsonSchemaParser, RegexParser
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter, build_token_enforcer_tokenizer_data
from lmformatenforcer.integrations.exllamav2 import (
ExLlamaV2TokenEnforcerFilter,
build_token_enforcer_tokenizer_data,
)
from loguru import logger
from typing import List
from functools import lru_cache
@ -56,6 +59,7 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)
def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to tokenizers after unloading a model"""
@ -93,7 +97,9 @@ class ExLlamaV2Grammar:
# Allow JSON objects or JSON arrays at the top level
json_prefixes = ["[", "{"]
lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, _get_lmfe_tokenizer_data(tokenizer))
lmfilter = ExLlamaV2TokenEnforcerFilter(
schema_parser, _get_lmfe_tokenizer_data(tokenizer)
)
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
# Append the filters
@ -118,7 +124,9 @@ class ExLlamaV2Grammar:
return
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, _get_lmfe_tokenizer_data(tokenizer))
lmfilter = ExLlamaV2TokenEnforcerFilter(
pattern_parser, _get_lmfe_tokenizer_data(tokenizer)
)
# Append the filters
self.filters.append(lmfilter)
@ -146,4 +154,4 @@ class ExLlamaV2Grammar:
return
self.filters.append(ebnf_filter)
self.filters.append(ebnf_filter)