diff --git a/model.py b/model.py index 564965b..25bfb25 100644 --- a/model.py +++ b/model.py @@ -69,20 +69,18 @@ class ModelContainer: self.config = ExLlamaV2Config() self.config.model_dir = str(model_directory.resolve()) self.config.prepare() - + + # Grab the base model's sequence length before overrides for rope calculations base_seq_len = self.config.max_seq_len + # Then override the max_seq_len if present if "max_seq_len" in kwargs: self.config.max_seq_len = kwargs["max_seq_len"] if "rope_scale" in kwargs: self.config.scale_pos_emb = kwargs["rope_scale"] - if "rope_alpha" in kwargs: - self.config.scale_alpha_value = kwargs["rope_alpha"] - else: - ratio = self.config.max_seq_len / base_seq_len - alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2 - if ratio == 1: alpha = 1.0 - self.config.scale_alpha_value = alpha - if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"] + # Automatically calculate rope alpha + self.config.scale_alpha_value = kwargs.get("rope_alpha") or self.calculate_rope_alpha(base_seq_len) + + if "no_flash_attn" in kwargs: self.config.no_flash_attn = kwargs["no_flash_attn"] if "low_mem" in kwargs and kwargs["low_mem"]: self.config.set_low_mem() @@ -109,14 +107,7 @@ class ModelContainer: self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() - if "draft_rope_alpha" in kwargs: - self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") - else: - ratio = self.config.max_seq_len / self.draft_config.max_seq_len - alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2 - if ratio == 1: alpha = 1.0 - self.draft_config.scale_alpha_value = alpha - + self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len) self.draft_config.max_seq_len = self.config.max_seq_len if "chunk_size" in kwargs: @@ -124,6 +115,13 @@ class ModelContainer: self.draft_config.max_attn_size = kwargs["chunk_size"] ** 2 + def calculate_rope_alpha(self, base_seq_len): + ratio = self.config.max_seq_len / base_seq_len + + # Default to a 1 alpha if the sequence length is ever less than or equal to 1 + alpha = 1 if ratio <= 1.0 else -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2 + return alpha + def get_model_path(self): model_path = pathlib.Path(self.config.model_dir) return model_path