Model: Add filter support to dynamic gen
Dynamic gen takes in filters differently. Adjust to set the filter list per class rather than in the generation function. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
8ccd8fe5f8
commit
32ae62feac
2 changed files with 15 additions and 18 deletions
|
|
@ -1,10 +1,10 @@
|
|||
import traceback
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||
from exllamav2.generator import ExLlamaV2Sampler
|
||||
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
|
||||
from lmformatenforcer import JsonSchemaParser, RegexParser
|
||||
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
|
||||
from loguru import logger
|
||||
from typing import List
|
||||
|
||||
|
||||
class OutlinesTokenizerWrapper:
|
||||
|
|
@ -54,10 +54,14 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
|
|||
class ExLlamaV2Grammar:
|
||||
"""ExLlamaV2 class for various grammar filters/parsers."""
|
||||
|
||||
filters: List[ExLlamaV2Filter]
|
||||
|
||||
def __init__(self):
|
||||
self.filters = []
|
||||
|
||||
def add_json_schema_filter(
|
||||
self,
|
||||
json_schema: dict,
|
||||
gen_settings: ExLlamaV2Sampler.Settings,
|
||||
model: ExLlamaV2,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
):
|
||||
|
|
@ -79,13 +83,11 @@ class ExLlamaV2Grammar:
|
|||
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{")
|
||||
|
||||
# Append the filters
|
||||
gen_settings.filters.extend([lmfilter, prefix_filter])
|
||||
gen_settings.filter_prefer_eos = True
|
||||
self.filters.extend([lmfilter, prefix_filter])
|
||||
|
||||
def add_regex_filter(
|
||||
self,
|
||||
pattern: str,
|
||||
gen_settings: ExLlamaV2Sampler.Settings,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
):
|
||||
"""Adds an ExllamaV2 filter based on regular expressions."""
|
||||
|
|
@ -105,13 +107,11 @@ class ExLlamaV2Grammar:
|
|||
lmfilter = ExLlamaV2TokenEnforcerFilter(pattern_parser, tokenizer)
|
||||
|
||||
# Append the filters
|
||||
gen_settings.filters.extend([lmfilter])
|
||||
gen_settings.filter_prefer_eos = True
|
||||
self.filters.extend([lmfilter])
|
||||
|
||||
def add_ebnf_filter(
|
||||
self,
|
||||
ebnf_string: str,
|
||||
gen_settings: ExLlamaV2Sampler.Settings,
|
||||
model: ExLlamaV2,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
):
|
||||
|
|
@ -132,5 +132,4 @@ class ExLlamaV2Grammar:
|
|||
|
||||
return
|
||||
|
||||
gen_settings.filters.append(ebnf_filter)
|
||||
gen_settings.filter_prefer_eos = True
|
||||
self.filters.append(ebnf_filter)
|
||||
|
|
|
|||
|
|
@ -856,28 +856,23 @@ class ExllamaV2Container:
|
|||
|
||||
# Initialize grammar handler
|
||||
grammar_handler = ExLlamaV2Grammar()
|
||||
gen_settings.filters = []
|
||||
|
||||
# Add JSON schema filter if it exists
|
||||
json_schema = unwrap(kwargs.get("json_schema"))
|
||||
if json_schema:
|
||||
grammar_handler.add_json_schema_filter(
|
||||
json_schema, gen_settings, self.model, self.tokenizer
|
||||
json_schema, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Add regex filter if it exists
|
||||
regex_pattern = unwrap(kwargs.get("regex_pattern"))
|
||||
if regex_pattern:
|
||||
grammar_handler.add_regex_filter(
|
||||
regex_pattern, gen_settings, self.tokenizer
|
||||
)
|
||||
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
|
||||
|
||||
# Add EBNF filter if it exists
|
||||
grammar_string = unwrap(kwargs.get("grammar_string"))
|
||||
if grammar_string:
|
||||
grammar_handler.add_ebnf_filter(
|
||||
grammar_string, gen_settings, self.model, self.tokenizer
|
||||
)
|
||||
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
|
||||
|
||||
# Fetch EOS tokens from generation_config if they exist
|
||||
eos_tokens = (
|
||||
|
|
@ -971,6 +966,7 @@ class ExllamaV2Container:
|
|||
banned_tokens=banned_tokens,
|
||||
banned_strings=banned_strings,
|
||||
logit_bias=logit_bias,
|
||||
filters=grammar_handler.filters,
|
||||
)
|
||||
|
||||
# Log prompt to console
|
||||
|
|
@ -994,6 +990,8 @@ class ExllamaV2Container:
|
|||
gen_settings=gen_settings,
|
||||
stop_conditions=stop_conditions,
|
||||
decode_special_tokens=decode_special_tokens,
|
||||
filters=grammar_handler.filters,
|
||||
filter_prefer_eos=bool(grammar_handler.filters),
|
||||
return_probs=request_logprobs > 0,
|
||||
return_top_tokens=request_logprobs,
|
||||
return_logits=request_logprobs > 0,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue