diff --git a/backends/exllamav2/grammar.py b/backends/exllamav2/grammar.py index adff61f..3ad2f44 100644 --- a/backends/exllamav2/grammar.py +++ b/backends/exllamav2/grammar.py @@ -1,9 +1,13 @@ import traceback from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter -from lmformatenforcer import JsonSchemaParser, RegexParser +from lmformatenforcer import ( + JsonSchemaParser, + RegexParser, + TokenEnforcer, + CharacterLevelParser, +) from lmformatenforcer.integrations.exllamav2 import ( - ExLlamaV2TokenEnforcerFilter, build_token_enforcer_tokenizer_data, ) from loguru import logger @@ -54,12 +58,48 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter): def next(self): return self.fsm.allowed_token_ids(self.state), set() + def use_background_worker(self): + return True + @lru_cache(10) def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer): return build_token_enforcer_tokenizer_data(tokenizer) +class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter): + """Filter class for LMFE""" + + token_sequence: List[int] + + def __init__( + self, + model: ExLlamaV2, + tokenizer: ExLlamaV2Tokenizer, + character_level_parser: CharacterLevelParser, + ): + super().__init__(model, tokenizer) + tokenizer_data = _get_lmfe_tokenizer_data(tokenizer) + self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser) + self.token_sequence = [] + + def begin(self, prefix_str: str): + self.token_sequence = [] + + def feed(self, token): + self.token_sequence.append(int(token[0][0])) + + def next(self): + allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence) + if not hasattr(self, "allow_return_type_list"): + return set(allowed_tokens), set() + else: + return sorted(allowed_tokens), [] + + def use_background_worker(self): + return True + + def clear_grammar_func_cache(): """Flush tokenizer_data cache to avoid holding references to tokenizers after unloading a model""" @@ -98,9 +138,7 @@ class ExLlamaV2Grammar: # Allow JSON objects or JSON arrays at the top level json_prefixes = ["[", "{"] - lmfilter = ExLlamaV2TokenEnforcerFilter( - schema_parser, _get_lmfe_tokenizer_data(tokenizer) - ) + lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser) prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes) # Append the filters @@ -109,6 +147,7 @@ class ExLlamaV2Grammar: def add_regex_filter( self, pattern: str, + model: ExLlamaV2, tokenizer: ExLlamaV2Tokenizer, ): """Adds an ExllamaV2 filter based on regular expressions.""" @@ -125,9 +164,7 @@ class ExLlamaV2Grammar: return - lmfilter = ExLlamaV2TokenEnforcerFilter( - pattern_parser, _get_lmfe_tokenizer_data(tokenizer) - ) + lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser) # Append the filters self.filters.append(lmfilter) diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 4aedf75..207232a 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -17,6 +17,7 @@ from exllamav2 import ( ExLlamaV2Cache_Q4, ExLlamaV2Cache_Q6, ExLlamaV2Cache_Q8, + ExLlamaV2Cache_TP, ExLlamaV2Tokenizer, ExLlamaV2Lora, ) @@ -55,14 +56,6 @@ from common.templating import ( from common.transformers_utils import GenerationConfig, HuggingFaceConfig from common.utils import coalesce, unwrap -# Dynamic imports -try: - from exllamav2 import ExLlamaV2Cache_TP - - has_tp = True -except ImportError: - has_tp = False - class ExllamaV2Container: """The model container class for ExLlamaV2 models.""" @@ -197,17 +190,10 @@ class ExllamaV2Container: else: # Set tensor parallel if use_tp: - if has_tp: - self.use_tp = True + self.use_tp = True - # TP has its own autosplit loader - self.gpu_split_auto = False - else: - # TODO: Remove conditional with exl2 v0.1.9 release - logger.warning( - "Tensor parallelism is not supported in the " - "current ExllamaV2 version." - ) + # TP has its own autosplit loader + self.gpu_split_auto = False # Enable manual GPU split if provided if gpu_split: @@ -703,7 +689,7 @@ class ExllamaV2Container: ): """Utility function to create a model cache.""" - if has_tp and use_tp: + if use_tp: return ExLlamaV2Cache_TP( model, base=cache_class, @@ -967,14 +953,6 @@ class ExllamaV2Container: Meant for dev wheels! """ - if unwrap(kwargs.get("dry_allowed_length"), 0) > 0 and not hasattr( - ExLlamaV2Sampler.Settings, "dry_multiplier" - ): - logger.warning( - "DRY sampling is not supported by the currently " - "installed ExLlamaV2 version." - ) - return kwargs async def generate_gen( @@ -1141,7 +1119,7 @@ class ExllamaV2Container: # Add regex filter if it exists regex_pattern = unwrap(kwargs.get("regex_pattern")) if regex_pattern: - grammar_handler.add_regex_filter(regex_pattern, self.tokenizer) + grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer) # Add EBNF filter if it exists grammar_string = unwrap(kwargs.get("grammar_string")) diff --git a/backends/exllamav2/utils.py b/backends/exllamav2/utils.py index 4c192b2..b7a9f54 100644 --- a/backends/exllamav2/utils.py +++ b/backends/exllamav2/utils.py @@ -8,7 +8,7 @@ from loguru import logger def check_exllama_version(): """Verifies the exllama version""" - required_version = version.parse("0.2.1") + required_version = version.parse("0.2.2") current_version = version.parse(package_version("exllamav2").split("+")[0]) unsupported_message = ( diff --git a/pyproject.toml b/pyproject.toml index 19fcbce..ad6f945 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,12 +68,12 @@ cu121 = [ "torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Windows FA2 from https://github.com/bdashore3/flash-attention/releases "flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", @@ -95,12 +95,12 @@ cu118 = [ "torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'", # Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases "flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'", @@ -119,9 +119,9 @@ amd = [ "torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", # Exl2 - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", - "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'", + "exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'", ] # MARK: Ruff options