Merge pull request #4 from waldfee/config_samples
Adds draft model support to config.yml
This commit is contained in:
commit
b2410a0436
3 changed files with 44 additions and 13 deletions
|
|
@ -13,13 +13,17 @@ network:
|
|||
# Options for model overrides and loading
|
||||
model:
|
||||
# Overrides the directory to look for models (default: models)
|
||||
# Windows users: DO NOT put this path in quotes! This directory will be invalid otherwise.
|
||||
# Windows users, DO NOT put this path in quotes! This directory will be invalid otherwise.
|
||||
# model_dir: your model directory path
|
||||
|
||||
# An initial model to load. Make sure the model is located in the model directory!
|
||||
# A model can be loaded later via the API. This does not have to be specified
|
||||
# A model can be loaded later via the API.
|
||||
# model_name: A model name
|
||||
|
||||
# Set the following to enable speculative decoding
|
||||
# draft_model_dir: your model directory path to use as draft model (path is independent from model_dir)
|
||||
# draft_rope_alpha: 1.0 (default: the draft model's alpha value is calculated automatically to scale to the size of the full model.)
|
||||
|
||||
# The below parameters apply only if model_name is set
|
||||
|
||||
# Maximum model context length (default: 4096)
|
||||
|
|
@ -40,3 +44,18 @@ model:
|
|||
|
||||
# Enable low vram optimizations in exllamav2 (default: False)
|
||||
low_mem: False
|
||||
|
||||
# Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16)
|
||||
# cache_mode: FP16
|
||||
|
||||
# Options for draft models (speculative decoding). This will use more VRAM!
|
||||
# draft:
|
||||
# Overrides the directory to look for draft (default: models)
|
||||
# draft_model_dir: Your draft model directory path
|
||||
|
||||
# An initial draft model to load. Make sure this model is located in the model directory!
|
||||
# A draft model can be loaded later via the API.
|
||||
# draft_model_name: A model name
|
||||
|
||||
# Rope parameters for draft models (default: 1.0)
|
||||
# draft_rope_alpha: 1.0
|
||||
|
|
|
|||
16
main.py
16
main.py
|
|
@ -56,7 +56,7 @@ async def list_models():
|
|||
else:
|
||||
model_path = pathlib.Path("models")
|
||||
|
||||
models = get_model_list(model_path)
|
||||
models = get_model_list(model_path.resolve())
|
||||
|
||||
return models
|
||||
|
||||
|
|
@ -76,7 +76,7 @@ async def load_model(data: ModelLoadRequest):
|
|||
def generator():
|
||||
global model_container
|
||||
|
||||
model_config = config.get("model", {})
|
||||
model_config = config.get("model") or {}
|
||||
if "model_dir" in model_config:
|
||||
model_path = pathlib.Path(model_config["model_dir"])
|
||||
else:
|
||||
|
|
@ -84,7 +84,7 @@ async def load_model(data: ModelLoadRequest):
|
|||
|
||||
model_path = model_path / data.name
|
||||
|
||||
model_container = ModelContainer(model_path, False, **data.dict())
|
||||
model_container = ModelContainer(model_path.resolve(), False, **data.dict())
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
for (module, modules) in load_status:
|
||||
if module == 0:
|
||||
|
|
@ -217,12 +217,12 @@ if __name__ == "__main__":
|
|||
|
||||
# If an initial model name is specified, create a container and load the model
|
||||
|
||||
model_config = config.get("model", {})
|
||||
model_config = config.get("model") or {}
|
||||
if "model_name" in model_config:
|
||||
model_path = pathlib.Path(model_config.get("model_dir", "models"))
|
||||
model_path = model_path / model_config["model_name"]
|
||||
model_path = pathlib.Path(model_config.get("model_dir") or "models")
|
||||
model_path = model_path / model_config.get("model_name")
|
||||
|
||||
model_container = ModelContainer(model_path, False, **model_config)
|
||||
model_container = ModelContainer(model_path.resolve(), False, **model_config)
|
||||
load_status = model_container.load_gen(load_progress)
|
||||
for (module, modules) in load_status:
|
||||
if module == 0:
|
||||
|
|
@ -233,7 +233,7 @@ if __name__ == "__main__":
|
|||
else:
|
||||
loading_bar.next()
|
||||
|
||||
network_config = config.get("network", {})
|
||||
network_config = config.get("network") or {}
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=network_config.get("host", "127.0.0.1"),
|
||||
|
|
|
|||
18
model.py
18
model.py
|
|
@ -82,17 +82,29 @@ class ModelContainer:
|
|||
self.config.max_input_len = chunk_size
|
||||
self.config.max_attn_size = chunk_size ** 2
|
||||
|
||||
self.draft_enabled = "draft_model_dir" in kwargs
|
||||
draft_config = kwargs.get("draft") or {}
|
||||
draft_model_name = draft_config.get("draft_model_name")
|
||||
enable_draft = bool(draft_config) and draft_model_name is not None
|
||||
|
||||
if bool(draft_config) and draft_model_name is None:
|
||||
print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.")
|
||||
self.draft_enabled = False
|
||||
else:
|
||||
self.draft_enabled = enable_draft
|
||||
|
||||
if self.draft_enabled:
|
||||
|
||||
self.draft_config = ExLlamaV2Config()
|
||||
self.draft_config.model_dir = kwargs["draft_model_dir"]
|
||||
draft_model_path = pathlib.Path(kwargs.get("draft_model_dir") or "models")
|
||||
draft_model_path = draft_model_path / draft_model_name
|
||||
|
||||
self.draft_config.model_dir = str(draft_model_path.resolve())
|
||||
self.draft_config.prepare()
|
||||
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
if "draft_rope_alpha" in kwargs:
|
||||
self.draft_config.scale_alpha_value = kwargs["draft_rope_alpha"]
|
||||
self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or 1
|
||||
else:
|
||||
ratio = self.config.max_seq_len / self.draft_config.max_seq_len
|
||||
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue