API + Model: Add support for regex pattern constraints
Adds the ability to constrain generation via regex pattern using lm-format-enforcer.
This commit is contained in:
parent
57525219d0
commit
abe411c6fb
3 changed files with 55 additions and 0 deletions
|
|
@ -97,6 +97,49 @@ class ExLlamaV2Grammar:
|
|||
gen_settings.filters.extend([lmfilter, prefix_filter])
|
||||
gen_settings.filter_prefer_eos = True
|
||||
|
||||
def add_regex_filter(
|
||||
self,
|
||||
pattern: str,
|
||||
gen_settings: ExLlamaV2Sampler.Settings,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
):
|
||||
"""Adds an ExllamaV2 filter based on regular expressions."""
|
||||
|
||||
# Import optional dependencies
|
||||
try:
|
||||
from lmformatenforcer import RegexParser
|
||||
from lmformatenforcer.integrations.exllamav2 import (
|
||||
ExLlamaV2TokenEnforcerFilter,
|
||||
)
|
||||
except ImportError:
|
||||
logger.error(
|
||||
"Skipping regex parsing because "
|
||||
"lm-format-enforcer is not installed.\n"
|
||||
"Please run the following command in your environment "
|
||||
"to reinstall dependencies:\n"
|
||||
"pip install -U ."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# Create the parser
|
||||
try:
|
||||
pattern_parser = RegexParser(pattern)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
logger.error(
|
||||
"Skipping because the regex pattern couldn't be parsed. "
|
||||
"Please read the above error for more information."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer)
|
||||
|
||||
# Append the filters
|
||||
gen_settings.filters.extend([lmfilter])
|
||||
gen_settings.filter_prefer_eos = True
|
||||
|
||||
def add_ebnf_filter(
|
||||
self,
|
||||
ebnf_string: str,
|
||||
|
|
|
|||
|
|
@ -850,6 +850,13 @@ class ExllamaV2Container:
|
|||
json_schema, gen_settings, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Add regex filter if it exists
|
||||
regex_pattern = unwrap(kwargs.get("regex_pattern"))
|
||||
if regex_pattern:
|
||||
grammar_handler.add_regex_filter(
|
||||
regex_pattern, gen_settings, self.tokenizer
|
||||
)
|
||||
|
||||
# Add EBNF filter if it exists
|
||||
grammar_string = unwrap(kwargs.get("grammar_string"))
|
||||
if grammar_string:
|
||||
|
|
|
|||
|
|
@ -138,6 +138,10 @@ class BaseSamplerRequest(BaseModel):
|
|||
default_factory=lambda: get_default_sampler_value("json_schema"),
|
||||
)
|
||||
|
||||
regex_pattern: Optional[str] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("regex_pattern"),
|
||||
)
|
||||
|
||||
grammar_string: Optional[str] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("grammar_string"),
|
||||
)
|
||||
|
|
@ -312,6 +316,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
"cfg_scale": self.cfg_scale,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
"json_schema": self.json_schema,
|
||||
"regex_pattern": self.regex_pattern,
|
||||
"grammar_string": self.grammar_string,
|
||||
"speculative_ngram": self.speculative_ngram,
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue