API + Model: Add support for JSON schema constraints
Add the ability to constrain the return value of a model to be JSON. Built using the JSON schema standard to define the properties of what the model should return. This feature should be more accurate than using GBNF/EBNF to yield the same results due to the use of lmformatenforcer. GBNF/EBNF will be added in a different commit/branch. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
ccd41d720d
commit
57b3d69949
3 changed files with 77 additions and 0 deletions
61
backends/exllamav2/grammar.py
Normal file
61
backends/exllamav2/grammar.py
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
from common.logger import init_logger
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||
from exllamav2.generator import ExLlamaV2Sampler
|
||||
|
||||
# Temporary, remove once the exllama version is bumped
|
||||
try:
|
||||
from exllamav2.generator.filters import ExLlamaV2PrefixFilter
|
||||
|
||||
_exllama_filter_available = True
|
||||
except ImportError:
|
||||
_exllama_filter_available = False
|
||||
|
||||
try:
|
||||
from lmformatenforcer import JsonSchemaParser
|
||||
from lmformatenforcer.integrations.exllamav2 import ExLlamaV2TokenEnforcerFilter
|
||||
|
||||
_lmformatenforcer_available = True
|
||||
except ImportError:
|
||||
_lmformatenforcer_available = False
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ExLlamaV2Grammar:
|
||||
"""ExLlamaV2 class for various grammar filters/parsers."""
|
||||
|
||||
def add_json_schema_filter(
|
||||
self,
|
||||
json_schema: dict,
|
||||
gen_settings: ExLlamaV2Sampler.Settings,
|
||||
model: ExLlamaV2,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
):
|
||||
"""Adds an ExllamaV2 filter based on a JSON schema."""
|
||||
|
||||
# Check if the required dependencies can be imported
|
||||
if not _exllama_filter_available:
|
||||
logger.warning(
|
||||
"ExllamaV2PrefixFilter is not available "
|
||||
"in the currently installed ExllamaV2 version."
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
if not _lmformatenforcer_available:
|
||||
logger.error(
|
||||
"lmformatenforcer must be installed to parse a json schema.\n"
|
||||
"Please run the following command: pip install lm-format-enforcer"
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# Create the parser
|
||||
schema_parser = JsonSchemaParser(json_schema)
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(schema_parser, tokenizer)
|
||||
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, "{")
|
||||
|
||||
# Append the filters
|
||||
gen_settings.filters += [lmfilter, prefix_filter]
|
||||
gen_settings.filter_prefer_eos = True
|
||||
|
|
@ -16,6 +16,7 @@ from exllamav2 import (
|
|||
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from backends.exllamav2.grammar import ExLlamaV2Grammar
|
||||
from common.gen_logging import log_generation_params, log_prompt, log_response
|
||||
from common.templating import (
|
||||
PromptTemplate,
|
||||
|
|
@ -758,6 +759,16 @@ class ExllamaV2Container:
|
|||
"in the model's vocab. Skipping."
|
||||
)
|
||||
|
||||
# Initialize grammar handler
|
||||
grammar_handler = ExLlamaV2Grammar()
|
||||
|
||||
# Add JSON schema filter if it exists
|
||||
json_schema = unwrap(kwargs.get("json_schema"))
|
||||
if json_schema:
|
||||
grammar_handler.add_json_schema_filter(
|
||||
json_schema, gen_settings, self.model, self.tokenizer
|
||||
)
|
||||
|
||||
# Ban the EOS token if specified. If not, append to stop conditions
|
||||
# as well.
|
||||
# Set this below logging to avoid polluting the stop strings array
|
||||
|
|
|
|||
|
|
@ -118,6 +118,10 @@ class BaseSamplerRequest(BaseModel):
|
|||
default_factory=lambda: get_default_sampler_value("negative_prompt")
|
||||
)
|
||||
|
||||
json_schema: Optional[object] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("json_schema"),
|
||||
)
|
||||
|
||||
# Aliased variables
|
||||
typical: Optional[float] = Field(
|
||||
default_factory=lambda: get_default_sampler_value("typical", 1.0),
|
||||
|
|
@ -261,6 +265,7 @@ class BaseSamplerRequest(BaseModel):
|
|||
"mirostat_eta": self.mirostat_eta,
|
||||
"cfg_scale": self.cfg_scale,
|
||||
"negative_prompt": self.negative_prompt,
|
||||
"json_schema": self.json_schema,
|
||||
}
|
||||
|
||||
return {**gen_params, **kwargs}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue