Model: Split cache creation into a common function
Unifies the switch statement across both draft and model caches. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
ecaddec48a
commit
5002617eac
1 changed files with 40 additions and 52 deletions
|
|
@ -548,30 +548,11 @@ class ExllamaV2Container:
|
|||
if not self.quiet:
|
||||
logger.info("Loading draft model: " + self.draft_config.model_dir)
|
||||
|
||||
if self.draft_cache_mode == "Q4":
|
||||
self.draft_cache = ExLlamaV2Cache_Q4(
|
||||
self.draft_model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=True,
|
||||
)
|
||||
elif self.draft_cache_mode == "Q6":
|
||||
self.draft_cache = ExLlamaV2Cache_Q6(
|
||||
self.draft_model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=True,
|
||||
)
|
||||
elif self.draft_cache_mode == "Q8":
|
||||
self.draft_cache = ExLlamaV2Cache_Q8(
|
||||
self.draft_model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=True,
|
||||
)
|
||||
else:
|
||||
self.draft_cache = ExLlamaV2Cache(
|
||||
self.draft_model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=True,
|
||||
)
|
||||
self.draft_cache = self.create_cache(
|
||||
cache_mode=self.draft_cache_mode,
|
||||
autosplit=True,
|
||||
)
|
||||
|
||||
for value in self.draft_model.load_autosplit_gen(
|
||||
self.draft_cache,
|
||||
reserve_vram=autosplit_reserve,
|
||||
|
|
@ -601,34 +582,10 @@ class ExllamaV2Container:
|
|||
if value:
|
||||
yield value
|
||||
|
||||
if self.cache_mode == "Q4":
|
||||
self.cache = ExLlamaV2Cache_Q4(
|
||||
self.model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=self.gpu_split_auto,
|
||||
batch_size=1,
|
||||
)
|
||||
elif self.cache_mode == "Q6":
|
||||
self.cache = ExLlamaV2Cache_Q6(
|
||||
self.model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=self.gpu_split_auto,
|
||||
batch_size=1,
|
||||
)
|
||||
elif self.cache_mode == "Q8":
|
||||
self.cache = ExLlamaV2Cache_Q8(
|
||||
self.model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=self.gpu_split_auto,
|
||||
batch_size=1,
|
||||
)
|
||||
else:
|
||||
self.cache = ExLlamaV2Cache(
|
||||
self.model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=self.gpu_split_auto,
|
||||
batch_size=1,
|
||||
)
|
||||
self.cache = self.create_cache(
|
||||
cache_mode=self.cache_mode,
|
||||
autosplit=self.gpu_split_auto,
|
||||
)
|
||||
|
||||
# Load model with autosplit
|
||||
if self.gpu_split_auto:
|
||||
|
|
@ -647,6 +604,37 @@ class ExllamaV2Container:
|
|||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
|
||||
|
||||
def create_cache(self, cache_mode: str, autosplit: bool):
|
||||
match cache_mode:
|
||||
case "Q4":
|
||||
return ExLlamaV2Cache_Q4(
|
||||
self.model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=autosplit,
|
||||
batch_size=1,
|
||||
)
|
||||
case "Q6":
|
||||
return ExLlamaV2Cache_Q6(
|
||||
self.model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=self.gpu_split_auto,
|
||||
batch_size=1,
|
||||
)
|
||||
case "Q8":
|
||||
return ExLlamaV2Cache_Q8(
|
||||
self.model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=autosplit,
|
||||
batch_size=1,
|
||||
)
|
||||
case _:
|
||||
return ExLlamaV2Cache(
|
||||
self.model,
|
||||
max_seq_len=self.cache_size,
|
||||
lazy=self.gpu_split_auto,
|
||||
batch_size=1,
|
||||
)
|
||||
|
||||
async def create_generator(self):
|
||||
try:
|
||||
# Don't acquire locks unless a model is loaded
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue