Merge pull request #244 from DocShotgun/draft-flash-attn-fix

Fix draft model non-FA2 fallback
This commit is contained in:
Brian 2024-11-16 21:23:42 -05:00 committed by GitHub
commit dfc889952a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -159,7 +159,6 @@ class ExllamaV2Container:
if enable_draft:
self.draft_config = ExLlamaV2Config()
self.draft_config.no_flash_attn = self.config.no_flash_attn
draft_model_path = pathlib.Path(
unwrap(draft_args.get("draft_model_dir"), "models")
)
@ -253,6 +252,8 @@ class ExllamaV2Container:
or not supports_paged_attn()
):
self.config.no_flash_attn = True
if self.draft_config:
self.draft_config.no_flash_attn = True
self.paged = False
self.max_batch_size = 1
torch.backends.cuda.enable_flash_sdp(False)