diff --git a/config_sample.yml b/config_sample.yml index e20d6c9..263b08e 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -60,6 +60,7 @@ model: # Rope scale (default: 1.0) # Same thing as compress_pos_emb # Only use if your model was trained on long context with rope (check config.json) + # Leave blank to pull the value from the model rope_scale: 1.0 # Rope alpha (default: 1.0) diff --git a/model.py b/model.py index a1177ab..66b7eb5 100644 --- a/model.py +++ b/model.py @@ -126,7 +126,9 @@ class ModelContainer: self.config.max_seq_len = target_max_seq_len # Set the rope scale - self.config.scale_pos_emb = unwrap(kwargs.get("rope_scale"), 1.0) + self.config.scale_pos_emb = unwrap( + kwargs.get("rope_scale"), self.config.scale_pos_emb + ) # Automatically calculate rope alpha self.config.scale_alpha_value = unwrap(