Grammar: Initial Formatron regex and JSON schema implementation

* Replace LMFE's regex and JSON schema filters with Formatron's
* Remove Outlines EBNF filter in preparation for Formatron KBNF filter
* TODO: Implement Formatron KBNF filter
This commit is contained in:
DocShotgun 2024-11-23 10:27:37 -08:00
parent aa4ccd03d4
commit 0836a9317f
3 changed files with 39 additions and 122 deletions

View file

@ -1,110 +1,20 @@
import traceback
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import (
JsonSchemaParser,
RegexParser,
TokenEnforcer,
CharacterLevelParser,
)
from lmformatenforcer.integrations.exllamav2 import (
build_token_enforcer_tokenizer_data,
)
from exllamav2.generator.filters import ExLlamaV2Filter
from loguru import logger
from typing import List
from functools import lru_cache
class OutlinesTokenizerWrapper:
"""Wrapper for Outlines tokenizer"""
def __init__(self, tokenizer):
self.tokenizer = tokenizer
id_to_piece = self.tokenizer.get_id_to_piece_list()
self.vocabulary = {piece: idx for idx, piece in enumerate(id_to_piece)}
self.eos_token_id = self.tokenizer.eos_token_id
self.eos_token = id_to_piece[self.tokenizer.eos_token_id]
self.special_tokens = list(self.tokenizer.extended_id_to_piece.keys())
def convert_token_to_string(self, token):
return token
def decode(self, tokens):
s = ""
id_to_piece = self.tokenizer.get_id_to_piece_list()
for t in tokens:
s += id_to_piece[t]
return s
class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
"""Filter class for context-free grammar via outlines"""
def __init__(self, model, tokenizer, grammar):
from outlines.fsm.fsm import CFGFSM
super().__init__(model, tokenizer)
self.wrapped_tokenizer = OutlinesTokenizerWrapper(tokenizer)
self.fsm = CFGFSM(grammar, self.wrapped_tokenizer)
self.state = self.fsm.first_state
def begin(self, prefix_str=""):
self.state = self.fsm.first_state
def feed(self, token):
self.state = self.fsm.next_state(self.state, token.item())
def next(self):
return self.fsm.allowed_token_ids(self.state), set()
def use_background_worker(self):
return True
@lru_cache(10)
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
return build_token_enforcer_tokenizer_data(tokenizer)
class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
"""Filter class for LMFE"""
token_sequence: List[int]
def __init__(
self,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
character_level_parser: CharacterLevelParser,
):
super().__init__(model, tokenizer)
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
self.token_sequence = []
def begin(self, prefix_str: str):
self.token_sequence = []
def feed(self, token):
self.token_sequence.append(int(token[0][0]))
def next(self):
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
if not hasattr(self, "allow_return_type_list"):
return set(allowed_tokens), set()
else:
return sorted(allowed_tokens), []
def use_background_worker(self):
return True
from formatron.formatter import FormatterBuilder
from formatron.schemas import json_schema
from formatron.integrations.exllamav2 import create_formatter_filter
def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model"""
_get_lmfe_tokenizer_data.cache_clear()
# TODO: Unsure if this is needed with formatron
pass
class ExLlamaV2Grammar:
@ -117,7 +27,7 @@ class ExLlamaV2Grammar:
def add_json_schema_filter(
self,
json_schema: dict,
schema: dict,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
@ -125,7 +35,16 @@ class ExLlamaV2Grammar:
# Create the parser
try:
schema_parser = JsonSchemaParser(json_schema)
# Add fields required by formatron if not present
if "$id" not in schema:
schema["$id"] = "https://example.com/example.json"
if "$schema" not in schema:
schema["$schema"] = "http://json-schema.org/draft-07/schema#"
# Validate schema and create formatter
schema = json_schema.create_schema(schema)
f = FormatterBuilder()
f.append_line(f"{f.json(schema)}")
except Exception:
traceback.print_exc()
logger.error(
@ -135,14 +54,10 @@ class ExLlamaV2Grammar:
return
# Allow JSON objects or JSON arrays at the top level
json_prefixes = ["[", "{"]
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser)
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
lmfilter = create_formatter_filter(model, tokenizer, f)
# Append the filters
self.filters.extend([lmfilter, prefix_filter])
self.filters.append(lmfilter)
def add_regex_filter(
self,
@ -154,7 +69,9 @@ class ExLlamaV2Grammar:
# Create the parser
try:
pattern_parser = RegexParser(pattern)
# Validate regex and create formatter
f = FormatterBuilder()
f.append_line(f"{f.regex(pattern)}")
except Exception:
traceback.print_exc()
logger.error(
@ -164,32 +81,33 @@ class ExLlamaV2Grammar:
return
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser)
lmfilter = create_formatter_filter(model, tokenizer, f)
# Append the filters
self.filters.append(lmfilter)
def add_ebnf_filter(
def add_kbnf_filter(
self,
ebnf_string: str,
kbnf_string: str,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""
Add an EBNF grammar filter.
Possibly replace outlines with an in-house solution in the future.
"""
"""Adds an ExllamaV2 filter based on KBNF grammar."""
# Create the parser
try:
ebnf_filter = ExLlamaV2EbnfFilter(model, tokenizer, ebnf_string)
except ImportError:
# Validate KBNF and create formatter
f = FormatterBuilder()
# TODO: Implement this
except Exception:
logger.error(
"Skipping EBNF parsing because Outlines is not installed.\n"
"Please run the following command in your environment "
"to install extra packages:\n"
"pip install -U .[extras]"
"Skipping because the KBNF string couldn't be parsed. "
"Please read the above error for more information."
)
return
self.filters.append(ebnf_filter)
lmfilter = create_formatter_filter(model, tokenizer, f)
# Append the filters
self.filters.append(lmfilter)

View file

@ -1194,7 +1194,7 @@ class ExllamaV2Container:
# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
if grammar_string:
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer)
# Set banned strings
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])

View file

@ -26,7 +26,7 @@ dependencies = [
"sse-starlette",
"packaging",
"tokenizers",
"lm-format-enforcer >= 0.9.6",
"formatron",
"aiofiles",
"aiohttp",
"async_lru",
@ -53,7 +53,6 @@ dependencies = [
[project.optional-dependencies]
extras = [
# Heavy dependencies that aren't for everyday use
"outlines",
"infinity-emb",
"sentence-transformers",
]