Model: Add filter support to dynamic gen

Dynamic gen takes in filters differently. Adjust to set the filter list
per class rather than in the generation function.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-05-23 00:55:22 -04:00 committed by Brian Dashore
parent 8ccd8fe5f8
commit 32ae62feac
2 changed files with 15 additions and 18 deletions

View file

@ -1,10 +1,10 @@
import traceback
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2Sampler
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
from lmformatenforcer import JsonSchemaParser, RegexParser
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
from loguru import logger
from typing import List
class OutlinesTokenizerWrapper:
@ -54,10 +54,14 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
class ExLlamaV2Grammar:
"""ExLlamaV2 class for various grammar filters/parsers."""
filters: List[ExLlamaV2Filter]
def __init__(self):
self.filters = []
def add_json_schema_filter(
self,
json_schema: dict,
gen_settings: ExLlamaV2Sampler.Settings,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
@ -79,13 +83,11 @@ class ExLlamaV2Grammar:
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{")
# Append the filters
gen_settings.filters.extend([lmfilter, prefix_filter])
gen_settings.filter_prefer_eos = True
self.filters.extend([lmfilter, prefix_filter])
def add_regex_filter(
self,
pattern: str,
gen_settings: ExLlamaV2Sampler.Settings,
tokenizer: ExLlamaV2Tokenizer,
):
"""Adds an ExllamaV2 filter based on regular expressions."""
@ -105,13 +107,11 @@ class ExLlamaV2Grammar:
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer)
# Append the filters
gen_settings.filters.extend([lmfilter])
gen_settings.filter_prefer_eos = True
self.filters.extend([lmfilter])
def add_ebnf_filter(
self,
ebnf_string: str,
gen_settings: ExLlamaV2Sampler.Settings,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
@ -132,5 +132,4 @@ class ExLlamaV2Grammar:
return
gen_settings.filters.append(ebnf_filter)
gen_settings.filter_prefer_eos = True
self.filters.append(ebnf_filter)

View file

@ -856,28 +856,23 @@ class ExllamaV2Container:
# Initialize grammar handler
grammar_handler = ExLlamaV2Grammar()
gen_settings.filters = []
# Add JSON schema filter if it exists
json_schema = unwrap(kwargs.get("json_schema"))
if json_schema:
grammar_handler.add_json_schema_filter(
json_schema, gen_settings, self.model, self.tokenizer
json_schema, self.model, self.tokenizer
)
# Add regex filter if it exists
regex_pattern = unwrap(kwargs.get("regex_pattern"))
if regex_pattern:
grammar_handler.add_regex_filter(
regex_pattern, gen_settings, self.tokenizer
)
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
if grammar_string:
grammar_handler.add_ebnf_filter(
grammar_string, gen_settings, self.model, self.tokenizer
)
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
# Fetch EOS tokens from generation_config if they exist
eos_tokens = (
@ -971,6 +966,7 @@ class ExllamaV2Container:
banned_tokens=banned_tokens,
banned_strings=banned_strings,
logit_bias=logit_bias,
filters=grammar_handler.filters,
)
# Log prompt to console
@ -994,6 +990,8 @@ class ExllamaV2Container:
gen_settings=gen_settings,
stop_conditions=stop_conditions,
decode_special_tokens=decode_special_tokens,
filters=grammar_handler.filters,
filter_prefer_eos=bool(grammar_handler.filters),
return_probs=request_logprobs > 0,
return_top_tokens=request_logprobs,
return_logits=request_logprobs > 0,