Merge pull request #4 from waldfee/config_samples

Adds draft model support to config.yml
This commit is contained in:
Brian Dashore 2023-11-18 13:16:23 -05:00 committed by GitHub
commit b2410a0436
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 44 additions and 13 deletions

View file

@ -13,13 +13,17 @@ network:
# Options for model overrides and loading
model:
# Overrides the directory to look for models (default: models)
# Windows users: DO NOT put this path in quotes! This directory will be invalid otherwise.
# Windows users, DO NOT put this path in quotes! This directory will be invalid otherwise.
# model_dir: your model directory path
# An initial model to load. Make sure the model is located in the model directory!
# A model can be loaded later via the API. This does not have to be specified
# A model can be loaded later via the API.
# model_name: A model name
# Set the following to enable speculative decoding
# draft_model_dir: your model directory path to use as draft model (path is independent from model_dir)
# draft_rope_alpha: 1.0 (default: the draft model's alpha value is calculated automatically to scale to the size of the full model.)
# The below parameters apply only if model_name is set
# Maximum model context length (default: 4096)
@ -40,3 +44,18 @@ model:
# Enable low vram optimizations in exllamav2 (default: False)
low_mem: False
# Enable 8 bit cache mode for VRAM savings (slight performance hit). Possible values FP16, FP8. (default: FP16)
# cache_mode: FP16
# Options for draft models (speculative decoding). This will use more VRAM!
# draft:
# Overrides the directory to look for draft (default: models)
# draft_model_dir: Your draft model directory path
# An initial draft model to load. Make sure this model is located in the model directory!
# A draft model can be loaded later via the API.
# draft_model_name: A model name
# Rope parameters for draft models (default: 1.0)
# draft_rope_alpha: 1.0

16
main.py
View file

@ -56,7 +56,7 @@ async def list_models():
else:
model_path = pathlib.Path("models")
models = get_model_list(model_path)
models = get_model_list(model_path.resolve())
return models
@ -76,7 +76,7 @@ async def load_model(data: ModelLoadRequest):
def generator():
global model_container
model_config = config.get("model", {})
model_config = config.get("model") or {}
if "model_dir" in model_config:
model_path = pathlib.Path(model_config["model_dir"])
else:
@ -84,7 +84,7 @@ async def load_model(data: ModelLoadRequest):
model_path = model_path / data.name
model_container = ModelContainer(model_path, False, **data.dict())
model_container = ModelContainer(model_path.resolve(), False, **data.dict())
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
@ -217,12 +217,12 @@ if __name__ == "__main__":
# If an initial model name is specified, create a container and load the model
model_config = config.get("model", {})
model_config = config.get("model") or {}
if "model_name" in model_config:
model_path = pathlib.Path(model_config.get("model_dir", "models"))
model_path = model_path / model_config["model_name"]
model_path = pathlib.Path(model_config.get("model_dir") or "models")
model_path = model_path / model_config.get("model_name")
model_container = ModelContainer(model_path, False, **model_config)
model_container = ModelContainer(model_path.resolve(), False, **model_config)
load_status = model_container.load_gen(load_progress)
for (module, modules) in load_status:
if module == 0:
@ -233,7 +233,7 @@ if __name__ == "__main__":
else:
loading_bar.next()
network_config = config.get("network", {})
network_config = config.get("network") or {}
uvicorn.run(
app,
host=network_config.get("host", "127.0.0.1"),

View file

@ -82,17 +82,29 @@ class ModelContainer:
self.config.max_input_len = chunk_size
self.config.max_attn_size = chunk_size ** 2
self.draft_enabled = "draft_model_dir" in kwargs
draft_config = kwargs.get("draft") or {}
draft_model_name = draft_config.get("draft_model_name")
enable_draft = bool(draft_config) and draft_model_name is not None
if bool(draft_config) and draft_model_name is None:
print("A draft config was found but a model name was not given. Please check your config.yml! Skipping draft load.")
self.draft_enabled = False
else:
self.draft_enabled = enable_draft
if self.draft_enabled:
self.draft_config = ExLlamaV2Config()
self.draft_config.model_dir = kwargs["draft_model_dir"]
draft_model_path = pathlib.Path(kwargs.get("draft_model_dir") or "models")
draft_model_path = draft_model_path / draft_model_name
self.draft_config.model_dir = str(draft_model_path.resolve())
self.draft_config.prepare()
self.draft_config.max_seq_len = self.config.max_seq_len
if "draft_rope_alpha" in kwargs:
self.draft_config.scale_alpha_value = kwargs["draft_rope_alpha"]
self.draft_config.scale_alpha_value = kwargs.get("draft_rope_alpha") or 1
else:
ratio = self.config.max_seq_len / self.draft_config.max_seq_len
alpha = -0.13436 + 0.80541 * ratio + 0.28833 * ratio ** 2