Model: Wrap load in inference_mode
Some tensors were being taken out of inference mode during each iteration of exllama's load_autosplit_gen. This causes errors since autograd is off. Therefore, make the shared load_gen_sync function have an overarching inference_mode context to prevent forward issues. This should allow for the generator to iterate across each thread call. Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
parent
37a80334a8
commit
5055a98e41
1 changed files with 3 additions and 0 deletions
|
|
@ -377,6 +377,7 @@ class ExllamaV2Container:
|
|||
async for value in iterate_in_threadpool(load_generator):
|
||||
yield value
|
||||
|
||||
@torch.inference_mode()
|
||||
def load_gen_sync(self, progress_callback=None):
|
||||
"""
|
||||
Load model, generator function
|
||||
|
|
@ -385,6 +386,8 @@ class ExllamaV2Container:
|
|||
progress_callback (function, optional): A function to call for each
|
||||
module loaded. Prototype:
|
||||
def progress(loaded_modules: int, total_modules: int)
|
||||
|
||||
Runs under a shared inference mode context.
|
||||
"""
|
||||
|
||||
# Notify that the model is being loaded
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue