API + Model: Add support for regex pattern constraints

Adds the ability to constrain generation via regex pattern using lm-format-enforcer.
This commit is contained in:
DocShotgun 2024-05-12 19:10:43 -07:00
parent 57525219d0
commit abe411c6fb
3 changed files with 55 additions and 0 deletions

View file

@ -97,6 +97,49 @@ class ExLlamaV2Grammar:
gen_settings.filters.extend([lmfilter, prefix_filter])
gen_settings.filter_prefer_eos = True
def add_regex_filter(
self,
pattern: str,
gen_settings: ExLlamaV2Sampler.Settings,
tokenizer: ExLlamaV2Tokenizer,
):
"""Adds an ExllamaV2 filter based on regular expressions."""
# Import optional dependencies
try:
from lmformatenforcer import RegexParser
from lmformatenforcer.integrations.exllamav2 import (
ExLlamaV2TokenEnforcerFilter,
)
except ImportError:
logger.error(
"Skipping regex parsing because "
"lm-format-enforcer is not installed.\n"
"Please run the following command in your environment "
"to reinstall dependencies:\n"
"pip install -U ."
)
return
# Create the parser
try:
pattern_parser = RegexParser(pattern)
except Exception:
traceback.print_exc()
logger.error(
"Skipping because the regex pattern couldn't be parsed. "
"Please read the above error for more information."
)
return
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer)
# Append the filters
gen_settings.filters.extend([lmfilter])
gen_settings.filter_prefer_eos = True
def add_ebnf_filter(
self,
ebnf_string: str,

View file

@ -850,6 +850,13 @@ class ExllamaV2Container:
json_schema, gen_settings, self.model, self.tokenizer
)
# Add regex filter if it exists
regex_pattern = unwrap(kwargs.get("regex_pattern"))
if regex_pattern:
grammar_handler.add_regex_filter(
regex_pattern, gen_settings, self.tokenizer
)
# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
if grammar_string:

View file

@ -138,6 +138,10 @@ class BaseSamplerRequest(BaseModel):
default_factory=lambda: get_default_sampler_value("json_schema"),
)
regex_pattern: Optional[str] = Field(
default_factory=lambda: get_default_sampler_value("regex_pattern"),
)
grammar_string: Optional[str] = Field(
default_factory=lambda: get_default_sampler_value("grammar_string"),
)
@ -312,6 +316,7 @@ class BaseSamplerRequest(BaseModel):
"cfg_scale": self.cfg_scale,
"negative_prompt": self.negative_prompt,
"json_schema": self.json_schema,
"regex_pattern": self.regex_pattern,
"grammar_string": self.grammar_string,
"speculative_ngram": self.speculative_ngram,
}