diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index 68c9e8e..2cb6870 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -56,6 +56,11 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter): def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer): return build_token_enforcer_tokenizer_data(tokenizer) +def clear_grammar_func_cache(): + """Flush tokenizer_data cache to avoid holding references to tokenizers after unloading a model""" + + _get_lmfe_tokenizer_data.clear_cache() + class ExLlamaV2Grammar: """ExLlamaV2 class for various grammar filters/parsers.""" diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 5a425c7..5f4e86b 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -26,7 +26,10 @@ from itertools import zip_longest from loguru import logger from typing import List, Optional, Union -from backends.exllamav2.grammar import ExLlamaV2Grammar +from backends.exllamav2.grammar import ( + ExLlamaV2Grammar, + clear_grammar_func_cache, +) from backends.exllamav2.utils import ( exllama_disabled_flash_attn, hardware_supports_flash_attn, @@ -704,6 +707,10 @@ class ExllamaV2Container: # Wait for other jobs to finish await self.wait_for_jobs(kwargs.get("skip_wait")) + # Delete references held in the grammar module + clear_grammar_func_cache() + + # Unload LoRAs if self.generator and self.generator.generator.current_loras: for lora in self.generator.generator.current_loras: lora.unload()