diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index 3ad2f44..9c27491 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -1,110 +1,20 @@ import traceback from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer -from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter -from lmformatenforcer import ( - JsonSchemaParser, - RegexParser, - TokenEnforcer, - CharacterLevelParser, -) -from lmformatenforcer.integrations.exllamav2 import ( - build_token_enforcer_tokenizer_data, -) +from exllamav2.generator.filters import ExLlamaV2Filter from loguru import logger from typing import List -from functools import lru_cache - -class OutlinesTokenizerWrapper: - """Wrapper for Outlines tokenizer""" - - def __init__(self, tokenizer): - self.tokenizer = tokenizer - id_to_piece = self.tokenizer.get_id_to_piece_list() - self.vocabulary = {piece: idx for idx, piece in enumerate(id_to_piece)} - self.eos_token_id = self.tokenizer.eos_token_id - self.eos_token = id_to_piece[self.tokenizer.eos_token_id] - self.special_tokens = list(self.tokenizer.extended_id_to_piece.keys()) - - def convert_token_to_string(self, token): - return token - - def decode(self, tokens): - s = "" - id_to_piece = self.tokenizer.get_id_to_piece_list() - for t in tokens: - s += id_to_piece[t] - return s - - -class ExLlamaV2EbnfFilter(ExLlamaV2Filter): - """Filter class for context-free grammar via outlines""" - - def __init__(self, model, tokenizer, grammar): - from outlines.fsm.fsm import CFGFSM - - super().__init__(model, tokenizer) - - self.wrapped_tokenizer = OutlinesTokenizerWrapper(tokenizer) - self.fsm = CFGFSM(grammar, self.wrapped_tokenizer) - self.state = self.fsm.first_state - - def begin(self, prefix_str=""): - self.state = self.fsm.first_state - - def feed(self, token): - self.state = self.fsm.next_state(self.state, token.item()) - - def next(self): - return self.fsm.allowed_token_ids(self.state), set() - - def use_background_worker(self): - return True - - -@lru_cache(10) -def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer): - return build_token_enforcer_tokenizer_data(tokenizer) - - -class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter): - """Filter class for LMFE""" - - token_sequence: List[int] - - def __init__( - self, - model: ExLlamaV2, - tokenizer: ExLlamaV2Tokenizer, - character_level_parser: CharacterLevelParser, - ): - super().__init__(model, tokenizer) - tokenizer_data = _get_lmfe_tokenizer_data(tokenizer) - self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser) - self.token_sequence = [] - - def begin(self, prefix_str: str): - self.token_sequence = [] - - def feed(self, token): - self.token_sequence.append(int(token[0][0])) - - def next(self): - allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence) - if not hasattr(self, "allow_return_type_list"): - return set(allowed_tokens), set() - else: - return sorted(allowed_tokens), [] - - def use_background_worker(self): - return True +from formatron.formatter import FormatterBuilder +from formatron.schemas import json_schema +from formatron.integrations.exllamav2 import create_formatter_filter def clear_grammar_func_cache(): """Flush tokenizer_data cache to avoid holding references to tokenizers after unloading a model""" - _get_lmfe_tokenizer_data.cache_clear() + # TODO: Unsure if this is needed with formatron + pass class ExLlamaV2Grammar: @@ -117,7 +27,7 @@ class ExLlamaV2Grammar: def add_json_schema_filter( self, - json_schema: dict, + schema: dict, model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, ): @@ -125,7 +35,16 @@ class ExLlamaV2Grammar: # Create the parser try: - schema_parser = JsonSchemaParser(json_schema) + # Add fields required by formatron if not present + if "$id" not in schema: + schema["$id"] = "https://example.com/example.json" + if "$schema" not in schema: + schema["$schema"] = "http://json-schema.org/draft-07/schema#" + + # Validate schema and create formatter + schema = json_schema.create_schema(schema) + f = FormatterBuilder() + f.append_line(f"{f.json(schema)}") except Exception: traceback.print_exc() logger.error( @@ -135,14 +54,10 @@ class ExLlamaV2Grammar: return - # Allow JSON objects or JSON arrays at the top level - json_prefixes = ["[", "{"] - - lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser) - prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes) + lmfilter = create_formatter_filter(model, tokenizer, f) # Append the filters - self.filters.extend([lmfilter, prefix_filter]) + self.filters.append(lmfilter) def add_regex_filter( self, @@ -154,7 +69,9 @@ class ExLlamaV2Grammar: # Create the parser try: - pattern_parser = RegexParser(pattern) + # Validate regex and create formatter + f = FormatterBuilder() + f.append_line(f"{f.regex(pattern)}") except Exception: traceback.print_exc() logger.error( @@ -164,32 +81,33 @@ class ExLlamaV2Grammar: return - lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser) + lmfilter = create_formatter_filter(model, tokenizer, f) # Append the filters self.filters.append(lmfilter) - def add_ebnf_filter( + def add_kbnf_filter( self, - ebnf_string: str, + kbnf_string: str, model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, ): - """ - Add an EBNF grammar filter. - Possibly replace outlines with an in-house solution in the future. - """ + """Adds an ExllamaV2 filter based on KBNF grammar.""" + # Create the parser try: - ebnf_filter = ExLlamaV2EbnfFilter(model, tokenizer, ebnf_string) - except ImportError: + # Validate KBNF and create formatter + f = FormatterBuilder() + # TODO: Implement this + except Exception: logger.error( - "Skipping EBNF parsing because Outlines is not installed.\n" - "Please run the following command in your environment " - "to install extra packages:\n" - "pip install -U .[extras]" + "Skipping because the KBNF string couldn't be parsed. " + "Please read the above error for more information." ) return - self.filters.append(ebnf_filter) + lmfilter = create_formatter_filter(model, tokenizer, f) + + # Append the filters + self.filters.append(lmfilter) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ff11531..50cef42 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1194,7 +1194,7 @@ class ExllamaV2Container: # Add EBNF filter if it exists grammar_string = unwrap(kwargs.get("grammar_string")) if grammar_string: - grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer) + grammar_handler.add_kbnf_filter(grammar_string, self.model, self.tokenizer) # Set banned strings banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), []) diff --git a/pyproject.toml b/pyproject.toml index de782b7..efa1d76 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ dependencies = [ "sse-starlette", "packaging", "tokenizers", - "lm-format-enforcer >= 0.9.6", + "formatron", "aiofiles", "aiohttp", "async_lru", @@ -53,7 +53,6 @@ dependencies = [ [project.optional-dependencies] extras = [ # Heavy dependencies that aren't for everyday use - "outlines", "infinity-emb", "sentence-transformers", ]