Model + API: GPU split updates and fixes

For the TP loader, GPU split cannot be an empty array. However,
defaulting the parameter to an empty array makes it easier to calculate
the device list. Therefore, cast an empty array to None using
falsy comparisons at load time.

Also add draft_gpu_split to the load request.

Signed-off-by: kingbri <8082010+kingbri1@users.noreply.github.com>
This commit is contained in:
kingbri 2025-02-15 21:50:14 -05:00
parent 304df16543
commit 9f649647f0
2 changed files with 10 additions and 5 deletions

View file

@ -181,7 +181,7 @@ class ExllamaV2Container:
)
draft_model_path = draft_model_path / draft_model_name
self.draft_gpu_split = draft_args.get("draft_gpu_split")
self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), [])
self.draft_model_dir = draft_model_path
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
@ -195,7 +195,7 @@ class ExllamaV2Container:
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
use_tp = unwrap(kwargs.get("tensor_parallel"), False)
gpu_split = kwargs.get("gpu_split")
gpu_split = unwrap(kwargs.get("gpu_split"), [])
gpu_device_list = list(range(0, gpu_count))
# Set GPU split options
@ -691,8 +691,10 @@ class ExllamaV2Container:
if self.use_tp:
logger.info("Loading with tensor parallel")
# GPU split must be None if the array is empty
# Otherwise the TP loader fails
for value in self.model.load_tp_gen(
self.gpu_split,
self.gpu_split or None,
callback_gen=progress_callback,
expect_cache_base=cache_class,
expect_cache_tokens=self.cache_size,

View file

@ -63,7 +63,10 @@ class DraftModelLoadRequest(BaseModel):
default=None,
examples=[1.0],
)
draft_cache_mode: Optional[str] = None
draft_gpu_split: Optional[List[float]] = Field(
default_factory=list,
examples=[[24.0, 20.0]],
)
class ModelLoadRequest(BaseModel):
@ -94,7 +97,7 @@ class ModelLoadRequest(BaseModel):
gpu_split_auto: Optional[bool] = None
autosplit_reserve: Optional[List[float]] = None
gpu_split: Optional[List[float]] = Field(
default=None,
default_factory=list,
examples=[[24.0, 20.0]],
)
rope_scale: Optional[float] = Field(