Expose draft_rope_scale

This commit is contained in:
DocShotgun 2023-12-05 12:59:32 -08:00 committed by GitHub
parent e085b806e8
commit 39f7a2aabd
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 7 additions and 4 deletions

View file

@ -50,7 +50,6 @@
"# @title # Install and download model { display-mode: \"form\" }\n",
"# @markdown ---\n",
"# @markdown Select model:\n",
"# Select model and branch\n",
"repo_id = \"royallab/Noromaid-13b-v0.1.1-exl2\" # @param {type:\"string\"}\n",
"revision = \"4bpw\" # @param {type:\"string\"}\n",
"# @markdown ---\n",
@ -96,7 +95,8 @@
"RopeAlpha = 1.0 # @param {type:\"number\"}\n",
"# @markdown ---\n",
"# @markdown Draft model parameters (optional, for speculative decoding):\n",
"DraftRopeAlpha = None # @param {type:\"number\"}\n",
"DraftRopeScale = 1.0 # @param {type:\"number\"}\n",
"DraftRopeAlpha = 1.0 # @param {type:\"number\"}\n",
"# @markdown ---\n",
"# @markdown Misc options:\n",
"CacheMode = \"FP16\" # @param [\"FP8\", \"FP16\"] {type:\"string\"}\n",
@ -105,7 +105,7 @@
"# @markdown ---\n",
"# @markdown To connect, make note of the cloudflared URL and your auto-generated API key after launching and provide it to your preferred frontend.\n",
"\n",
"# Setup Config - edit parameters to fit your needs\n",
"# Setup Config\n",
"%cd /content/tabbyAPI/\n",
"\n",
"write = f'''\n",
@ -169,6 +169,7 @@
" draft_model_name: {draft_model}\n",
"\n",
" # Rope parameters for draft models (default: 1.0)\n",
" draft_rope_scale: {DraftRopeScale}\n",
" draft_rope_alpha: {DraftRopeAlpha}\n",
"'''\n",
"with open(\"./config.yml\", \"w\") as file:\n",

View file

@ -58,4 +58,5 @@ model:
draft_model_name: A model name
# Rope parameters for draft models (default: 1.0)
draft_rope_scale: 1.0
draft_rope_alpha: 1.0

View file

@ -110,7 +110,8 @@ class ModelContainer:
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
if "draft_rope_scale" in kwargs: self.draft_config.scale_pos_emb = kwargs["draft_rope_scale"]
self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len)
self.draft_config.max_seq_len = self.config.max_seq_len