Merge pull request #254 from lucyknada/main
add draft_gpu_split option for spec decoding
This commit is contained in:
commit
2e491472d1
3 changed files with 59 additions and 16 deletions
|
|
@ -89,7 +89,8 @@ class ExllamaV2Container:
|
|||
generation_config: Optional[GenerationConfig] = None
|
||||
|
||||
# GPU split vars
|
||||
gpu_split: Optional[list] = None
|
||||
gpu_split: List[float] = []
|
||||
draft_gpu_split: List[float] = []
|
||||
gpu_split_auto: bool = True
|
||||
autosplit_reserve: List[float] = [96 * 1024**2]
|
||||
use_tp: bool = False
|
||||
|
|
@ -180,6 +181,7 @@ class ExllamaV2Container:
|
|||
)
|
||||
draft_model_path = draft_model_path / draft_model_name
|
||||
|
||||
self.draft_gpu_split = draft_args.get("draft_gpu_split")
|
||||
self.draft_model_dir = draft_model_path
|
||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
|
|
@ -232,6 +234,15 @@ class ExllamaV2Container:
|
|||
for value in autosplit_reserve_megabytes
|
||||
]
|
||||
|
||||
# Change the GPU device list only if gpu_split's list is too small
|
||||
# This allows for an uneven list specification
|
||||
if self.draft_gpu_split and len(self.draft_gpu_split) > len(self.gpu_split):
|
||||
gpu_device_list = [
|
||||
device_idx
|
||||
for device_idx, memory in enumerate(self.draft_gpu_split)
|
||||
if memory > 0
|
||||
]
|
||||
|
||||
# Hardcode max output length to 16
|
||||
self.config.max_output_len = 16
|
||||
|
||||
|
|
@ -375,6 +386,7 @@ class ExllamaV2Container:
|
|||
# Set draft cache mode
|
||||
self.draft_cache_mode = unwrap(draft_args.get("draft_cache_mode"), "FP16")
|
||||
|
||||
# Edit the draft config size
|
||||
if chunk_size:
|
||||
self.draft_config.max_input_len = chunk_size
|
||||
self.draft_config.max_attention_size = chunk_size**2
|
||||
|
|
@ -619,21 +631,41 @@ class ExllamaV2Container:
|
|||
|
||||
# Draft uses the autosplit loader, so create a cache that reflects this
|
||||
draft_cache_class = self.get_cache_class(self.draft_cache_mode)
|
||||
self.draft_cache = self.create_cache(
|
||||
cache_class=draft_cache_class,
|
||||
autosplit=True,
|
||||
use_tp=False,
|
||||
model=self.draft_model,
|
||||
)
|
||||
|
||||
for value in self.draft_model.load_autosplit_gen(
|
||||
self.draft_cache,
|
||||
reserve_vram=autosplit_reserve,
|
||||
last_id_only=True,
|
||||
callback_gen=progress_callback,
|
||||
):
|
||||
if value:
|
||||
yield value
|
||||
if self.draft_gpu_split:
|
||||
logger.info("Loading with a manual GPU split (or a one GPU setup)")
|
||||
|
||||
for value in self.draft_model.load_gen(
|
||||
self.draft_gpu_split,
|
||||
callback_gen=progress_callback,
|
||||
):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
self.draft_cache = self.create_cache(
|
||||
cache_class=draft_cache_class,
|
||||
autosplit=False,
|
||||
use_tp=False,
|
||||
model=self.draft_model,
|
||||
)
|
||||
else:
|
||||
logger.info("Loading with autosplit")
|
||||
|
||||
self.draft_cache = self.create_cache(
|
||||
cache_class=draft_cache_class,
|
||||
autosplit=True,
|
||||
use_tp=False,
|
||||
model=self.draft_model,
|
||||
)
|
||||
|
||||
for value in self.draft_model.load_autosplit_gen(
|
||||
self.draft_cache,
|
||||
reserve_vram=autosplit_reserve,
|
||||
last_id_only=True,
|
||||
callback_gen=progress_callback,
|
||||
):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
# Test VRAM allocation with a full-length forward pass
|
||||
input_ids = torch.zeros((1, self.config.max_input_len), dtype=torch.long)
|
||||
|
|
|
|||
|
|
@ -351,6 +351,13 @@ class DraftModelConfig(BaseConfigModel):
|
|||
f"Possible values: {str(CACHE_SIZES)[15:-1]}."
|
||||
),
|
||||
)
|
||||
draft_gpu_split: List[float] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"An integer array of GBs of VRAM to split between GPUs (default: []).\n"
|
||||
"If this isn't filled in, the draft model is autosplit."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class LoraInstanceModel(BaseConfigModel):
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ network:
|
|||
# Turn on this option if you are ONLY connecting from localhost.
|
||||
disable_auth: false
|
||||
|
||||
# Disable fetching external content in response to requests, such as images from URLs.
|
||||
# Disable fetching external content in response to requests,such as images from URLs.
|
||||
disable_fetch_requests: false
|
||||
|
||||
# Send tracebacks over the API (default: False).
|
||||
|
|
@ -166,6 +166,10 @@ draft_model:
|
|||
# Possible values: 'FP16', 'Q8', 'Q6', 'Q4'.
|
||||
draft_cache_mode: FP16
|
||||
|
||||
# An integer array of GBs of VRAM to split between GPUs (default: []).
|
||||
# If this isn't filled in, the draft model is autosplit.
|
||||
draft_gpu_split: []
|
||||
|
||||
# Options for Loras
|
||||
lora:
|
||||
# Directory to look for LoRAs (default: loras).
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue