tabbyAPI-ollama/backends/exllamav2/grammar.py
kingbri 0858b6d4b2 Tree: Format
Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
2025-05-17 00:46:40 -04:00

161 lines
4.8 KiB
Python

import traceback
import typing
from functools import lru_cache
from typing import List
import torch
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter
from formatron.extractor import NonterminalExtractor
from formatron.formatter import FormatterBuilder
from formatron.integrations.exllamav2 import FormatterFilter, create_engine_vocabulary
from formatron.schemas import json_schema
from loguru import logger
class ExLlamaV2Grammar:
"""ExLlamaV2 class for various grammar filters/parsers."""
filters: List[ExLlamaV2Filter]
def __init__(self):
self.filters = []
def add_json_schema_filter(
self,
schema: dict,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""Adds an ExllamaV2 filter based on a JSON schema."""
# Create the parser
try:
# Add fields required by formatron if not present
if "$id" not in schema:
schema["$id"] = "https://example.com/example.json"
if "$schema" not in schema:
schema["$schema"] = "http://json-schema.org/draft-07/schema#"
# Validate schema and create formatter
schema = json_schema.create_schema(schema)
f = FormatterBuilder()
f.append_line(f"{f.json(schema)}")
except Exception:
traceback.print_exc()
logger.error(
"Skipping because the JSON schema couldn't be parsed. "
"Please read the above error for more information."
)
return
lmfilter = _create_formatter_filter(model, tokenizer, f)
# Append the filters
self.filters.append(lmfilter)
def add_regex_filter(
self,
pattern: str,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""Adds an ExllamaV2 filter based on regular expressions."""
# Create the parser
try:
# Validate regex and create formatter
f = FormatterBuilder()
f.append_line(f"{f.regex(pattern)}")
except Exception:
traceback.print_exc()
logger.error(
"Skipping because the regex pattern couldn't be parsed. "
"Please read the above error for more information."
)
return
lmfilter = _create_formatter_filter(model, tokenizer, f)
# Append the filters
self.filters.append(lmfilter)
def add_kbnf_filter(
self,
kbnf_string: str,
model: ExLlamaV2,
tokenizer: ExLlamaV2Tokenizer,
):
"""Adds an ExllamaV2 filter based on KBNF grammar."""
# Create the parser
try:
# Validate KBNF and create formatter
f = FormatterBuilder()
f.append_line(
f"""{
f.extractor(
lambda nonterminal: CFGExtractor(nonterminal, kbnf_string)
)
}"""
)
except Exception:
logger.error(
"Skipping because the KBNF string couldn't be parsed. "
"Please read the above error for more information."
)
return
lmfilter = _create_formatter_filter(model, tokenizer, f)
# Append the filters
self.filters.append(lmfilter)
class CFGExtractor(NonterminalExtractor):
"""Extractor class for KBNF context-free grammar"""
def __init__(self, nonterminal: str, kbnf_string: str):
super().__init__(nonterminal)
self.kbnf_string = kbnf_string
# Return the entire input string as the extracted string
def extract(self, input_str: str) -> typing.Optional[tuple[str, typing.Any]]:
return "", input_str
@property
def kbnf_definition(self) -> str:
return self.kbnf_string.replace("start", self.nonterminal)
@lru_cache(1)
def _create_cached_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer):
"""Build and cache engine vocabulary on first grammar run"""
return create_engine_vocabulary(tokenizer)
def _create_formatter_filter(
model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, formatter_builder: FormatterBuilder
) -> ExLlamaV2Filter:
"""
Create a formatter filter for the ExLlamaV2 engine.
Minimalist clone of formatron.integrations.exllamav2.create_formatter_filter
with lru_cache enabled for engine vocabulary
"""
vocab = _create_cached_engine_vocabulary(tokenizer)
f = formatter_builder.build(
vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens))
)
return FormatterFilter(model, tokenizer, f)
def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model"""
_create_cached_engine_vocabulary.cache_clear()