Skip to content

Commit

Permalink
minor update
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeHollingsworth committed Oct 15, 2024
1 parent 1fe31ab commit 2a95433
Showing 1 changed file with 6 additions and 30 deletions.
36 changes: 6 additions & 30 deletions neuralplayground/agents/whittington_2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,39 +340,14 @@ def update(self):
self.accuracy_history["gt_accuracy"].append(acc_gt)
self.iter += 1

# Save accuracies periodically (e.g., every 100 iterations)
# if (self.iter) % 100 == 0:
# self.save_accuracies()
# # Also store the internal state (all learnable parameters) and the hyperparameters periodically
# if self.iter % self.pars["save_interval"] == 0:
# torch.save(self.tem.state_dict(), self.model_path + "/tem_" + str(self.iter) + ".pt")
# torch.save(self.tem.hyper, self.model_path + "/params_" + str(self.iter) + ".pt")

# # Save the final state of the model after training has finished
# if self.iter == self.pars["train_it"] - 1:
# torch.save(self.tem.state_dict(), self.model_path + "/tem_" + str(self.iter) + ".pt")
# torch.save(self.tem.hyper, self.model_path + "/params_" + str(self.iter) + ".pt")
# if self.iter % 10 == 0:
# print("Iteration: ", self.iter)
# print("Accuracies: ", acc_p, acc_g, acc_gt)

def initialise(self):
"""
Generate random distribution of objects and intialise optimiser, logger and relevant variables
"""
# Create directories for storing all information about the current run
# (
# self.run_path,
# self.train_path,
# self.model_path,
# self.save_path,
# self.script_path,
# self.envs_path,
# ) = utils.make_directories()
# # Save all python files in current directory to script directory
# self.save_files()
# # Save parameters
# np.save(os.path.join(self.save_path, "params"), self.pars)
# # Create a tensor board to stay updated on training progress. Start tensorboard with tensorboard --logdir=runs
# self.writer = SummaryWriter(self.train_path)
# Create a logger to write log output to file
current_dir = os.path.dirname(os.getcwd())
run_path = os.path.join(current_dir, "agent_examples", self.save_name)
run_path = os.path.normpath(run_path)
Expand Down Expand Up @@ -533,8 +508,9 @@ def collect_final_trajectory(self):
return final_model_input, history, environments

def plot_run(self, tem, model_input, environments):
with torch.no_grad():
forward = tem(model_input, prev_iter=None)
# with torch.no_grad():
# forward = tem(model_input, prev_iter=None)
forward = tem(model_input, prev_iter=None)
include_stay_still = True
# shiny_envs = [False, False, False, False]
# env_to_plot = 0
Expand Down

0 comments on commit 2a95433

Please sign in to comment.