Add automatic NTK-aware alpha scaling to model
* enables automatic calculation of NTK-aware alpha scaling for models if the rope_alpha arg is not passed in the config, using the same formula used for draft models
This commit is contained in:
parent
61f6e51fdb
commit
1c398b0be7
1 changed files with 9 additions and 2 deletions
11
model.py
11
model.py
|
|
@ -69,10 +69,17 @@ class ModelContainer:
|
|||
self.config = ExLlamaV2Config()
|
||||
self.config.model_dir = str(model_directory.resolve())
|
||||
self.config.prepare()
|
||||
|
||||
base_seq_len = self.config.max_seq_len
|
||||
|
||||
if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"]
|
||||
if "rope_scale" in kwargs: self.config.scale_pos_emb = kwargs["rope_scale"]
|
||||
if "rope_alpha" in kwargs: self.config.scale_alpha_value = kwargs["rope_alpha"]
|
||||
if "rope_alpha" in kwargs:
|
||||
self.config.scale_alpha_value = kwargs["rope_alpha"]
|
||||
else:
|
||||
ratio = self.config.max_seq_len / base_seq_len
|
||||
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2
|
||||
self.config.scale_alpha_value = alpha
|
||||
if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"]
|
||||
|
||||
if "low_mem" in kwargs and kwargs["low_mem"]:
|
||||
|
|
@ -102,7 +109,7 @@ class ModelContainer:
|
|||
self.draft_config.prepare()
|
||||
|
||||
if "draft_rope_alpha" in kwargs:
|
||||
self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or 1
|
||||
self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha")
|
||||
else:
|
||||
ratio = self.config.max_seq_len / self.draft_config.max_seq_len
|
||||
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue