Model: Wrap load in inference_mode

Some tensors were being taken out of inference mode during each
iteration of exllama's load_autosplit_gen. This causes errors since
autograd is off.

Therefore, make the shared load_gen_sync function have an overarching
inference_mode context to prevent forward issues. This should allow for
the generator to iterate across each thread call.

Signed-off-by: kingbri <bdashore3@proton.me>
This commit is contained in:
kingbri 2024-03-21 18:04:22 -04:00
parent 37a80334a8
commit 5055a98e41

View file

@ -377,6 +377,7 @@ class ExllamaV2Container:
async for value in iterate_in_threadpool(load_generator):
yield value
@torch.inference_mode()
def load_gen_sync(self, progress_callback=None):
"""
Load model, generator function
@ -385,6 +386,8 @@ class ExllamaV2Container:
progress_callback (function, optional): A function to call for each
module loaded. Prototype:
def progress(loaded_modules: int, total_modules: int)
Runs under a shared inference mode context.
"""
# Notify that the model is being loaded