Model: Move cache creation to a common function
Prevents repetitiveness while also creating a Cache class. Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
parent
9dcde59c57
commit
638eef401a
1 changed files with 31 additions and 50 deletions
|
|
@ -168,7 +168,7 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
logger.info(f"Using draft model: {str(draft_model_path.resolve())}")
|
||||
else:
|
||||
self.draft_model = None
|
||||
self.craft_cache = None
|
||||
self.draft_cache = None
|
||||
|
||||
# Turn off GPU split if the user is using 1 GPU
|
||||
gpu_count = torch.cuda.device_count()
|
||||
|
|
@ -222,61 +222,15 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
user_cache_size = unwrap(kwargs.get("cache_size"), self.max_seq_len)
|
||||
self.cache_size = self.adjust_cache_size(user_cache_size)
|
||||
self.cache_mode = unwrap(kwargs.get("cache_mode"), "FP16")
|
||||
|
||||
# Alias Exl2 q-cache settings
|
||||
match self.cache_mode:
|
||||
case "Q4":
|
||||
self.cache_mode = "4,4"
|
||||
case "Q6":
|
||||
self.cache_mode = "6,6"
|
||||
case "Q8":
|
||||
self.cache_mode = "8,8"
|
||||
|
||||
split_cache_mode = re.search(r"^([2-8])\s*,\s*([2-8])$", self.cache_mode)
|
||||
if split_cache_mode:
|
||||
k_bits = int(split_cache_mode.group(1))
|
||||
v_bits = int(split_cache_mode.group(2))
|
||||
self.cache = Cache(
|
||||
self.model,
|
||||
max_num_tokens=self.cache_size,
|
||||
layer_type=CacheLayer_quant,
|
||||
k_bits=k_bits,
|
||||
v_bits=v_bits,
|
||||
)
|
||||
else:
|
||||
self.cache = Cache(self.model, max_num_tokens=self.cache_size)
|
||||
self.cache = self.create_cache(self.cache_mode, self.model)
|
||||
|
||||
# Draft cache
|
||||
if self.use_draft_model:
|
||||
# Set draft cache mode
|
||||
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
|
||||
|
||||
# Alias Exl2 q-cache settings
|
||||
match self.draft_cache_mode:
|
||||
case "Q4":
|
||||
self.draft_cache_mode = "4,4"
|
||||
case "Q6":
|
||||
self.draft_cache_mode = "6,6"
|
||||
case "Q8":
|
||||
self.draft_cache_mode = "8,8"
|
||||
|
||||
split_draft_cache_mode = re.search(
|
||||
r"^([2-8])\s*,\s*([2-8])$", self.draft_cache_mode
|
||||
self.draft_cache = self.create_cache(
|
||||
self.draft_cache_mode, self.draft_model
|
||||
)
|
||||
if split_draft_cache_mode:
|
||||
draft_k_bits = int(split_draft_cache_mode.group(1))
|
||||
draft_v_bits = int(split_draft_cache_mode.group(2))
|
||||
self.draft_cache = Cache(
|
||||
self.draft_model,
|
||||
max_num_tokens=self.cache_size,
|
||||
layer_type=CacheLayer_quant,
|
||||
k_bits=draft_k_bits,
|
||||
v_bits=draft_v_bits,
|
||||
)
|
||||
else:
|
||||
self.draft_cache = Cache(
|
||||
self.draft_model, max_num_tokens=self.cache_size
|
||||
)
|
||||
|
||||
# Max batch size
|
||||
self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)
|
||||
|
|
@ -355,6 +309,33 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
|
||||
return chunk_size
|
||||
|
||||
def create_cache(self, raw_cache_mode: str, model: Model):
|
||||
# Cast exl2 types to exl3
|
||||
match raw_cache_mode:
|
||||
case "Q4":
|
||||
raw_cache_mode = "4,4"
|
||||
case "Q6":
|
||||
raw_cache_mode = "6,6"
|
||||
case "Q8":
|
||||
raw_cache_mode = "8,8"
|
||||
|
||||
split_cache_mode = re.search(r"^([2-8])\s*,\s*([2-8])$", raw_cache_mode)
|
||||
|
||||
if split_cache_mode:
|
||||
draft_k_bits = int(split_cache_mode.group(1))
|
||||
draft_v_bits = int(split_cache_mode.group(2))
|
||||
cache = Cache(
|
||||
model,
|
||||
max_num_tokens=self.cache_size,
|
||||
layer_type=CacheLayer_quant,
|
||||
k_bits=draft_k_bits,
|
||||
v_bits=draft_v_bits,
|
||||
)
|
||||
else:
|
||||
cache = Cache(model, max_num_tokens=self.cache_size)
|
||||
|
||||
return cache
|
||||
|
||||
def model_info(self) -> ModelCard:
|
||||
"""
|
||||
Returns a dictionary of the current model's configuration parameters.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue