Grammar: Cache the engine vocabulary

* Avoid rebuilding the KBNF engine vocabulary on every grammar-enabled request
This commit is contained in:
DocShotgun 2024-12-05 21:36:37 -08:00
parent 8ccd7a12a2
commit 7f899734c0
2 changed files with 40 additions and 4 deletions

View file

@ -1,12 +1,14 @@
import traceback
import typing
from functools import lru_cache
from typing import List
import torch
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
from exllamav2.generator.filters import ExLlamaV2Filter
from formatron.extractor import NonterminalExtractor
from formatron.formatter import FormatterBuilder
from formatron.integrations.exllamav2 import create_formatter_filter
from formatron.integrations.exllamav2 import FormatterFilter, create_engine_vocabulary
from formatron.schemas import json_schema
from loguru import logger
@ -48,7 +50,7 @@ class ExLlamaV2Grammar:
return
lmfilter = create_formatter_filter(model, tokenizer, f)
lmfilter = _create_formatter_filter(model, tokenizer, f)
# Append the filters
self.filters.append(lmfilter)
@ -75,7 +77,7 @@ class ExLlamaV2Grammar:
return
lmfilter = create_formatter_filter(model, tokenizer, f)
lmfilter = _create_formatter_filter(model, tokenizer, f)
# Append the filters
self.filters.append(lmfilter)
@ -104,7 +106,7 @@ class ExLlamaV2Grammar:
return
lmfilter = create_formatter_filter(model, tokenizer, f)
lmfilter = _create_formatter_filter(model, tokenizer, f)
# Append the filters
self.filters.append(lmfilter)
@ -124,3 +126,33 @@ class CFGExtractor(NonterminalExtractor):
@property
def kbnf_definition(self) -> str:
return self.kbnf_string.replace("start", self.nonterminal)
@lru_cache(1)
def _create_cached_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer):
"""Build and cache engine vocabulary on first grammar run"""
return create_engine_vocabulary(tokenizer)
def _create_formatter_filter(
model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, formatter_builder: FormatterBuilder
) -> ExLlamaV2Filter:
"""
Create a formatter filter for the ExLlamaV2 engine.
Minimalist clone of formatron.integrations.exllamav2.create_formatter_filter
with lru_cache enabled for engine vocabulary
"""
vocab = _create_cached_engine_vocabulary(tokenizer)
f = formatter_builder.build(
vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens))
)
return FormatterFilter(model, tokenizer, f)
def clear_grammar_func_cache():
"""Flush tokenizer_data cache to avoid holding references to
tokenizers after unloading a model"""
_create_cached_engine_vocabulary.cache_clear()

View file

@ -39,6 +39,7 @@ from common.health import HealthManager
from backends.exllamav2.grammar import (
ExLlamaV2Grammar,
clear_grammar_func_cache,
)
from backends.exllamav2.utils import (
exllama_disabled_flash_attn,
@ -832,6 +833,9 @@ class ExllamaV2Container:
# Wait for other jobs to finish
await self.wait_for_jobs(kwargs.get("skip_wait"))
# Delete references held in the grammar module
clear_grammar_func_cache()
# Clear the image embedding cache
clear_image_embedding_cache()