Model: Add draft model/speculative decoding
This commit is contained in:
parent
1db2cb99cb
commit
92ea7ee7cd
1 changed files with 55 additions and 6 deletions
|
|
@ -58,10 +58,11 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
# Exl3 vars
|
||||
model: Optional[Model]
|
||||
cache: Optional[Cache]
|
||||
draft_model: Optional[Model]
|
||||
draft_cache: Optional[Cache]
|
||||
tokenizer: Optional[Tokenizer]
|
||||
config: Optional[Config]
|
||||
generator: Optional[AsyncGenerator] = None
|
||||
tokenizer_config: Optional[TokenizerConfig] = None
|
||||
draft_config: Optional[Config]
|
||||
generator: Optional[AsyncGenerator]
|
||||
tokenizer_config: Optional[TokenizerConfig]
|
||||
|
||||
|
|
@ -93,8 +94,11 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
|
||||
self.model = None
|
||||
self.cache = None
|
||||
self.draft_model = None
|
||||
self.draft_cache = None
|
||||
self.tokenizer = None
|
||||
self.config = None
|
||||
self.draft_config = None
|
||||
self.generator = None
|
||||
self.tokenizer_config = None
|
||||
|
||||
|
|
@ -135,6 +139,35 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
# Fallback to 4096 since exl3 can't fetch from HF's config.json
|
||||
self.max_seq_len = unwrap(kwargs.get("max_seq_len"), 4096)
|
||||
|
||||
# Prepare the draft model config if necessary
|
||||
draft_args = unwrap(kwargs.get("draft_model"), {})
|
||||
draft_model_name = draft_args.get("draft_model_name")
|
||||
self.use_draft_model = draft_args and draft_model_name
|
||||
|
||||
# Always disable draft if params are incorrectly configured
|
||||
if draft_args and draft_model_name is None:
|
||||
logger.warning(
|
||||
"Draft model is disabled because a model name "
|
||||
"wasn't provided. Please check your config.yml!"
|
||||
)
|
||||
self.use_draft_model = False
|
||||
|
||||
if self.use_draft_model:
|
||||
draft_model_path = pathlib.Path(
|
||||
unwrap(draft_args.get("draft_model_dir"), "models")
|
||||
)
|
||||
draft_model_path = draft_model_path / draft_model_name
|
||||
self.draft_gpu_split = unwrap(draft_args.get("draft_gpu_split"), [])
|
||||
self.draft_model_dir = draft_model_path
|
||||
self.draft_config = Config.from_directory(str(draft_model_path.resolve()))
|
||||
self.draft_model = Model.from_config(self.draft_config)
|
||||
logger.info(
|
||||
f'Using draft model: {str(draft_model_path.resolve())}'
|
||||
)
|
||||
else:
|
||||
self.draft_model = None
|
||||
self.craft_cache = None
|
||||
|
||||
# Turn off GPU split if the user is using 1 GPU
|
||||
gpu_count = torch.cuda.device_count()
|
||||
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
|
||||
|
|
@ -188,6 +221,10 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
self.cache_size = self.adjust_cache_size(user_cache_size)
|
||||
self.cache = Cache(self.model, max_num_tokens=self.cache_size)
|
||||
|
||||
# Draft cache
|
||||
if self.use_draft_model:
|
||||
self.draft_cache = Cache(self.draft_model, max_num_tokens = self.cache_size)
|
||||
|
||||
# Max batch size
|
||||
self.max_batch_size = unwrap(kwargs.get("max_batch_size"), 256)
|
||||
|
||||
|
|
@ -211,8 +248,6 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
"template wasn't provided or auto-detected."
|
||||
)
|
||||
|
||||
# TODO: speculative decoding
|
||||
|
||||
return self
|
||||
|
||||
def adjust_cache_size(self, cache_size):
|
||||
|
|
@ -371,9 +406,16 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
async with self.load_condition:
|
||||
self.load_condition.notify_all()
|
||||
|
||||
# TODO: Add draft loading
|
||||
@torch.inference_mode()
|
||||
def load_model_sync(self, progress_callback=None):
|
||||
if self.use_draft_model:
|
||||
for value in self.draft_model.load_gen(
|
||||
reserve_per_device=self.autosplit_reserve,
|
||||
callback=progress_callback,
|
||||
):
|
||||
if value:
|
||||
yield value
|
||||
|
||||
for value in self.model.load_gen(
|
||||
reserve_per_device=self.autosplit_reserve,
|
||||
use_per_device=self.gpu_split,
|
||||
|
|
@ -397,6 +439,8 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
self.generator = AsyncGenerator(
|
||||
model=self.model,
|
||||
cache=self.cache,
|
||||
draft_model=self.draft_model,
|
||||
draft_cache=self.draft_cache,
|
||||
tokenizer=self.tokenizer,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_chunk_size=self.chunk_size,
|
||||
|
|
@ -435,11 +479,16 @@ class ExllamaV3Container(BaseModelContainer):
|
|||
|
||||
self.model.unload()
|
||||
self.model = None
|
||||
|
||||
self.config = None
|
||||
self.cache = None
|
||||
self.tokenizer = None
|
||||
|
||||
if self.use_draft_model:
|
||||
self.draft_model.unload()
|
||||
self.draft_model = None
|
||||
self.draft_config = None
|
||||
self.draft_cache = None
|
||||
|
||||
# Cleanup the generator from any pending jobs
|
||||
if self.generator is not None:
|
||||
await self.generator.close()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue