Model: Fix model override application for draft args
These have to be merged beforehand and the updated version needs to be re-fetched. It's possible to prevent the fetch of draft_args in the beginning of init. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
4aebe8a2a5
commit
4bf1a71d7b
1 changed files with 11 additions and 2 deletions
|
|
@ -347,6 +347,9 @@ class ExllamaV2Container:
|
|||
|
||||
# Set user-configured draft model values
|
||||
if enable_draft:
|
||||
# Fetch from the updated kwargs
|
||||
draft_args = unwrap(kwargs.get("draft"), {})
|
||||
|
||||
self.draft_config.max_seq_len = self.config.max_seq_len
|
||||
|
||||
self.draft_config.scale_pos_emb = unwrap(
|
||||
|
|
@ -378,9 +381,15 @@ class ExllamaV2Container:
|
|||
return kwargs
|
||||
|
||||
with open(override_config_path, "r", encoding="utf8") as override_config_file:
|
||||
override_config = unwrap(yaml.safe_load(override_config_file), {})
|
||||
merged_kwargs = {**override_config, **kwargs}
|
||||
override_args = unwrap(yaml.safe_load(override_config_file), {})
|
||||
|
||||
# Merge draft overrides beforehand
|
||||
draft_override_args = unwrap(override_args.get("draft"), {})
|
||||
if self.draft_config and draft_override_args:
|
||||
kwargs["draft"] = {**draft_override_args, **kwargs.get("draft")}
|
||||
|
||||
# Merge the override and model kwargs
|
||||
merged_kwargs = {**override_args, **kwargs}
|
||||
return merged_kwargs
|
||||
|
||||
def find_prompt_template(self, prompt_template_name, model_directory):
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue