Model: Add speculative decoding support via config
Speculative decoding makes use of draft models that ingest the prompt before forwarding it to the main model. Add options in the config to support this. API options will occur in a different commit. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
78a6587b95
commit
27ebec3b35
3 changed files with 38 additions and 14 deletions
|
|
@ -13,11 +13,11 @@ network:
|
|||
# Options for model overrides and loading
|
||||
model:
|
||||
# Overrides the directory to look for models (default: models)
|
||||
# Windows users: DO NOT put this path in quotes! This directory will be invalid otherwise.
|
||||
# Windows users, DO NOT put this path in quotes! This directory will be invalid otherwise.
|
||||
# model_dir: your model directory path
|
||||
|
||||
# An initial model to load. Make sure the model is located in the model directory!
|
||||
# A model can be loaded later via the API. This does not have to be specified
|
||||
# A model can be loaded later via the API.
|
||||
# model_name: A model name
|
||||
|
||||
# Set the following to enable speculative decoding
|
||||
|
|
@ -46,4 +46,16 @@ model:
|
|||
low_mem: False
|
||||
|
||||
# Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16)
|
||||
# cache_mode: "FP8"
|
||||
# cache_mode: FP16
|
||||
|
||||
# Options for draft models (speculative decoding). This will use more VRAM!
|
||||
# draft:
|
||||
# Overrides the directory to look for draft (default: models)
|
||||
# draft_model_dir: Your draft model directory path
|
||||
|
||||
# An initial draft model to load. Make sure this model is located in the model directory!
|
||||
# A draft model can be loaded later via the API.
|
||||
# draft_model_name: A model name
|
||||
|
||||
# Rope parameters for draft models (default: 1.0)
|
||||
# draft_rope_alpha: 1.0
|
||||
|
|
|
|||
16
main.py
16
main.py
|
|
@ -56,7 +56,7 @@ async def list_models():
|
|||
else:
|
||||
model_path = pathlib.Path("models")
|
||||
|
||||
models = get_model_list(model_path)
|
||||
models = get_model_list(model_path.resolve())
|
||||
|
||||
return models
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ async def load_model(data: ModelLoadRequest):
|
|||
def generator():
|
||||
global model_container
|
||||
|
||||
model_config = config.get("model", {})
|
||||
model_config = config.get("model") or {}
|
||||
if "model_dir" in model_config:
|
||||
model_path = pathlib.Path(model_config["model_dir"])
|
||||
else:
|
||||
|
|
@ -84,7 +84,7 @@ async def load_model(data: ModelLoadRequest):
|
|||
|
||||
model_path = model_path / data.name
|
||||
|
||||
model_container = ModelContainer(model_path, False, **data.dict())
|
||||
model_container = ModelContainer(model_path.resolve(), False, **data.dict())
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
for (module, modules) in load_status:
|
||||
if module == 0:
|
||||
|
|
@ -217,12 +217,12 @@ if __name__ == "__main__":
|
|||
|
||||
# If an initial model name is specified, create a container and load the model
|
||||
|
||||
model_config = config.get("model", {})
|
||||
model_config = config.get("model") or {}
|
||||
if "model_name" in model_config:
|
||||
model_path = pathlib.Path(model_config.get("model_dir", "models"))
|
||||
model_path = model_path / model_config["model_name"]
|
||||
model_path = pathlib.Path(model_config.get("model_dir") or "models")
|
||||
model_path = model_path / model_config.get("model_name")
|
||||
|
||||
model_container = ModelContainer(model_path, False, **model_config)
|
||||
model_container = ModelContainer(model_path.resolve(), False, **model_config)
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
for (module, modules) in load_status:
|
||||
if module == 0:
|
||||
|
|
@ -233,7 +233,7 @@ if __name__ == "__main__":
|
|||
else:
|
||||
loading_bar.next()
|
||||
|
||||
network_config = config.get("network", {})
|
||||
network_config = config.get("network") or {}
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=network_config.get("host", "127.0.0.1"),
|
||||
|
|
|
|||
18
model.py
18
model.py
|
|
@ -82,17 +82,29 @@ class ModelContainer:
|
|||
self.config.max_input_len = chunk_size
|
||||
self.config.max_attn_size = chunk_size ** 2
|
||||
|
||||
self.draft_enabled = "draft_model_dir" in kwargs
|
||||
draft_config = kwargs.get("draft") or {}
|
||||
draft_model_name = draft_config.get("draft_model_name")
|
||||
enable_draft = bool(draft_config) and draft_model_name is not None
|
||||
|
||||
if bool(draft_config) and draft_model_name is None:
|
||||
print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.")
|
||||
self.draft_enabled = False
|
||||
else:
|
||||
self.draft_enabled = enable_draft
|
||||
|
||||
if self.draft_enabled:
|
||||
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
self.draft_config.model_dir = kwargs["draft_model_dir"]
|
||||
draft_model_path = pathlib.Path(kwargs.get("draft_model_dir") or "models")
|
||||
draft_model_path = draft_model_path / draft_model_name
|
||||
|
||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
if "draft_rope_alpha" in kwargs:
|
||||
self.draft_config.scale_alpha_value = kwargs["draft_rope_alpha"]
|
||||
self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or 1
|
||||
else:
|
||||
ratio = self.config.max_seq_len / self.draft_config.max_seq_len
|
||||
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue