Model: Split cache creation into a common function

Unifies the switch statement across both draft and model caches.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-08-16 15:17:03 -04:00 committed by Brian Dashore
parent ecaddec48a
commit 5002617eac

View file

@ -548,30 +548,11 @@ class ExllamaV2Container:
if not self.quiet:
logger.info("Loading draft model: " + self.draft_config.model_dir)
if self.draft_cache_mode == "Q4":
self.draft_cache = ExLlamaV2Cache_Q4(
self.draft_model,
max_seq_len=self.cache_size,
lazy=True,
)
elif self.draft_cache_mode == "Q6":
self.draft_cache = ExLlamaV2Cache_Q6(
self.draft_model,
max_seq_len=self.cache_size,
lazy=True,
)
elif self.draft_cache_mode == "Q8":
self.draft_cache = ExLlamaV2Cache_Q8(
self.draft_model,
max_seq_len=self.cache_size,
lazy=True,
)
else:
self.draft_cache = ExLlamaV2Cache(
self.draft_model,
max_seq_len=self.cache_size,
lazy=True,
)
self.draft_cache = self.create_cache(
cache_mode=self.draft_cache_mode,
autosplit=True,
)
for value in self.draft_model.load_autosplit_gen(
self.draft_cache,
reserve_vram=autosplit_reserve,
@ -601,34 +582,10 @@ class ExllamaV2Container:
if value:
yield value
if self.cache_mode == "Q4":
self.cache = ExLlamaV2Cache_Q4(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
elif self.cache_mode == "Q6":
self.cache = ExLlamaV2Cache_Q6(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
elif self.cache_mode == "Q8":
self.cache = ExLlamaV2Cache_Q8(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
else:
self.cache = ExLlamaV2Cache(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
self.cache = self.create_cache(
cache_mode=self.cache_mode,
autosplit=self.gpu_split_auto,
)
# Load model with autosplit
if self.gpu_split_auto:
@ -647,6 +604,37 @@ class ExllamaV2Container:
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
self.model.forward(input_ids, cache=self.cache, preprocess_only=True)
def create_cache(self, cache_mode: str, autosplit: bool):
match cache_mode:
case "Q4":
return ExLlamaV2Cache_Q4(
self.model,
max_seq_len=self.cache_size,
lazy=autosplit,
batch_size=1,
)
case "Q6":
return ExLlamaV2Cache_Q6(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
case "Q8":
return ExLlamaV2Cache_Q8(
self.model,
max_seq_len=self.cache_size,
lazy=autosplit,
batch_size=1,
)
case _:
return ExLlamaV2Cache(
self.model,
max_seq_len=self.cache_size,
lazy=self.gpu_split_auto,
batch_size=1,
)
async def create_generator(self):
try:
# Don't acquire locks unless a model is loaded