Grammar: Initial Formatron regex and JSON schema implementation
* Replace LMFE's regex and JSON schema filters with Formatron's * Remove Outlines EBNF filter in preparation for Formatron KBNF filter * TODO: Implement Formatron KBNF filter
This commit is contained in:
parent
aa4ccd03d4
commit
0836a9317f
3 changed files with 39 additions and 122 deletions
|
|
@ -1,110 +1,20 @@
|
|||
import traceback
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
|
||||
from lmformatenforcer import (
|
||||
JsonSchemaParser,
|
||||
RegexParser,
|
||||
TokenEnforcer,
|
||||
CharacterLevelParser,
|
||||
)
|
||||
from lmformatenforcer.integrations.exllamav2 import (
|
||||
build_token_enforcer_tokenizer_data,
|
||||
)
|
||||
from exllamav2.generator.filters import ExLlamaV2Filter
|
||||
from loguru import logger
|
||||
from typing import List
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
class OutlinesTokenizerWrapper:
|
||||
"""Wrapper for Outlines tokenizer"""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
id_to_piece = self.tokenizer.get_id_to_piece_list()
|
||||
self.vocabulary = {piece: idx for idx, piece in enumerate(id_to_piece)}
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.eos_token = id_to_piece[self.tokenizer.eos_token_id]
|
||||
self.special_tokens = list(self.tokenizer.extended_id_to_piece.keys())
|
||||
|
||||
def convert_token_to_string(self, token):
|
||||
return token
|
||||
|
||||
def decode(self, tokens):
|
||||
s = ""
|
||||
id_to_piece = self.tokenizer.get_id_to_piece_list()
|
||||
for t in tokens:
|
||||
s += id_to_piece[t]
|
||||
return s
|
||||
|
||||
|
||||
class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
|
||||
"""Filter class for context-free grammar via outlines"""
|
||||
|
||||
def __init__(self, model, tokenizer, grammar):
|
||||
from outlines.fsm.fsm import CFGFSM
|
||||
|
||||
super().__init__(model, tokenizer)
|
||||
|
||||
self.wrapped_tokenizer = OutlinesTokenizerWrapper(tokenizer)
|
||||
self.fsm = CFGFSM(grammar, self.wrapped_tokenizer)
|
||||
self.state = self.fsm.first_state
|
||||
|
||||
def begin(self, prefix_str=""):
|
||||
self.state = self.fsm.first_state
|
||||
|
||||
def feed(self, token):
|
||||
self.state = self.fsm.next_state(self.state, token.item())
|
||||
|
||||
def next(self):
|
||||
return self.fsm.allowed_token_ids(self.state), set()
|
||||
|
||||
def use_background_worker(self):
|
||||
return True
|
||||
|
||||
|
||||
@lru_cache(10)
|
||||
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
|
||||
return build_token_enforcer_tokenizer_data(tokenizer)
|
||||
|
||||
|
||||
class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
|
||||
"""Filter class for LMFE"""
|
||||
|
||||
token_sequence: List[int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: ExLlamaV2,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
character_level_parser: CharacterLevelParser,
|
||||
):
|
||||
super().__init__(model, tokenizer)
|
||||
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
|
||||
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
|
||||
self.token_sequence = []
|
||||
|
||||
def begin(self, prefix_str: str):
|
||||
self.token_sequence = []
|
||||
|
||||
def feed(self, token):
|
||||
self.token_sequence.append(int(token[0][0]))
|
||||
|
||||
def next(self):
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
|
||||
if not hasattr(self, "allow_return_type_list"):
|
||||
return set(allowed_tokens), set()
|
||||
else:
|
||||
return sorted(allowed_tokens), []
|
||||
|
||||
def use_background_worker(self):
|
||||
return True
|
||||
from formatron.formatter import FormatterBuilder
|
||||
from formatron.schemas import json_schema
|
||||
from formatron.integrations.exllamav2 import create_formatter_filter
|
||||
|
||||
|
||||
def clear_grammar_func_cache():
|
||||
"""Flush tokenizer_data cache to avoid holding references to
|
||||
tokenizers after unloading a model"""
|
||||
|
||||
_get_lmfe_tokenizer_data.cache_clear()
|
||||
# TODO: Unsure if this is needed with formatron
|
||||
pass
|
||||
|
||||
|
||||
class ExLlamaV2Grammar:
|
||||
|
|
@ -117,7 +27,7 @@ class ExLlamaV2Grammar:
|
|||
|
||||
def add_json_schema_filter(
|
||||
self,
|
||||
json_schema: dict,
|
||||
schema: dict,
|
||||
model: ExLlamaV2,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
):
|
||||
|
|
@ -125,7 +35,16 @@ class ExLlamaV2Grammar:
|
|||
|
||||
# Create the parser
|
||||
try:
|
||||
schema_parser = JsonSchemaParser(json_schema)
|
||||
# Add fields required by formatron if not present
|
||||
if "$id" not in schema:
|
||||
schema["$id"] = "https://example.com/example.json"
|
||||
if "$schema" not in schema:
|
||||
schema["$schema"] = "http://json-schema.org/draft-07/schema#"
|
||||
|
||||
# Validate schema and create formatter
|
||||
schema = json_schema.create_schema(schema)
|
||||
f = FormatterBuilder()
|
||||
f.append_line(f"{f.json(schema)}")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
logger.error(
|
||||
|
|
@ -135,14 +54,10 @@ class ExLlamaV2Grammar:
|
|||
|
||||
return
|
||||
|
||||
# Allow JSON objects or JSON arrays at the top level
|
||||
json_prefixes = ["[", "{"]
|
||||
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser)
|
||||
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
|
||||
lmfilter = create_formatter_filter(model, tokenizer, f)
|
||||
|
||||
# Append the filters
|
||||
self.filters.extend([lmfilter, prefix_filter])
|
||||
self.filters.append(lmfilter)
|
||||
|
||||
def add_regex_filter(
|
||||
self,
|
||||
|
|
@ -154,7 +69,9 @@ class ExLlamaV2Grammar:
|
|||
|
||||
# Create the parser
|
||||
try:
|
||||
pattern_parser = RegexParser(pattern)
|
||||
# Validate regex and create formatter
|
||||
f = FormatterBuilder()
|
||||
f.append_line(f"{f.regex(pattern)}")
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
logger.error(
|
||||
|
|
@ -164,32 +81,33 @@ class ExLlamaV2Grammar:
|
|||
|
||||
return
|
||||
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser)
|
||||
lmfilter = create_formatter_filter(model, tokenizer, f)
|
||||
|
||||
# Append the filters
|
||||
self.filters.append(lmfilter)
|
||||
|
||||
def add_ebnf_filter(
|
||||
def add_kbnf_filter(
|
||||
self,
|
||||
ebnf_string: str,
|
||||
kbnf_string: str,
|
||||
model: ExLlamaV2,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
):
|
||||
"""
|
||||
Add an EBNF grammar filter.
|
||||
Possibly replace outlines with an in-house solution in the future.
|
||||
"""
|
||||
"""Adds an ExllamaV2 filter based on KBNF grammar."""
|
||||
|
||||
# Create the parser
|
||||
try:
|
||||
ebnf_filter = ExLlamaV2EbnfFilter(model, tokenizer, ebnf_string)
|
||||
except ImportError:
|
||||
# Validate KBNF and create formatter
|
||||
f = FormatterBuilder()
|
||||
# TODO: Implement this
|
||||
except Exception:
|
||||
logger.error(
|
||||
"Skipping EBNF parsing because Outlines is not installed.\n"
|
||||
"Please run the following command in your environment "
|
||||
"to install extra packages:\n"
|
||||
"pip install -U .[extras]"
|
||||
"Skipping because the KBNF string couldn't be parsed. "
|
||||
"Please read the above error for more information."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
self.filters.append(ebnf_filter)
|
||||
lmfilter = create_formatter_filter(model, tokenizer, f)
|
||||
|
||||
# Append the filters
|
||||
self.filters.append(lmfilter)
|
||||
|
|
|
|||
|
|
@ -1194,7 +1194,7 @@ class ExllamaV2Container:
|
|||
# Add EBNF filter if it exists
|
||||
grammar_string = unwrap(kwargs.get("grammar_string"))
|
||||
if grammar_string:
|
||||
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
|
||||
grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer)
|
||||
|
||||
# Set banned strings
|
||||
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ dependencies = [
|
|||
"sse-starlette",
|
||||
"packaging",
|
||||
"tokenizers",
|
||||
"lm-format-enforcer >= 0.9.6",
|
||||
"formatron",
|
||||
"aiofiles",
|
||||
"aiohttp",
|
||||
"async_lru",
|
||||
|
|
@ -53,7 +53,6 @@ dependencies = [
|
|||
[project.optional-dependencies]
|
||||
extras = [
|
||||
# Heavy dependencies that aren't for everyday use
|
||||
"outlines",
|
||||
"infinity-emb",
|
||||
"sentence-transformers",
|
||||
]
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue