diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index 731970c..f34c112 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -97,6 +97,49 @@ class ExLlamaV2Grammar: gen_settings.filters.extend([lmfilter, prefix_filter]) gen_settings.filter_prefer_eos = True + def add_regex_filter( + self, + pattern: str, + gen_settings: ExLlamaV2Sampler.Settings, + tokenizer: ExLlamaV2Tokenizer, + ): + """Adds an ExllamaV2 filter based on regular expressions.""" + + # Import optional dependencies + try: + from lmformatenforcer import RegexParser + from lmformatenforcer.integrations.exllamav2 import ( + ExLlamaV2TokenEnforcerFilter, + ) + except ImportError: + logger.error( + "Skipping regex parsing because " + "lm-format-enforcer is not installed.\n" + "Please run the following command in your environment " + "to reinstall dependencies:\n" + "pip install -U ." + ) + + return + + # Create the parser + try: + pattern_parser = RegexParser(pattern) + except Exception: + traceback.print_exc() + logger.error( + "Skipping because the regex pattern couldn't be parsed. " + "Please read the above error for more information." + ) + + return + + lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer) + + # Append the filters + gen_settings.filters.extend([lmfilter]) + gen_settings.filter_prefer_eos = True + def add_ebnf_filter( self, ebnf_string: str, diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index ce08259..2246110 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -850,6 +850,13 @@ class ExllamaV2Container: json_schema, gen_settings, self.model, self.tokenizer ) + # Add regex filter if it exists + regex_pattern = unwrap(kwargs.get("regex_pattern")) + if regex_pattern: + grammar_handler.add_regex_filter( + regex_pattern, gen_settings, self.tokenizer + ) + # Add EBNF filter if it exists grammar_string = unwrap(kwargs.get("grammar_string")) if grammar_string: diff --git a/common/sampling.py b/common/sampling.py index e2201c1..0d808be 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -138,6 +138,10 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("json_schema"), ) + regex_pattern: Optional[str] = Field( + default_factory=lambda: get_default_sampler_value("regex_pattern"), + ) + grammar_string: Optional[str] = Field( default_factory=lambda: get_default_sampler_value("grammar_string"), ) @@ -312,6 +316,7 @@ class BaseSamplerRequest(BaseModel): "cfg_scale": self.cfg_scale, "negative_prompt": self.negative_prompt, "json_schema": self.json_schema, + "regex_pattern": self.regex_pattern, "grammar_string": self.grammar_string, "speculative_ngram": self.speculative_ngram, }