Model: Move cache creation to a common function

Prevents repetitiveness while also creating a Cache class.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-05-08 23:10:03 -04:00
parent 9dcde59c57
commit 638eef401a

View file

@ -168,7 +168,7 @@ class ExllamaV3Container(BaseModelContainer):
logger.info(f"Using draft model: {str(draft_model_path.resolve())}")
else:
self.draft_model = None
self.craft_cache = None
self.draft_cache = None
# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
@ -222,61 +222,15 @@ class ExllamaV3Container(BaseModelContainer):
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
self.cache_size = self.adjust_cache_size(user_cache_size)
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
# Alias Exl2 q-cache settings
match self.cache_mode:
case "Q4":
self.cache_mode = "4,4"
case "Q6":
self.cache_mode = "6,6"
case "Q8":
self.cache_mode = "8,8"
split_cache_mode = re.search(r"^([2-8])\s*,\s*([2-8])$", self.cache_mode)
if split_cache_mode:
k_bits = int(split_cache_mode.group(1))
v_bits = int(split_cache_mode.group(2))
self.cache = Cache(
self.model,
max_num_tokens=self.cache_size,
layer_type=CacheLayer_quant,
k_bits=k_bits,
v_bits=v_bits,
)
else:
self.cache = Cache(self.model, max_num_tokens=self.cache_size)
self.cache = self.create_cache(self.cache_mode, self.model)
# Draft cache
if self.use_draft_model:
# Set draft cache mode
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
# Alias Exl2 q-cache settings
match self.draft_cache_mode:
case "Q4":
self.draft_cache_mode = "4,4"
case "Q6":
self.draft_cache_mode = "6,6"
case "Q8":
self.draft_cache_mode = "8,8"
split_draft_cache_mode = re.search(
r"^([2-8])\s*,\s*([2-8])$", self.draft_cache_mode
self.draft_cache = self.create_cache(
self.draft_cache_mode, self.draft_model
)
if split_draft_cache_mode:
draft_k_bits = int(split_draft_cache_mode.group(1))
draft_v_bits = int(split_draft_cache_mode.group(2))
self.draft_cache = Cache(
self.draft_model,
max_num_tokens=self.cache_size,
layer_type=CacheLayer_quant,
k_bits=draft_k_bits,
v_bits=draft_v_bits,
)
else:
self.draft_cache = Cache(
self.draft_model, max_num_tokens=self.cache_size
)
# Max batch size
self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)
@ -355,6 +309,33 @@ class ExllamaV3Container(BaseModelContainer):
return chunk_size
def create_cache(self, raw_cache_mode: str, model: Model):
# Cast exl2 types to exl3
match raw_cache_mode:
case "Q4":
raw_cache_mode = "4,4"
case "Q6":
raw_cache_mode = "6,6"
case "Q8":
raw_cache_mode = "8,8"
split_cache_mode = re.search(r"^([2-8])\s*,\s*([2-8])$", raw_cache_mode)
if split_cache_mode:
draft_k_bits = int(split_cache_mode.group(1))
draft_v_bits = int(split_cache_mode.group(2))
cache = Cache(
model,
max_num_tokens=self.cache_size,
layer_type=CacheLayer_quant,
k_bits=draft_k_bits,
v_bits=draft_v_bits,
)
else:
cache = Cache(model, max_num_tokens=self.cache_size)
return cache
def model_info(self) -> ModelCard:
"""
Returns a dictionary of the current model's configuration parameters.