add draft_gpu_split option

This commit is contained in:
lucy 2024-11-27 02:52:19 +01:00 committed by GitHub
parent aa4ccd03d4
commit ab1f4b7a6a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -90,6 +90,7 @@ class ExllamaV2Container:
# GPU split vars
gpu_split: Optional[list] = None
draft_gpu_split: Optional[list] = None
gpu_split_auto: bool = True
autosplit_reserve: List[float] = [96 * 1024**2]
use_tp: bool = False
@ -180,6 +181,7 @@ class ExllamaV2Container:
)
draft_model_path = draft_model_path / draft_model_name
self.draft_gpu_split = draft_args.get("draft_gpu_split")
self.draft_model_dir = draft_model_path
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
@ -232,6 +234,16 @@ class ExllamaV2Container:
for value in autosplit_reserve_megabytes
]
if self.draft_gpu_split:
self.gpu_split_auto = False
self.gpu_split = gpu_split
gpu_device_list = [
device_idx
for device_idx, memory in enumerate(self.draft_gpu_split)
if memory > 0
]
# Hardcode max output length to 16
self.config.max_output_len = 16
@ -617,21 +629,37 @@ class ExllamaV2Container:
# Draft uses the autosplit loader, so create a cache that reflects this
draft_cache_class = self.get_cache_class(self.draft_cache_mode)
self.draft_cache = self.create_cache(
cache_class=draft_cache_class,
autosplit=True,
use_tp=False,
model=self.draft_model,
)
for value in self.draft_model.load_autosplit_gen(
self.draft_cache,
reserve_vram=autosplit_reserve,
last_id_only=True,
callback_gen=progress_callback,
):
if value:
yield value
if self.draft_gpu_split:
for value in self.draft_model.load_gen(
self.draft_gpu_split,
callback_gen=progress_callback,
):
if value:
yield value
self.draft_cache = self.create_cache(
cache_class=draft_cache_class,
autosplit=False,
use_tp=False,
model=self.draft_model,
)
else:
self.draft_cache = self.create_cache(
cache_class=draft_cache_class,
autosplit=True,
use_tp=False,
model=self.draft_model,
)
for value in self.draft_model.load_autosplit_gen(
self.draft_cache,
reserve_vram=autosplit_reserve,
last_id_only=True,
callback_gen=progress_callback,
):
if value:
yield value
# Test VRAM allocation with a full-length forward pass
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)