Model: Add speculative decoding support via config

Speculative decoding makes use of draft models that ingest the prompt
before forwarding it to the main model.

Add options in the config to support this. API options will occur
in a different commit.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2023-11-18 01:38:54 -05:00
parent 78a6587b95
commit 27ebec3b35
3 changed files with 38 additions and 14 deletions

View file

@ -13,11 +13,11 @@ network:
# Options for model overrides and loading
model:
# Overrides the directory to look for models (default: models)
# Windows users: DO NOT put this path in quotes! This directory will be invalid otherwise.
# Windows users, DO NOT put this path in quotes! This directory will be invalid otherwise.
# model_dir: your model directory path
# An initial model to load. Make sure the model is located in the model directory!
# A model can be loaded later via the API. This does not have to be specified
# A model can be loaded later via the API.
# model_name: A model name
# Set the following to enable speculative decoding
@ -46,4 +46,16 @@ model:
low_mem: False
# Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16)
# cache_mode: "FP8"
# cache_mode: FP16
# Options for draft models (speculative decoding). This will use more VRAM!
# draft:
# Overrides the directory to look for draft (default: models)
# draft_model_dir: Your draft model directory path
# An initial draft model to load. Make sure this model is located in the model directory!
# A draft model can be loaded later via the API.
# draft_model_name: A model name
# Rope parameters for draft models (default: 1.0)
# draft_rope_alpha: 1.0

16
main.py
View file

@ -56,7 +56,7 @@ async def list_models():
else:
model_path = pathlib.Path("models")
models = get_model_list(model_path)
models = get_model_list(model_path.resolve())
return models
@ -76,7 +76,7 @@ async def load_model(data: ModelLoadRequest):
def generator():
global model_container
model_config = config.get("model", {})
model_config = config.get("model") or {}
if "model_dir" in model_config:
model_path = pathlib.Path(model_config["model_dir"])
else:
@ -84,7 +84,7 @@ async def load_model(data: ModelLoadRequest):
model_path = model_path / data.name
model_container = ModelContainer(model_path, False, **data.dict())
model_container = ModelContainer(model_path.resolve(), False, **data.dict())
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
@ -217,12 +217,12 @@ if __name__ == "__main__":
# If an initial model name is specified, create a container and load the model
model_config = config.get("model", {})
model_config = config.get("model") or {}
if "model_name" in model_config:
model_path = pathlib.Path(model_config.get("model_dir", "models"))
model_path = model_path / model_config["model_name"]
model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_path = model_path / model_config.get("model_name")
model_container = ModelContainer(model_path, False, **model_config)
model_container = ModelContainer(model_path.resolve(), False, **model_config)
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
@ -233,7 +233,7 @@ if __name__ == "__main__":
else:
loading_bar.next()
network_config = config.get("network", {})
network_config = config.get("network") or {}
uvicorn.run(
app,
host=network_config.get("host", "127.0.0.1"),

View file

@ -82,17 +82,29 @@ class ModelContainer:
self.config.max_input_len = chunk_size
self.config.max_attn_size = chunk_size ** 2
self.draft_enabled = "draft_model_dir" in kwargs
draft_config = kwargs.get("draft") or {}
draft_model_name = draft_config.get("draft_model_name")
enable_draft = bool(draft_config) and draft_model_name is not None
if bool(draft_config) and draft_model_name is None:
print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.")
self.draft_enabled = False
else:
self.draft_enabled = enable_draft
if self.draft_enabled:
self.draft_config = ExLlamaV2Config()
self.draft_config.model_dir = kwargs["draft_model_dir"]
draft_model_path = pathlib.Path(kwargs.get("draft_model_dir") or "models")
draft_model_path = draft_model_path / draft_model_name
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
self.draft_config.max_seq_len = self.config.max_seq_len
if "draft_rope_alpha" in kwargs:
self.draft_config.scale_alpha_value = kwargs["draft_rope_alpha"]
self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or 1
else:
ratio = self.config.max_seq_len / self.draft_config.max_seq_len
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2