Model: Add draft model/speculative decoding

This commit is contained in:
turboderp 2025-05-04 01:27:42 +02:00
parent 1db2cb99cb
commit 92ea7ee7cd

View file

@ -58,10 +58,11 @@ class ExllamaV3Container(BaseModelContainer):
# Exl3 vars
model: Optional[Model]
cache: Optional[Cache]
draft_model: Optional[Model]
draft_cache: Optional[Cache]
tokenizer: Optional[Tokenizer]
config: Optional[Config]
generator: Optional[AsyncGenerator] = None
tokenizer_config: Optional[TokenizerConfig] = None
draft_config: Optional[Config]
generator: Optional[AsyncGenerator]
tokenizer_config: Optional[TokenizerConfig]
@ -93,8 +94,11 @@ class ExllamaV3Container(BaseModelContainer):
self.model = None
self.cache = None
self.draft_model = None
self.draft_cache = None
self.tokenizer = None
self.config = None
self.draft_config = None
self.generator = None
self.tokenizer_config = None
@ -135,6 +139,35 @@ class ExllamaV3Container(BaseModelContainer):
# Fallback to 4096 since exl3 can't fetch from HF's config.json
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
# Prepare the draft model config if necessary
draft_args = unwrap(kwargs.get("draft_model"), {})
draft_model_name = draft_args.get("draft_model_name")
self.use_draft_model = draft_args and draft_model_name
# Always disable draft if params are incorrectly configured
if draft_args and draft_model_name is None:
logger.warning(
"Draft model is disabled because a model name "
"wasn't provided. Please check your config.yml!"
)
self.use_draft_model = False
if self.use_draft_model:
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")
)
draft_model_path = draft_model_path / draft_model_name
self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), [])
self.draft_model_dir = draft_model_path
self.draft_config = Config.from_directory(str(draft_model_path.resolve()))
self.draft_model = Model.from_config(self.draft_config)
logger.info(
f'Using draft model: {str(draft_model_path.resolve())}'
)
else:
self.draft_model = None
self.craft_cache = None
# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
@ -188,6 +221,10 @@ class ExllamaV3Container(BaseModelContainer):
self.cache_size = self.adjust_cache_size(user_cache_size)
self.cache = Cache(self.model, max_num_tokens=self.cache_size)
# Draft cache
if self.use_draft_model:
self.draft_cache = Cache(self.draft_model, max_num_tokens = self.cache_size)
# Max batch size
self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)
@ -211,8 +248,6 @@ class ExllamaV3Container(BaseModelContainer):
"template wasn't provided or auto-detected."
)
# TODO: speculative decoding
return self
def adjust_cache_size(self, cache_size):
@ -371,9 +406,16 @@ class ExllamaV3Container(BaseModelContainer):
async with self.load_condition:
self.load_condition.notify_all()
# TODO: Add draft loading
@torch.inference_mode()
def load_model_sync(self, progress_callback=None):
if self.use_draft_model:
for value in self.draft_model.load_gen(
reserve_per_device=self.autosplit_reserve,
callback=progress_callback,
):
if value:
yield value
for value in self.model.load_gen(
reserve_per_device=self.autosplit_reserve,
use_per_device=self.gpu_split,
@ -397,6 +439,8 @@ class ExllamaV3Container(BaseModelContainer):
self.generator = AsyncGenerator(
model=self.model,
cache=self.cache,
draft_model=self.draft_model,
draft_cache=self.draft_cache,
tokenizer=self.tokenizer,
max_batch_size=self.max_batch_size,
max_chunk_size=self.chunk_size,
@ -435,11 +479,16 @@ class ExllamaV3Container(BaseModelContainer):
self.model.unload()
self.model = None
self.config = None
self.cache = None
self.tokenizer = None
if self.use_draft_model:
self.draft_model.unload()
self.draft_model = None
self.draft_config = None
self.draft_cache = None
# Cleanup the generator from any pending jobs
if self.generator is not None:
await self.generator.close()