diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py new file mode 100644 index 0000000..61c78b3 --- /dev/null +++ b/backends/exllamav2/grammar.py @@ -0,0 +1,61 @@ +from common.logger import init_logger +from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer +from exllamav2.generator import ExLlamaV2Sampler + +# Temporary, remove once the exllama version is bumped +try: + from exllamav2.generator.filters import ExLlamaV2PrefixFilter + + _exllama_filter_available = True +except ImportError: + _exllama_filter_available = False + +try: + from lmformatenforcer import JsonSchemaParser + from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter + + _lmformatenforcer_available = True +except ImportError: + _lmformatenforcer_available = False + + +logger = init_logger(__name__) + + +class ExLlamaV2Grammar: + """ExLlamaV2 class for various grammar filters/parsers.""" + + def add_json_schema_filter( + self, + json_schema: dict, + gen_settings: ExLlamaV2Sampler.Settings, + model: ExLlamaV2, + tokenizer: ExLlamaV2Tokenizer, + ): + """Adds an ExllamaV2 filter based on a JSON schema.""" + + # Check if the required dependencies can be imported + if not _exllama_filter_available: + logger.warning( + "ExllamaV2PrefixFilter is not available " + "in the currently installed ExllamaV2 version." + ) + + return + + if not _lmformatenforcer_available: + logger.error( + "lmformatenforcer must be installed to parse a json schema.\n" + "Please run the following command: pip install lm-format-enforcer" + ) + + return + + # Create the parser + schema_parser = JsonSchemaParser(json_schema) + lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer) + prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{") + + # Append the filters + gen_settings.filters += [lmfilter, prefix_filter] + gen_settings.filter_prefer_eos = True diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 5575931..e602149 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -16,6 +16,7 @@ from exllamav2 import ( from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler from typing import List, Optional, Union +from backends.exllamav2.grammar import ExLlamaV2Grammar from common.gen_logging import log_generation_params, log_prompt, log_response from common.templating import ( PromptTemplate, @@ -758,6 +759,16 @@ class ExllamaV2Container: "in the model's vocab. Skipping." ) + # Initialize grammar handler + grammar_handler = ExLlamaV2Grammar() + + # Add JSON schema filter if it exists + json_schema = unwrap(kwargs.get("json_schema")) + if json_schema: + grammar_handler.add_json_schema_filter( + json_schema, gen_settings, self.model, self.tokenizer + ) + # Ban the EOS token if specified. If not, append to stop conditions # as well. # Set this below logging to avoid polluting the stop strings array diff --git a/common/sampling.py b/common/sampling.py index 9b78824..e0d20c7 100644 --- a/common/sampling.py +++ b/common/sampling.py @@ -118,6 +118,10 @@ class BaseSamplerRequest(BaseModel): default_factory=lambda: get_default_sampler_value("negative_prompt") ) + json_schema: Optional[object] = Field( + default_factory=lambda: get_default_sampler_value("json_schema"), + ) + # Aliased variables typical: Optional[float] = Field( default_factory=lambda: get_default_sampler_value("typical", 1.0), @@ -261,6 +265,7 @@ class BaseSamplerRequest(BaseModel): "mirostat_eta": self.mirostat_eta, "cfg_scale": self.cfg_scale, "negative_prompt": self.negative_prompt, + "json_schema": self.json_schema, } return {**gen_params, **kwargs}