Expose draft_rope_scale
This commit is contained in:
parent
e085b806e8
commit
39f7a2aabd
3 changed files with 7 additions and 4 deletions
|
|
@ -50,7 +50,6 @@
|
|||
"# @title # Install and download model { display-mode: \"form\" }\n",
|
||||
"# @markdown ---\n",
|
||||
"# @markdown Select model:\n",
|
||||
"# Select model and branch\n",
|
||||
"repo_id = \"royallab/Noromaid-13b-v0.1.1-exl2\" # @param {type:\"string\"}\n",
|
||||
"revision = \"4bpw\" # @param {type:\"string\"}\n",
|
||||
"# @markdown ---\n",
|
||||
|
|
@ -96,7 +95,8 @@
|
|||
"RopeAlpha = 1.0 # @param {type:\"number\"}\n",
|
||||
"# @markdown ---\n",
|
||||
"# @markdown Draft model parameters (optional, for speculative decoding):\n",
|
||||
"DraftRopeAlpha = None # @param {type:\"number\"}\n",
|
||||
"DraftRopeScale = 1.0 # @param {type:\"number\"}\n",
|
||||
"DraftRopeAlpha = 1.0 # @param {type:\"number\"}\n",
|
||||
"# @markdown ---\n",
|
||||
"# @markdown Misc options:\n",
|
||||
"CacheMode = \"FP16\" # @param [\"FP8\", \"FP16\"] {type:\"string\"}\n",
|
||||
|
|
@ -105,7 +105,7 @@
|
|||
"# @markdown ---\n",
|
||||
"# @markdown To connect, make note of the cloudflared URL and your auto-generated API key after launching and provide it to your preferred frontend.\n",
|
||||
"\n",
|
||||
"# Setup Config - edit parameters to fit your needs\n",
|
||||
"# Setup Config\n",
|
||||
"%cd /content/tabbyAPI/\n",
|
||||
"\n",
|
||||
"write = f'''\n",
|
||||
|
|
@ -169,6 +169,7 @@
|
|||
" draft_model_name: {draft_model}\n",
|
||||
"\n",
|
||||
" # Rope parameters for draft models (default: 1.0)\n",
|
||||
" draft_rope_scale: {DraftRopeScale}\n",
|
||||
" draft_rope_alpha: {DraftRopeAlpha}\n",
|
||||
"'''\n",
|
||||
"with open(\"./config.yml\", \"w\") as file:\n",
|
||||
|
|
|
|||
|
|
@ -58,4 +58,5 @@ model:
|
|||
draft_model_name: A model name
|
||||
|
||||
# Rope parameters for draft models (default: 1.0)
|
||||
draft_rope_scale: 1.0
|
||||
draft_rope_alpha: 1.0
|
||||
|
|
|
|||
3
model.py
3
model.py
|
|
@ -110,7 +110,8 @@ class ModelContainer:
|
|||
|
||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
|
||||
|
||||
if "draft_rope_scale" in kwargs: self.draft_config.scale_pos_emb = kwargs["draft_rope_scale"]
|
||||
self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or self.calculate_rope_alpha(self.draft_config.max_seq_len)
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue