Model: Disable banned strings if grammar is used

ExllamaV2 filters don't allow for rewinding which is what banned
strings uses. Therefore, constrained generation via LMFE or outlines
is not compatible for now.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-08-05 11:08:58 -04:00
parent 34281c2e14
commit 63650d2c3c

View file

@ -1018,8 +1018,37 @@ class ExllamaV2Container:
kwargs.get("repetition_decay"), fallback_decay, 0
)
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
# Initialize grammar handler
grammar_handler = ExLlamaV2Grammar()
# Add JSON schema filter if it exists
json_schema = unwrap(kwargs.get("json_schema"))
if json_schema:
grammar_handler.add_json_schema_filter(
json_schema, self.model, self.tokenizer
)
# Add regex filter if it exists
regex_pattern = unwrap(kwargs.get("regex_pattern"))
if regex_pattern:
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
if grammar_string:
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
# Set banned strings
banned_strings: List[str] = unwrap(kwargs.get("banned_strings"), [])
if banned_strings and len(grammar_handler.filters) > 0:
logger.warning(
"Disabling banned_strings because "
"they cannot be used with grammar filters."
)
banned_strings = []
stop_conditions: List[Union[str, int]] = unwrap(kwargs.get("stop"), [])
add_bos_token = unwrap(kwargs.get("add_bos_token"), True)
ban_eos_token = unwrap(kwargs.get("ban_eos_token"), False)
logit_bias = kwargs.get("logit_bias")
@ -1067,26 +1096,6 @@ class ExllamaV2Container:
"in the model's vocab. Skipping."
)
# Initialize grammar handler
grammar_handler = ExLlamaV2Grammar()
# Add JSON schema filter if it exists
json_schema = unwrap(kwargs.get("json_schema"))
if json_schema:
grammar_handler.add_json_schema_filter(
json_schema, self.model, self.tokenizer
)
# Add regex filter if it exists
regex_pattern = unwrap(kwargs.get("regex_pattern"))
if regex_pattern:
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
# Add EBNF filter if it exists
grammar_string = unwrap(kwargs.get("grammar_string"))
if grammar_string:
grammar_handler.add_ebnf_filter(grammar_string, self.model, self.tokenizer)
# Fetch EOS tokens from generation_config if they exist
eos_tokens = (
self.generation_config.eos_tokens()