diff --git a/TabbyAPI_Colab_Example.ipynb b/TabbyAPI_Colab_Example.ipynb index 9ad687b..c36fab5 100644 --- a/TabbyAPI_Colab_Example.ipynb +++ b/TabbyAPI_Colab_Example.ipynb @@ -50,7 +50,6 @@ "# @title # Install and download model { display-mode: \"form\" }\n", "# @markdown ---\n", "# @markdown Select model:\n", - "# Select model and branch\n", "repo_id = \"royallab/Noromaid-13b-v0.1.1-exl2\" # @param {type:\"string\"}\n", "revision = \"4bpw\" # @param {type:\"string\"}\n", "# @markdown ---\n", @@ -96,7 +95,8 @@ "RopeAlpha = 1.0 # @param {type:\"number\"}\n", "# @markdown ---\n", "# @markdown Draft model parameters (optional, for speculative decoding):\n", - "DraftRopeAlpha = None # @param {type:\"number\"}\n", + "DraftRopeScale = 1.0 # @param {type:\"number\"}\n", + "DraftRopeAlpha = 1.0 # @param {type:\"number\"}\n", "# @markdown ---\n", "# @markdown Misc options:\n", "CacheMode = \"FP16\" # @param [\"FP8\", \"FP16\"] {type:\"string\"}\n", @@ -105,7 +105,7 @@ "# @markdown ---\n", "# @markdown To connect, make note of the cloudflared URL and your auto-generated API key after launching and provide it to your preferred frontend.\n", "\n", - "# Setup Config - edit parameters to fit your needs\n", + "# Setup Config\n", "%cd /content/tabbyAPI/\n", "\n", "write = f'''\n", @@ -169,6 +169,7 @@ " draft_model_name: {draft_model}\n", "\n", " # Rope parameters for draft models (default: 1.0)\n", + " draft_rope_scale: {DraftRopeScale}\n", " draft_rope_alpha: {DraftRopeAlpha}\n", "'''\n", "with open(\"./config.yml\", \"w\") as file:\n", diff --git a/config_sample.yml b/config_sample.yml index fe723a5..f1524cc 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -58,4 +58,5 @@ model: draft_model_name: A model name # Rope parameters for draft models (default: 1.0) + draft_rope_scale: 1.0 draft_rope_alpha: 1.0 diff --git a/model.py b/model.py index 6d3d8c7..929916f 100644 --- a/model.py +++ b/model.py @@ -110,7 +110,8 @@ class ModelContainer: self.draft_config.model_dir = str(draft_model_path.resolve()) self.draft_config.prepare() - + + if "draft_rope_scale" in kwargs: self.draft_config.scale_pos_emb = kwargs["draft_rope_scale"] self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len) self.draft_config.max_seq_len = self.config.max_seq_len