Grammar: Preliminary Formatron KBNF support

This commit is contained in:
DocShotgun 2024-11-23 12:05:41 -08:00
parent 0836a9317f
commit a9f39bcff3

View file

@ -1,4 +1,5 @@
import traceback
import typing
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter
from loguru import logger
@ -7,6 +8,7 @@ from typing import List
from formatron.formatter import FormatterBuilder
from formatron.schemas import json_schema
from formatron.integrations.exllamav2 import create_formatter_filter
from formatron.extractor import NonterminalExtractor
def clear_grammar_func_cache():
@ -98,7 +100,11 @@ class ExLlamaV2Grammar:
try:
# Validate KBNF and create formatter
f = FormatterBuilder()
# TODO: Implement this
f.append_line(
f"{f.extractor(
lambda nonterminal: CustomExtractor(nonterminal, kbnf_string)
)}"
)
except Exception:
logger.error(
"Skipping because the KBNF string couldn't be parsed. "
@ -111,3 +117,19 @@ class ExLlamaV2Grammar:
# Append the filters
self.filters.append(lmfilter)
class CustomExtractor(NonterminalExtractor):
def __init__(self, nonterminal: str, kbnf_string: str):
super().__init__(nonterminal)
self.kbnf_string = kbnf_string
# Fails without an extract function defined
# No idea what it does or why it's needed, but this seems to work
# TODO: Figure out how to do this properly
def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
return input_str[len(input_str) :], input_str[: len(input_str)]
@property
def kbnf_definition(self) -> str:
return self.kbnf_string.replace("start", self.nonterminal)