Model: Add EBNF grammar support
Using the Outlines library, add support to supply EBNF strings and pass them to the library for parsing. From there, a wrapper is created and a filter is passed to generation. Replace with an in-house solution at some point that's more flexible. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
57b3d69949
commit
f6d749c771
3 changed files with 119 additions and 16 deletions
|
|
@ -1,8 +1,10 @@
|
|||
import traceback
|
||||
from common.logger import init_logger
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||
from exllamav2.generator import ExLlamaV2Sampler
|
||||
from exllamav2.generator.filters import ExLlamaV2Filter
|
||||
|
||||
# Temporary, remove once the exllama version is bumped
|
||||
# TODO: Remove after new exllama version is released
|
||||
try:
|
||||
from exllamav2.generator.filters import ExLlamaV2PrefixFilter
|
||||
|
||||
|
|
@ -10,18 +12,54 @@ try:
|
|||
except ImportError:
|
||||
_exllama_filter_available = False
|
||||
|
||||
try:
|
||||
from lmformatenforcer import JsonSchemaParser
|
||||
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
|
||||
|
||||
_lmformatenforcer_available = True
|
||||
except ImportError:
|
||||
_lmformatenforcer_available = False
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class OutlinesTokenizerWrapper:
|
||||
"""Wrapper for Outlines tokenizer"""
|
||||
|
||||
def __init__(self, tokenizer):
|
||||
self.tokenizer = tokenizer
|
||||
id_to_piece = self.tokenizer.get_id_to_piece_list()
|
||||
self.vocabulary = {piece: idx for idx, piece in enumerate(id_to_piece)}
|
||||
self.eos_token_id = self.tokenizer.eos_token_id
|
||||
self.eos_token = id_to_piece[self.tokenizer.eos_token_id]
|
||||
self.special_tokens = list(self.tokenizer.extended_id_to_piece.keys())
|
||||
|
||||
def convert_token_to_string(self, token):
|
||||
return token
|
||||
|
||||
def decode(self, tokens):
|
||||
s = ""
|
||||
id_to_piece = self.tokenizer.get_id_to_piece_list()
|
||||
for t in tokens:
|
||||
s += id_to_piece[t]
|
||||
return s
|
||||
|
||||
|
||||
class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
|
||||
"""Filter class for context-free grammar via outlines"""
|
||||
|
||||
def __init__(self, model, tokenizer, grammar):
|
||||
from outlines.fsm.fsm import CFGFSM
|
||||
|
||||
super().__init__(model, tokenizer)
|
||||
|
||||
self.wrapped_tokenizer = OutlinesTokenizerWrapper(tokenizer)
|
||||
self.fsm = CFGFSM(grammar, self.wrapped_tokenizer)
|
||||
self.state = self.fsm.first_state
|
||||
|
||||
def begin(self, prefix_str=""):
|
||||
self.state = self.fsm.first_state
|
||||
|
||||
def feed(self, token):
|
||||
self.state = self.fsm.next_state(self.state, token.item())
|
||||
|
||||
def next(self):
|
||||
return self.fsm.allowed_token_ids(self.state), set()
|
||||
|
||||
|
||||
class ExLlamaV2Grammar:
|
||||
"""ExLlamaV2 class for various grammar filters/parsers."""
|
||||
|
||||
|
|
@ -34,28 +72,80 @@ class ExLlamaV2Grammar:
|
|||
):
|
||||
"""Adds an ExllamaV2 filter based on a JSON schema."""
|
||||
|
||||
# Check if the required dependencies can be imported
|
||||
if not _exllama_filter_available:
|
||||
logger.warning(
|
||||
"ExllamaV2PrefixFilter is not available "
|
||||
"in the currently installed ExllamaV2 version."
|
||||
"in the currently installed ExllamaV2 version. "
|
||||
"Skipping JSON schema parsing."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
if not _lmformatenforcer_available:
|
||||
# Import optional dependencies
|
||||
try:
|
||||
from lmformatenforcer import JsonSchemaParser
|
||||
from lmformatenforcer.integrations.exllamav2 import (
|
||||
ExLlamaV2TokenEnforcerFilter,
|
||||
)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"lmformatenforcer must be installed to parse a json schema.\n"
|
||||
"Please run the following command: pip install lm-format-enforcer"
|
||||
"Skipping JSON schema parsing because "
|
||||
"lm-format-enforcer is not installed.\n"
|
||||
"Please run the following command: "
|
||||
"pip install lm-format-enforcer"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# Create the parser
|
||||
schema_parser = JsonSchemaParser(json_schema)
|
||||
try:
|
||||
schema_parser = JsonSchemaParser(json_schema)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
logger.error(
|
||||
"Skipping because the JSON schema couldn't be parsed. "
|
||||
"Please read the above error for more information."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer)
|
||||
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{")
|
||||
|
||||
# Append the filters
|
||||
gen_settings.filters += [lmfilter, prefix_filter]
|
||||
gen_settings.filters.extend([lmfilter, prefix_filter])
|
||||
gen_settings.filter_prefer_eos = True
|
||||
|
||||
def add_ebnf_filter(
|
||||
self,
|
||||
ebnf_string: str,
|
||||
gen_settings: ExLlamaV2Sampler.Settings,
|
||||
model: ExLlamaV2,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
):
|
||||
"""
|
||||
Add an EBNF grammar filter.
|
||||
Possibly replace outlines with an in-house solution in the future.
|
||||
"""
|
||||
|
||||
if not _exllama_filter_available:
|
||||
logger.warning(
|
||||
"filter_prefer_eos is not available "
|
||||
"in the currently installed ExllamaV2 version. "
|
||||
"Skipping EBNF parsing."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
try:
|
||||
ebnf_filter = ExLlamaV2EbnfFilter(model, tokenizer, ebnf_string)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Skipping EBNF parsing because Outlines is not installed.\n"
|
||||
"Please run the following command: pip install outlines"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
gen_settings.filters.append(ebnf_filter)
|
||||
gen_settings.filter_prefer_eos = True
|
||||
|
|
|
|||
|
|
@ -761,6 +761,7 @@ class ExllamaV2Container:
|
|||
|
||||
# Initialize grammar handler
|
||||
grammar_handler = ExLlamaV2Grammar()
|
||||
gen_settings.filters = []
|
||||
|
||||
# Add JSON schema filter if it exists
|
||||
json_schema = unwrap(kwargs.get("json_schema"))
|
||||
|
|
@ -769,6 +770,13 @@ class ExllamaV2Container:
|
|||
json_schema, gen_settings, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Add EBNF filter if it exists
|
||||
grammar_string = unwrap(kwargs.get("grammar_string"))
|
||||
if grammar_string:
|
||||
grammar_handler.add_ebnf_filter(
|
||||
grammar_string, gen_settings, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions
|
||||
# as well.
|
||||
# Set this below logging to avoid polluting the stop strings array
|
||||
|
|
|
|||
|
|
@ -122,6 +122,10 @@ class BaseSamplerRequest(BaseModel):
|
|||
default_factory=lambda: get_default_sampler_value("json_schema"),
|
||||
)
|
||||
|
||||
grammar_string: Optional[str] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("grammar_string"),
|
||||
)
|
||||
|
||||
# Aliased variables
|
||||
typical: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("typical", 1.0),
|
||||
|
|
@ -266,6 +270,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
"cfg_scale": self.cfg_scale,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
"json_schema": self.json_schema,
|
||||
"grammar_string": self.grammar_string,
|
||||
}
|
||||
|
||||
return {**gen_params, **kwargs}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue