161 lines
4.8 KiB
Python
161 lines
4.8 KiB
Python
import traceback
|
|
import typing
|
|
from functools import lru_cache
|
|
from typing import List
|
|
|
|
import torch
|
|
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
|
from exllamav2.generator.filters import ExLlamaV2Filter
|
|
from formatron.extractor import NonterminalExtractor
|
|
from formatron.formatter import FormatterBuilder
|
|
from formatron.integrations.exllamav2 import FormatterFilter, create_engine_vocabulary
|
|
from formatron.schemas import json_schema
|
|
from loguru import logger
|
|
|
|
|
|
class ExLlamaV2Grammar:
|
|
"""ExLlamaV2 class for various grammar filters/parsers."""
|
|
|
|
filters: List[ExLlamaV2Filter]
|
|
|
|
def __init__(self):
|
|
self.filters = []
|
|
|
|
def add_json_schema_filter(
|
|
self,
|
|
schema: dict,
|
|
model: ExLlamaV2,
|
|
tokenizer: ExLlamaV2Tokenizer,
|
|
):
|
|
"""Adds an ExllamaV2 filter based on a JSON schema."""
|
|
|
|
# Create the parser
|
|
try:
|
|
# Add fields required by formatron if not present
|
|
if "$id" not in schema:
|
|
schema["$id"] = "https://example.com/example.json"
|
|
if "$schema" not in schema:
|
|
schema["$schema"] = "http://json-schema.org/draft-07/schema#"
|
|
|
|
# Validate schema and create formatter
|
|
schema = json_schema.create_schema(schema)
|
|
f = FormatterBuilder()
|
|
f.append_line(f"{f.json(schema)}")
|
|
except Exception:
|
|
traceback.print_exc()
|
|
logger.error(
|
|
"Skipping because the JSON schema couldn't be parsed. "
|
|
"Please read the above error for more information."
|
|
)
|
|
|
|
return
|
|
|
|
lmfilter = _create_formatter_filter(model, tokenizer, f)
|
|
|
|
# Append the filters
|
|
self.filters.append(lmfilter)
|
|
|
|
def add_regex_filter(
|
|
self,
|
|
pattern: str,
|
|
model: ExLlamaV2,
|
|
tokenizer: ExLlamaV2Tokenizer,
|
|
):
|
|
"""Adds an ExllamaV2 filter based on regular expressions."""
|
|
|
|
# Create the parser
|
|
try:
|
|
# Validate regex and create formatter
|
|
f = FormatterBuilder()
|
|
f.append_line(f"{f.regex(pattern)}")
|
|
except Exception:
|
|
traceback.print_exc()
|
|
logger.error(
|
|
"Skipping because the regex pattern couldn't be parsed. "
|
|
"Please read the above error for more information."
|
|
)
|
|
|
|
return
|
|
|
|
lmfilter = _create_formatter_filter(model, tokenizer, f)
|
|
|
|
# Append the filters
|
|
self.filters.append(lmfilter)
|
|
|
|
def add_kbnf_filter(
|
|
self,
|
|
kbnf_string: str,
|
|
model: ExLlamaV2,
|
|
tokenizer: ExLlamaV2Tokenizer,
|
|
):
|
|
"""Adds an ExllamaV2 filter based on KBNF grammar."""
|
|
|
|
# Create the parser
|
|
try:
|
|
# Validate KBNF and create formatter
|
|
f = FormatterBuilder()
|
|
f.append_line(
|
|
f"""{
|
|
f.extractor(
|
|
lambda nonterminal: CFGExtractor(nonterminal, kbnf_string)
|
|
)
|
|
}"""
|
|
)
|
|
except Exception:
|
|
logger.error(
|
|
"Skipping because the KBNF string couldn't be parsed. "
|
|
"Please read the above error for more information."
|
|
)
|
|
|
|
return
|
|
|
|
lmfilter = _create_formatter_filter(model, tokenizer, f)
|
|
|
|
# Append the filters
|
|
self.filters.append(lmfilter)
|
|
|
|
|
|
class CFGExtractor(NonterminalExtractor):
|
|
"""Extractor class for KBNF context-free grammar"""
|
|
|
|
def __init__(self, nonterminal: str, kbnf_string: str):
|
|
super().__init__(nonterminal)
|
|
self.kbnf_string = kbnf_string
|
|
|
|
# Return the entire input string as the extracted string
|
|
def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
|
|
return "", input_str
|
|
|
|
@property
|
|
def kbnf_definition(self) -> str:
|
|
return self.kbnf_string.replace("start", self.nonterminal)
|
|
|
|
|
|
@lru_cache(1)
|
|
def _create_cached_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer):
|
|
"""Build and cache engine vocabulary on first grammar run"""
|
|
|
|
return create_engine_vocabulary(tokenizer)
|
|
|
|
|
|
def _create_formatter_filter(
|
|
model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, formatter_builder: FormatterBuilder
|
|
) -> ExLlamaV2Filter:
|
|
"""
|
|
Create a formatter filter for the ExLlamaV2 engine.
|
|
Minimalist clone of formatron.integrations.exllamav2.create_formatter_filter
|
|
with lru_cache enabled for engine vocabulary
|
|
"""
|
|
|
|
vocab = _create_cached_engine_vocabulary(tokenizer)
|
|
f = formatter_builder.build(
|
|
vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens))
|
|
)
|
|
return FormatterFilter(model, tokenizer, f)
|
|
|
|
|
|
def clear_grammar_func_cache():
|
|
"""Flush tokenizer_data cache to avoid holding references to
|
|
tokenizers after unloading a model"""
|
|
|
|
_create_cached_engine_vocabulary.cache_clear()
|