Grammar: Clean up KBNF implementation

* Also remove empty cache clear function
This commit is contained in:
DocShotgun 2024-11-24 10:44:45 -08:00
parent a9f39bcff3
commit 8f209efb99
2 changed files with 12 additions and 23 deletions

View file

@ -1,22 +1,14 @@
import traceback
import typing
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter
from loguru import logger
from typing import List
from formatron.formatter import FormatterBuilder
from formatron.schemas import json_schema
from formatron.integrations.exllamav2 import create_formatter_filter
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter
from formatron.extractor import NonterminalExtractor
def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model"""
# TODO: Unsure if this is needed with formatron
pass
from formatron.formatter import FormatterBuilder
from formatron.integrations.exllamav2 import create_formatter_filter
from formatron.schemas import json_schema
from loguru import logger
class ExLlamaV2Grammar:
@ -102,7 +94,7 @@ class ExLlamaV2Grammar:
f = FormatterBuilder()
f.append_line(
f"{f.extractor(
lambda nonterminal: CustomExtractor(nonterminal, kbnf_string)
lambda nonterminal: CFGExtractor(nonterminal, kbnf_string)
)}"
)
except Exception:
@ -119,16 +111,16 @@ class ExLlamaV2Grammar:
self.filters.append(lmfilter)
class CustomExtractor(NonterminalExtractor):
class CFGExtractor(NonterminalExtractor):
"""Extractor class for KBNF context-free grammar"""
def __init__(self, nonterminal: str, kbnf_string: str):
super().__init__(nonterminal)
self.kbnf_string = kbnf_string
# Fails without an extract function defined
# No idea what it does or why it's needed, but this seems to work
# TODO: Figure out how to do this properly
# Return the entire input string as the extracted string
def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
return input_str[len(input_str) :], input_str[: len(input_str)]
return "", input_str
@property
def kbnf_definition(self) -> str:

View file

@ -833,9 +833,6 @@ class ExllamaV2Container:
# Wait for other jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
# Delete references held in the grammar module
clear_grammar_func_cache()
# Clear the image embedding cache
clear_image_embedding_cache()