diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index 1b2ed3c..3ab1403 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -1,10 +1,10 @@ import traceback from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer -from exllamav2.generator import ExLlamaV2Sampler from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter from lmformatenforcer import JsonSchemaParser, RegexParser from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter from loguru import logger +from typing import List class OutlinesTokenizerWrapper: @@ -54,10 +54,14 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter): class ExLlamaV2Grammar: """ExLlamaV2 class for various grammar filters/parsers.""" + filters: List[ExLlamaV2Filter] + + def __init__(self): + self.filters = [] + def add_json_schema_filter( self, json_schema: dict, - gen_settings: ExLlamaV2Sampler.Settings, model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, ): @@ -79,13 +83,11 @@ class ExLlamaV2Grammar: prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{") # Append the filters - gen_settings.filters.extend([lmfilter, prefix_filter]) - gen_settings.filter_prefer_eos = True + self.filters.extend([lmfilter, prefix_filter]) def add_regex_filter( self, pattern: str, - gen_settings: ExLlamaV2Sampler.Settings, tokenizer: ExLlamaV2Tokenizer, ): """Adds an ExllamaV2 filter based on regular expressions.""" @@ -105,13 +107,11 @@ class ExLlamaV2Grammar: lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer) # Append the filters - gen_settings.filters.extend([lmfilter]) - gen_settings.filter_prefer_eos = True + self.filters.extend([lmfilter]) def add_ebnf_filter( self, ebnf_string: str, - gen_settings: ExLlamaV2Sampler.Settings, model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, ): @@ -132,5 +132,4 @@ class ExLlamaV2Grammar: return - gen_settings.filters.append(ebnf_filter) - gen_settings.filter_prefer_eos = True + self.filters.append(ebnf_filter) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index c1275df..3f302f8 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -856,28 +856,23 @@ class ExllamaV2Container: # Initialize grammar handler grammar_handler = ExLlamaV2Grammar() - gen_settings.filters = [] # Add JSON schema filter if it exists json_schema = unwrap(kwargs.get("json_schema")) if json_schema: grammar_handler.add_json_schema_filter( - json_schema, gen_settings, self.model, self.tokenizer + json_schema, self.model, self.tokenizer ) # Add regex filter if it exists regex_pattern = unwrap(kwargs.get("regex_pattern")) if regex_pattern: - grammar_handler.add_regex_filter( - regex_pattern, gen_settings, self.tokenizer - ) + grammar_handler.add_regex_filter(regex_pattern, self.tokenizer) # Add EBNF filter if it exists grammar_string = unwrap(kwargs.get("grammar_string")) if grammar_string: - grammar_handler.add_ebnf_filter( - grammar_string, gen_settings, self.model, self.tokenizer - ) + grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer) # Fetch EOS tokens from generation_config if they exist eos_tokens = ( @@ -971,6 +966,7 @@ class ExllamaV2Container: banned_tokens=banned_tokens, banned_strings=banned_strings, logit_bias=logit_bias, + filters=grammar_handler.filters, ) # Log prompt to console @@ -994,6 +990,8 @@ class ExllamaV2Container: gen_settings=gen_settings, stop_conditions=stop_conditions, decode_special_tokens=decode_special_tokens, + filters=grammar_handler.filters, + filter_prefer_eos=bool(grammar_handler.filters), return_probs=request_logprobs > 0, return_top_tokens=request_logprobs, return_logits=request_logprobs > 0,