Model: Fix model override application for draft args

These have to be merged beforehand and the updated version needs to be
re-fetched. It's possible to prevent the fetch of draft_args in the
beginning of init.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-08-31 22:56:49 -04:00 committed by Brian Dashore
parent 4aebe8a2a5
commit 4bf1a71d7b

View file

@ -347,6 +347,9 @@ class ExllamaV2Container:
# Set user-configured draft model values
if enable_draft:
# Fetch from the updated kwargs
draft_args = unwrap(kwargs.get("draft"), {})
self.draft_config.max_seq_len = self.config.max_seq_len
self.draft_config.scale_pos_emb = unwrap(
@ -378,9 +381,15 @@ class ExllamaV2Container:
return kwargs
with open(override_config_path, "r", encoding="utf8") as override_config_file:
override_config = unwrap(yaml.safe_load(override_config_file), {})
merged_kwargs = {**override_config, **kwargs}
override_args = unwrap(yaml.safe_load(override_config_file), {})
# Merge draft overrides beforehand
draft_override_args = unwrap(override_args.get("draft"), {})
if self.draft_config and draft_override_args:
kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")}
# Merge the override and model kwargs
merged_kwargs = {**override_args, **kwargs}
return merged_kwargs
def find_prompt_template(self, prompt_template_name, model_directory):