From 8f209efb99c0aa06f23098e0a47ffc3216fc1a64 Mon Sep 17 00:00:00 2001 From: DocShotgun <126566557+DocShotgun@users.noreply.github.com> Date: Sun, 24 Nov 2024 10:44:45 -0800 Subject: [PATCH] Grammar: Clean up KBNF implementation * Also remove empty cache clear function --- backends/exllamav2/grammar.py | 32 ++++++++++++-------------------- backends/exllamav2/model.py | 3 --- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index 9621e3f..9a6b520 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -1,22 +1,14 @@ import traceback import typing -from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer -from exllamav2.generator.filters import ExLlamaV2Filter -from loguru import logger from typing import List -from formatron.formatter import FormatterBuilder -from formatron.schemas import json_schema -from formatron.integrations.exllamav2 import create_formatter_filter +from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer +from exllamav2.generator.filters import ExLlamaV2Filter from formatron.extractor import NonterminalExtractor - - -def clear_grammar_func_cache(): - """Flush tokenizer_data cache to avoid holding references to - tokenizers after unloading a model""" - - # TODO: Unsure if this is needed with formatron - pass +from formatron.formatter import FormatterBuilder +from formatron.integrations.exllamav2 import create_formatter_filter +from formatron.schemas import json_schema +from loguru import logger class ExLlamaV2Grammar: @@ -102,7 +94,7 @@ class ExLlamaV2Grammar: f = FormatterBuilder() f.append_line( f"{f.extractor( - lambda nonterminal: CustomExtractor(nonterminal, kbnf_string) + lambda nonterminal: CFGExtractor(nonterminal, kbnf_string) )}" ) except Exception: @@ -119,16 +111,16 @@ class ExLlamaV2Grammar: self.filters.append(lmfilter) -class CustomExtractor(NonterminalExtractor): +class CFGExtractor(NonterminalExtractor): + """Extractor class for KBNF context-free grammar""" + def __init__(self, nonterminal: str, kbnf_string: str): super().__init__(nonterminal) self.kbnf_string = kbnf_string - # Fails without an extract function defined - # No idea what it does or why it's needed, but this seems to work - # TODO: Figure out how to do this properly + # Return the entire input string as the extracted string def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]: - return input_str[len(input_str) :], input_str[: len(input_str)] + return "", input_str @property def kbnf_definition(self) -> str: diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 50cef42..64ed5b9 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -833,9 +833,6 @@ class ExllamaV2Container: # Wait for other jobs to finish await self.wait_for_jobs(kwargs.get("skip_wait")) - # Delete references held in the grammar module - clear_grammar_func_cache() - # Clear the image embedding cache clear_image_embedding_cache()