Grammar: Cache the engine vocabulary
* Avoid rebuilding the KBNF engine vocabulary on every grammar-enabled request
This commit is contained in:
parent
8ccd7a12a2
commit
7f899734c0
2 changed files with 40 additions and 4 deletions
|
|
@ -1,12 +1,14 @@
|
|||
import traceback
|
||||
import typing
|
||||
from functools import lru_cache
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||
from exllamav2.generator.filters import ExLlamaV2Filter
|
||||
from formatron.extractor import NonterminalExtractor
|
||||
from formatron.formatter import FormatterBuilder
|
||||
from formatron.integrations.exllamav2 import create_formatter_filter
|
||||
from formatron.integrations.exllamav2 import FormatterFilter, create_engine_vocabulary
|
||||
from formatron.schemas import json_schema
|
||||
from loguru import logger
|
||||
|
||||
|
|
@ -48,7 +50,7 @@ class ExLlamaV2Grammar:
|
|||
|
||||
return
|
||||
|
||||
lmfilter = create_formatter_filter(model, tokenizer, f)
|
||||
lmfilter = _create_formatter_filter(model, tokenizer, f)
|
||||
|
||||
# Append the filters
|
||||
self.filters.append(lmfilter)
|
||||
|
|
@ -75,7 +77,7 @@ class ExLlamaV2Grammar:
|
|||
|
||||
return
|
||||
|
||||
lmfilter = create_formatter_filter(model, tokenizer, f)
|
||||
lmfilter = _create_formatter_filter(model, tokenizer, f)
|
||||
|
||||
# Append the filters
|
||||
self.filters.append(lmfilter)
|
||||
|
|
@ -104,7 +106,7 @@ class ExLlamaV2Grammar:
|
|||
|
||||
return
|
||||
|
||||
lmfilter = create_formatter_filter(model, tokenizer, f)
|
||||
lmfilter = _create_formatter_filter(model, tokenizer, f)
|
||||
|
||||
# Append the filters
|
||||
self.filters.append(lmfilter)
|
||||
|
|
@ -124,3 +126,33 @@ class CFGExtractor(NonterminalExtractor):
|
|||
@property
|
||||
def kbnf_definition(self) -> str:
|
||||
return self.kbnf_string.replace("start", self.nonterminal)
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def _create_cached_engine_vocabulary(tokenizer: ExLlamaV2Tokenizer):
|
||||
"""Build and cache engine vocabulary on first grammar run"""
|
||||
|
||||
return create_engine_vocabulary(tokenizer)
|
||||
|
||||
|
||||
def _create_formatter_filter(
|
||||
model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, formatter_builder: FormatterBuilder
|
||||
) -> ExLlamaV2Filter:
|
||||
"""
|
||||
Create a formatter filter for the ExLlamaV2 engine.
|
||||
Minimalist clone of formatron.integrations.exllamav2.create_formatter_filter
|
||||
with lru_cache enabled for engine vocabulary
|
||||
"""
|
||||
|
||||
vocab = _create_cached_engine_vocabulary(tokenizer)
|
||||
f = formatter_builder.build(
|
||||
vocab, lambda tokens: tokenizer.decode(torch.tensor(tokens))
|
||||
)
|
||||
return FormatterFilter(model, tokenizer, f)
|
||||
|
||||
|
||||
def clear_grammar_func_cache():
|
||||
"""Flush tokenizer_data cache to avoid holding references to
|
||||
tokenizers after unloading a model"""
|
||||
|
||||
_create_cached_engine_vocabulary.cache_clear()
|
||||
|
|
|
|||
|
|
@ -39,6 +39,7 @@ from common.health import HealthManager
|
|||
|
||||
from backends.exllamav2.grammar import (
|
||||
ExLlamaV2Grammar,
|
||||
clear_grammar_func_cache,
|
||||
)
|
||||
from backends.exllamav2.utils import (
|
||||
exllama_disabled_flash_attn,
|
||||
|
|
@ -832,6 +833,9 @@ class ExllamaV2Container:
|
|||
# Wait for other jobs to finish
|
||||
await self.wait_for_jobs(kwargs.get("skip_wait"))
|
||||
|
||||
# Delete references held in the grammar module
|
||||
clear_grammar_func_cache()
|
||||
|
||||
# Clear the image embedding cache
|
||||
clear_image_embedding_cache()
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue