Merge branch 'main' into pydantic-config
This commit is contained in:
commit
f05229bce4
4 changed files with 67 additions and 52 deletions
|
|
@ -1,9 +1,13 @@
|
|||
import traceback
|
||||
from exllamav2 import ExLlamaV2, ExLlamaV2Tokenizer
|
||||
from exllamav2.generator.filters import ExLlamaV2Filter, ExLlamaV2PrefixFilter
|
||||
from lmformatenforcer import JsonSchemaParser, RegexParser
|
||||
from lmformatenforcer import (
|
||||
JsonSchemaParser,
|
||||
RegexParser,
|
||||
TokenEnforcer,
|
||||
CharacterLevelParser,
|
||||
)
|
||||
from lmformatenforcer.integrations.exllamav2 import (
|
||||
ExLlamaV2TokenEnforcerFilter,
|
||||
build_token_enforcer_tokenizer_data,
|
||||
)
|
||||
from loguru import logger
|
||||
|
|
@ -54,12 +58,48 @@ class ExLlamaV2EbnfFilter(ExLlamaV2Filter):
|
|||
def next(self):
|
||||
return self.fsm.allowed_token_ids(self.state), set()
|
||||
|
||||
def use_background_worker(self):
|
||||
return True
|
||||
|
||||
|
||||
@lru_cache(10)
|
||||
def _get_lmfe_tokenizer_data(tokenizer: ExLlamaV2Tokenizer):
|
||||
return build_token_enforcer_tokenizer_data(tokenizer)
|
||||
|
||||
|
||||
class ExLlamaV2TokenEnforcerFilter(ExLlamaV2Filter):
|
||||
"""Filter class for LMFE"""
|
||||
|
||||
token_sequence: List[int]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: ExLlamaV2,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
character_level_parser: CharacterLevelParser,
|
||||
):
|
||||
super().__init__(model, tokenizer)
|
||||
tokenizer_data = _get_lmfe_tokenizer_data(tokenizer)
|
||||
self.token_enforcer = TokenEnforcer(tokenizer_data, character_level_parser)
|
||||
self.token_sequence = []
|
||||
|
||||
def begin(self, prefix_str: str):
|
||||
self.token_sequence = []
|
||||
|
||||
def feed(self, token):
|
||||
self.token_sequence.append(int(token[0][0]))
|
||||
|
||||
def next(self):
|
||||
allowed_tokens = self.token_enforcer.get_allowed_tokens(self.token_sequence)
|
||||
if not hasattr(self, "allow_return_type_list"):
|
||||
return set(allowed_tokens), set()
|
||||
else:
|
||||
return sorted(allowed_tokens), []
|
||||
|
||||
def use_background_worker(self):
|
||||
return True
|
||||
|
||||
|
||||
def clear_grammar_func_cache():
|
||||
"""Flush tokenizer_data cache to avoid holding references to
|
||||
tokenizers after unloading a model"""
|
||||
|
|
@ -98,9 +138,7 @@ class ExLlamaV2Grammar:
|
|||
# Allow JSON objects or JSON arrays at the top level
|
||||
json_prefixes = ["[", "{"]
|
||||
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(
|
||||
schema_parser, _get_lmfe_tokenizer_data(tokenizer)
|
||||
)
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, schema_parser)
|
||||
prefix_filter = ExLlamaV2PrefixFilter(model, tokenizer, json_prefixes)
|
||||
|
||||
# Append the filters
|
||||
|
|
@ -109,6 +147,7 @@ class ExLlamaV2Grammar:
|
|||
def add_regex_filter(
|
||||
self,
|
||||
pattern: str,
|
||||
model: ExLlamaV2,
|
||||
tokenizer: ExLlamaV2Tokenizer,
|
||||
):
|
||||
"""Adds an ExllamaV2 filter based on regular expressions."""
|
||||
|
|
@ -125,9 +164,7 @@ class ExLlamaV2Grammar:
|
|||
|
||||
return
|
||||
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(
|
||||
pattern_parser, _get_lmfe_tokenizer_data(tokenizer)
|
||||
)
|
||||
lmfilter = ExLlamaV2TokenEnforcerFilter(model, tokenizer, pattern_parser)
|
||||
|
||||
# Append the filters
|
||||
self.filters.append(lmfilter)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from exllamav2 import (
|
|||
ExLlamaV2Cache_Q4,
|
||||
ExLlamaV2Cache_Q6,
|
||||
ExLlamaV2Cache_Q8,
|
||||
ExLlamaV2Cache_TP,
|
||||
ExLlamaV2Tokenizer,
|
||||
ExLlamaV2Lora,
|
||||
)
|
||||
|
|
@ -55,14 +56,6 @@ from common.templating import (
|
|||
from common.transformers_utils import GenerationConfig, HuggingFaceConfig
|
||||
from common.utils import coalesce, unwrap
|
||||
|
||||
# Dynamic imports
|
||||
try:
|
||||
from exllamav2 import ExLlamaV2Cache_TP
|
||||
|
||||
has_tp = True
|
||||
except ImportError:
|
||||
has_tp = False
|
||||
|
||||
|
||||
class ExllamaV2Container:
|
||||
"""The model container class for ExLlamaV2 models."""
|
||||
|
|
@ -197,17 +190,10 @@ class ExllamaV2Container:
|
|||
else:
|
||||
# Set tensor parallel
|
||||
if use_tp:
|
||||
if has_tp:
|
||||
self.use_tp = True
|
||||
self.use_tp = True
|
||||
|
||||
# TP has its own autosplit loader
|
||||
self.gpu_split_auto = False
|
||||
else:
|
||||
# TODO: Remove conditional with exl2 v0.1.9 release
|
||||
logger.warning(
|
||||
"Tensor parallelism is not supported in the "
|
||||
"current ExllamaV2 version."
|
||||
)
|
||||
# TP has its own autosplit loader
|
||||
self.gpu_split_auto = False
|
||||
|
||||
# Enable manual GPU split if provided
|
||||
if gpu_split:
|
||||
|
|
@ -703,7 +689,7 @@ class ExllamaV2Container:
|
|||
):
|
||||
"""Utility function to create a model cache."""
|
||||
|
||||
if has_tp and use_tp:
|
||||
if use_tp:
|
||||
return ExLlamaV2Cache_TP(
|
||||
model,
|
||||
base=cache_class,
|
||||
|
|
@ -967,14 +953,6 @@ class ExllamaV2Container:
|
|||
Meant for dev wheels!
|
||||
"""
|
||||
|
||||
if unwrap(kwargs.get("dry_allowed_length"), 0) > 0 and not hasattr(
|
||||
ExLlamaV2Sampler.Settings, "dry_multiplier"
|
||||
):
|
||||
logger.warning(
|
||||
"DRY sampling is not supported by the currently "
|
||||
"installed ExLlamaV2 version."
|
||||
)
|
||||
|
||||
return kwargs
|
||||
|
||||
async def generate_gen(
|
||||
|
|
@ -1141,7 +1119,7 @@ class ExllamaV2Container:
|
|||
# Add regex filter if it exists
|
||||
regex_pattern = unwrap(kwargs.get("regex_pattern"))
|
||||
if regex_pattern:
|
||||
grammar_handler.add_regex_filter(regex_pattern, self.tokenizer)
|
||||
grammar_handler.add_regex_filter(regex_pattern, self.model, self.tokenizer)
|
||||
|
||||
# Add EBNF filter if it exists
|
||||
grammar_string = unwrap(kwargs.get("grammar_string"))
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from loguru import logger
|
|||
def check_exllama_version():
|
||||
"""Verifies the exllama version"""
|
||||
|
||||
required_version = version.parse("0.2.1")
|
||||
required_version = version.parse("0.2.2")
|
||||
current_version = version.parse(package_version("exllamav2").split("+")[0])
|
||||
|
||||
unsupported_message = (
|
||||
|
|
|
|||
|
|
@ -68,12 +68,12 @@ cu121 = [
|
|||
"torch @ https://download.pytorch.org/whl/cu121/torch-2.3.1%2Bcu121-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Exl2
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu121.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Windows FA2 from https://github.com/bdashore3/flash-attention/releases
|
||||
"flash_attn @ https://github.com/bdashore3/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu123torch2.3.1cxx11abiFALSE-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
|
|
@ -95,12 +95,12 @@ cu118 = [
|
|||
"torch @ https://download.pytorch.org/whl/cu118/torch-2.3.1%2Bcu118-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Exl2
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp312-cp312-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp311-cp311-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp310-cp310-win_amd64.whl ; platform_system == 'Windows' and python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp311-cp311-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+cu118.torch2.3.1-cp310-cp310-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.10'",
|
||||
|
||||
# Linux FA2 from https://github.com/Dao-AILab/flash-attention/releases
|
||||
"flash_attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.6.3/flash_attn-2.6.3+cu118torch2.3cxx11abiFALSE-cp312-cp312-linux_x86_64.whl ; platform_system == 'Linux' and platform_machine == 'x86_64' and python_version == '3.12'",
|
||||
|
|
@ -119,9 +119,9 @@ amd = [
|
|||
"torch @ https://download.pytorch.org/whl/rocm6.0/torch-2.3.1%2Brocm6.0-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||
|
||||
# Exl2
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.1/exllamav2-0.2.1+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+rocm6.0.torch2.3.1-cp312-cp312-linux_x86_64.whl ; python_version == '3.12'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+rocm6.0.torch2.3.1-cp311-cp311-linux_x86_64.whl ; python_version == '3.11'",
|
||||
"exllamav2 @ https://github.com/turboderp/exllamav2/releases/download/v0.2.2/exllamav2-0.2.2+rocm6.0.torch2.3.1-cp310-cp310-linux_x86_64.whl ; python_version == '3.10'",
|
||||
]
|
||||
|
||||
# MARK: Ruff options
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue