Model: Enable draft model q-cache in Exl3
* Remove unneeded default fp16 cache layer import
This commit is contained in:
parent
58e34ba4c5
commit
a635a719d7
1 changed files with 28 additions and 5 deletions
|
|
@ -20,7 +20,7 @@ from exllamav3 import (
|
|||
Model,
|
||||
Tokenizer,
|
||||
)
|
||||
from exllamav3.cache import CacheLayer_fp16, CacheLayer_quant
|
||||
from exllamav3.cache import CacheLayer_quant
|
||||
from loguru import logger
|
||||
|
||||
from backends.base_model_container import BaseModelContainer
|
||||
|
|
@ -76,6 +76,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
max_seq_len: int = 4096
|
||||
cache_size: int = 4096
|
||||
cache_mode: str = "FP16"
|
||||
draft_cache_mode: str = "FP16"
|
||||
chunk_size: int = 2048
|
||||
max_batch_size: Optional[int] = None
|
||||
|
||||
|
|
@ -245,13 +246,35 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
v_bits=v_bits,
|
||||
)
|
||||
else:
|
||||
self.cache = Cache(
|
||||
self.model, max_num_tokens=self.cache_size, layer_type=CacheLayer_fp16
|
||||
)
|
||||
self.cache = Cache(self.model, max_num_tokens=self.cache_size)
|
||||
|
||||
# Draft cache
|
||||
if self.use_draft_model:
|
||||
self.draft_cache = Cache(self.draft_model, max_num_tokens = self.cache_size)
|
||||
# Set draft cache mode
|
||||
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
|
||||
|
||||
# Alias Exl2 q-cache settings
|
||||
match self.draft_cache_mode:
|
||||
case "Q4":
|
||||
self.draft_cache_mode = "4,4"
|
||||
case "Q6":
|
||||
self.draft_cache_mode = "6,6"
|
||||
case "Q8":
|
||||
self.draft_cache_mode = "8,8"
|
||||
|
||||
split_draft_cache_mode = re.search(r"^([2-8])\s*,\s*([2-8])$", self.draft_cache_mode)
|
||||
if split_draft_cache_mode:
|
||||
draft_k_bits = int(split_draft_cache_mode.group(1))
|
||||
draft_v_bits = int(split_draft_cache_mode.group(2))
|
||||
self.draft_cache = Cache(
|
||||
self.draft_model,
|
||||
max_num_tokens=self.cache_size,
|
||||
layer_type=CacheLayer_quant,
|
||||
k_bits=draft_k_bits,
|
||||
v_bits=draft_v_bits,
|
||||
)
|
||||
else:
|
||||
self.draft_cache = Cache(self.draft_model, max_num_tokens = self.cache_size)
|
||||
|
||||
# Max batch size
|
||||
self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue