diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index ca8412f..fc10a3d 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -58,10 +58,11 @@ class ExllamaV3Container(BaseModelContainer): # Exl3 vars model: Optional[Model] cache: Optional[Cache] + draft_model: Optional[Model] + draft_cache: Optional[Cache] tokenizer: Optional[Tokenizer] config: Optional[Config] - generator: Optional[AsyncGenerator] = None - tokenizer_config: Optional[TokenizerConfig] = None + draft_config: Optional[Config] generator: Optional[AsyncGenerator] tokenizer_config: Optional[TokenizerConfig] @@ -93,8 +94,11 @@ class ExllamaV3Container(BaseModelContainer): self.model = None self.cache = None + self.draft_model = None + self.draft_cache = None self.tokenizer = None self.config = None + self.draft_config = None self.generator = None self.tokenizer_config = None @@ -135,6 +139,35 @@ class ExllamaV3Container(BaseModelContainer): # Fallback to 4096 since exl3 can't fetch from HF's config.json self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096) + # Prepare the draft model config if necessary + draft_args = unwrap(kwargs.get("draft_model"), {}) + draft_model_name = draft_args.get("draft_model_name") + self.use_draft_model = draft_args and draft_model_name + + # Always disable draft if params are incorrectly configured + if draft_args and draft_model_name is None: + logger.warning( + "Draft model is disabled because a model name " + "wasn't provided. Please check your config.yml!" + ) + self.use_draft_model = False + + if self.use_draft_model: + draft_model_path = pathlib.Path( + unwrap(draft_args.get("draft_model_dir"), "models") + ) + draft_model_path = draft_model_path / draft_model_name + self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), []) + self.draft_model_dir = draft_model_path + self.draft_config = Config.from_directory(str(draft_model_path.resolve())) + self.draft_model = Model.from_config(self.draft_config) + logger.info( + f'Using draft model: {str(draft_model_path.resolve())}' + ) + else: + self.draft_model = None + self.craft_cache = None + # Turn off GPU split if the user is using 1 GPU gpu_count = torch.cuda.device_count() gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) @@ -188,6 +221,10 @@ class ExllamaV3Container(BaseModelContainer): self.cache_size = self.adjust_cache_size(user_cache_size) self.cache = Cache(self.model, max_num_tokens=self.cache_size) + # Draft cache + if self.use_draft_model: + self.draft_cache = Cache(self.draft_model, max_num_tokens = self.cache_size) + # Max batch size self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256) @@ -211,8 +248,6 @@ class ExllamaV3Container(BaseModelContainer): "template wasn't provided or auto-detected." ) - # TODO: speculative decoding - return self def adjust_cache_size(self, cache_size): @@ -371,9 +406,16 @@ class ExllamaV3Container(BaseModelContainer): async with self.load_condition: self.load_condition.notify_all() - # TODO: Add draft loading @torch.inference_mode() def load_model_sync(self, progress_callback=None): + if self.use_draft_model: + for value in self.draft_model.load_gen( + reserve_per_device=self.autosplit_reserve, + callback=progress_callback, + ): + if value: + yield value + for value in self.model.load_gen( reserve_per_device=self.autosplit_reserve, use_per_device=self.gpu_split, @@ -397,6 +439,8 @@ class ExllamaV3Container(BaseModelContainer): self.generator = AsyncGenerator( model=self.model, cache=self.cache, + draft_model=self.draft_model, + draft_cache=self.draft_cache, tokenizer=self.tokenizer, max_batch_size=self.max_batch_size, max_chunk_size=self.chunk_size, @@ -435,11 +479,16 @@ class ExllamaV3Container(BaseModelContainer): self.model.unload() self.model = None - self.config = None self.cache = None self.tokenizer = None + if self.use_draft_model: + self.draft_model.unload() + self.draft_model = None + self.draft_config = None + self.draft_cache = None + # Cleanup the generator from any pending jobs if self.generator is not None: await self.generator.close()