Skip to content

Commit

Permalink
adding cpu versions of the models
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigcd committed Jun 19, 2024
1 parent 64d2829 commit 8b55a01
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 5 deletions.
9 changes: 6 additions & 3 deletions examples/neuroai_tutorial/rnn_grid_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1529,8 +1529,8 @@
"sequence_length = 20\n",
"\n",
"# If you are loading a pre-trained model, make sure to use the same activation function\n",
"activation = \"relu\" # This will give you \"square grids\", some of them actually look hexagonal\n",
"# activation = \"relu\" # This will give you hexagonal grids"
"# activation = \"tanh\" # This will give you \"square grids\", some of them actually look hexagonal\n",
"activation = \"relu\" # This will give you hexagonal grids"
],
"id": "bc631ad84d0ce55f",
"outputs": [],
Expand Down Expand Up @@ -1628,7 +1628,10 @@
"source": [
"# Load pre-trained model\n",
"real_rnn.load_model(\"tmp_tutorial_model/pre_trained_relu\")\n",
"# real_rnn.load_model(\"tmp_tutorial_model/pre_trained_tanh\")"
"# real_rnn.load_model(\"tmp_tutorial_model/pre_trained_tanh\")\n",
"# If you don't have gpu try loading the cpu model\n",
"# real_rnn.load_model(\"tmp_tutorial_model/pre_trained_relu_cpu\")\n",
"# real_rnn.load_model(\"tmp_tutorial_model/pre_trained_tanh_cpu\")"
],
"id": "91a201d374ae0d4c",
"outputs": [
Expand Down
Binary file not shown.
Binary file not shown.
17 changes: 15 additions & 2 deletions neuralplayground/agents/Sorscher_2022.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,23 @@ def train_RNN(self, data_generator, training_steps):

return self.loss_hist, self.pos_err_hist

def save_model(self, path):
# torch.save(self.state_dict(), path+".torch")
def save_model(self, path, cpu=False):
if cpu:
self.convert_params_to_cpu()
pickle.dump(self.__dict__, open(path+".pkl", "wb"))

def convert_params_to_cpu(self):
self.encoder_W = self.encoder_W.cpu()
self.recurrent_W = self.recurrent_W.cpu()
self.velocity_W = self.velocity_W.cpu()
self.decoder_W = self.decoder_W.cpu()

def convert_params_to_gpu(self):
self.encoder_W = self.encoder_W.to(self.device)
self.recurrent_W = self.recurrent_W.to(self.device)
self.velocity_W = self.velocity_W.to(self.device)
self.decoder_W = self.decoder_W.to(self.device)

def load_model(self, path):
# self.load_state_dict(torch.load(path))
# self.eval()
Expand Down

0 comments on commit 8b55a01

Please sign in to comment.