You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am performing an hyperparameter search, using a grid search, which implies creating a different model every loop and training it. The instantiation and training of the model are performed inside a function that is called with different configuration parameters every run.
When finishing every iteration of the loop, the model and all created variables, including optimizer, dataset, and STDPLearners, go out of scope, which should delete them. However, the current behaviour is that tensors created by STDPLearner are never deleted, filling up the memory until the program crashes. The MWE crashes after only 3 iterations when the model includes a layer.Conv2D, and the problem still exists when the model only uses layer.Linear layers, although to a lesser extent.
Happens with both torch and cupy backends. Explicitly calling the garbage collector does not help. When training using gradient descent, as shown in the minimal example by swapping train_model_stdp with the train_model_gd function, this problem doesn't show up, linking it with the STDPLearner class.
Minimal code to reproduce the error/bug
This is the basic training code that causes the issue. A single epoch per run is used because the problem only shows up when starting new runs. Includes a debugging function to see how many tensors are there.
The out-of-memory problem you've encountered can be addressed by running STDPLearner.step() within a torch.no_grad() context. You may modify your STDP code as follows:
STDPLearner.step(on_grad=True) updates the weight by adding $$-\Delta w$$ to w.grad, where w is the weight tensor. After optimizer.step() is called, $$-\Delta w$$ is subtracted from the weight, thus updating the weight as expected. Wrapping STDPLearner.step() in a torch.no_grad() context ensures that its internal computations are excluded from the PyTorch computational graph. Without torch.no_grad(), the operations in STDPLearner.step() will be part of the graph; as no backward() is called in your STDP code, the computational graph will never be freed, leading to the out-of-memory error.
Thank you, this seems to prevent the memory usage from growing uncontrollably, which is good enough for me. However, runs following the first one still take up more memory than a single run would.
According to pytorch docs, the reference to the computational graph is held by the resulting tensor of an operation, meaning that if that tensor goes out of scope the whole graph is freed. It appears that there's still some reference to the STDPLearner tensors that is accesible from outside the function, probably at the module level, otherwise they would be collected.
Issue type
SpikingJelly version
0.0.0.0.15
Description
I am performing an hyperparameter search, using a grid search, which implies creating a different model every loop and training it. The instantiation and training of the model are performed inside a function that is called with different configuration parameters every run.
When finishing every iteration of the loop, the model and all created variables, including optimizer, dataset, and STDPLearners, go out of scope, which should delete them. However, the current behaviour is that tensors created by STDPLearner are never deleted, filling up the memory until the program crashes. The MWE crashes after only 3 iterations when the model includes a
layer.Conv2D
, and the problem still exists when the model only useslayer.Linear
layers, although to a lesser extent.Happens with both torch and cupy backends. Explicitly calling the garbage collector does not help. When training using gradient descent, as shown in the minimal example by swapping
train_model_stdp
with thetrain_model_gd
function, this problem doesn't show up, linking it with theSTDPLearner
class.Minimal code to reproduce the error/bug
This is the basic training code that causes the issue. A single epoch per run is used because the problem only shows up when starting new runs. Includes a debugging function to see how many tensors are there.
By swapping the
train_model_stdp
function with the following one, which simulates training with gradient descent, the problem disappearsThe text was updated successfully, but these errors were encountered: