fixup: auto split
This commit is contained in:
parent
14fb573371
commit
acb3adb953
1 changed files with 36 additions and 6 deletions
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
import gc
|
||||
import math
|
||||
import pathlib
|
||||
from loguru import logger
|
||||
from typing import (
|
||||
|
|
@ -46,8 +47,11 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
cache: Cache
|
||||
tokenizer: Tokenizer
|
||||
config: Config
|
||||
gpu_split: List[float] = []
|
||||
gpu_split: List[float] | None = None
|
||||
gpu_split_auto: bool = True
|
||||
autosplit_reserve: List[float] = [96 * 1024**2]
|
||||
max_seq_len: int
|
||||
use_tp: bool = False
|
||||
|
||||
# Required methods
|
||||
@classmethod
|
||||
|
|
@ -76,17 +80,43 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
|
||||
self.max_seq_len = kwargs.get("max_seq_len")
|
||||
self.cache = Cache(self.model, max_num_tokens=self.max_seq_len)
|
||||
gpu_split = unwrap(kwargs.get("gpu_split"), [])
|
||||
|
||||
# Set GPU split options
|
||||
# Enable manual GPU split if provided
|
||||
if gpu_split:
|
||||
self.gpu_split = gpu_split
|
||||
# Try to set prompt template
|
||||
self.prompt_template = await find_prompt_template(
|
||||
kwargs.get("prompt_template"), model_directory
|
||||
)
|
||||
|
||||
# Turn off GPU split if the user is using 1 GPU
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
||||
gpu_split = unwrap(kwargs.get("gpu_split"), None)
|
||||
|
||||
# Set GPU split options
|
||||
if gpu_count == 1:
|
||||
self.gpu_split_auto = False
|
||||
logger.info("Disabling GPU split because one GPU is in use.")
|
||||
else:
|
||||
# TODO: Set tensor parallel
|
||||
|
||||
# Set GPU split options
|
||||
# Enable manual GPU split if provided
|
||||
if gpu_split:
|
||||
self.gpu_split = gpu_split
|
||||
elif gpu_split_auto and not self.use_tp:
|
||||
# Otherwise fallback to autosplit settings
|
||||
self.gpu_split_auto = gpu_split_auto
|
||||
|
||||
autosplit_reserve_megabytes = unwrap(
|
||||
kwargs.get("autosplit_reserve"), [96]
|
||||
)
|
||||
|
||||
# Reserve VRAM for each GPU
|
||||
self.autosplit_reserve = [
|
||||
int(math.ceil(value * 1024**2))
|
||||
for value in autosplit_reserve_megabytes
|
||||
]
|
||||
# TODO: speculative decoding
|
||||
|
||||
return self
|
||||
|
||||
async def load(self, progress_callback=None, **kwargs):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue