diff --git a/examples/agent_examples/plotting_metrics.py b/examples/agent_examples/plotting_metrics.py
deleted file mode 100644
index 98b7144..0000000
--- a/examples/agent_examples/plotting_metrics.py
+++ /dev/null
@@ -1,247 +0,0 @@
-import re
-from collections import defaultdict
-
-import matplotlib.pyplot as plt
-
-
-def parse_and_plot_run_log(file_path):
- iterations = []
- losses = []
- accuracies_p = []
- accuracies_g = []
- accuracies_gt = []
- new_walks = []
-
- iter_pattern = r"Finished backprop iter (\d+)"
- loss_pattern = r"Loss: ([\d.]+)\." # Note the added \. to catch the trailing period
- accuracy_pattern = r"Accuracy:
([\d.]+)% ([\d.]+)% ([\d.]+)%"
- new_walk_pattern = r"Iteration (\d+): new walk"
-
- with open(file_path, "r") as file:
- for line in file:
- iter_match = re.search(iter_pattern, line)
- if iter_match:
- iterations.append(int(iter_match.group(1)))
-
- loss_match = re.search(loss_pattern, line)
- if loss_match:
- losses.append(float(loss_match.group(1)))
-
- accuracy_match = re.search(accuracy_pattern, line)
- if accuracy_match:
- accuracies_p.append(float(accuracy_match.group(1)))
- accuracies_g.append(float(accuracy_match.group(2)))
- accuracies_gt.append(float(accuracy_match.group(3)))
-
- new_walk_match = re.search(new_walk_pattern, line)
- if new_walk_match:
- new_walks.append(int(new_walk_match.group(1)))
-
- fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12))
-
- ax1.plot(iterations, losses)
- ax1.set_xlabel("Iteration")
- ax1.set_ylabel("Loss")
- ax1.set_title("Loss over Iterations")
-
- ax2.plot(iterations, accuracies_p, label="p accuracy")
- ax2.plot(iterations, accuracies_g, label="g accuracy")
- ax2.plot(iterations, accuracies_gt, label="gt accuracy")
- ax2.set_xlabel("Iteration")
- ax2.set_ylabel("Accuracy (%)")
- ax2.set_title("Accuracies over Iterations")
- ax2.legend()
-
- # Add vertical lines for new walks
- # for walk in new_walks:
- # ax1.axvline(x=walk, color='r', linestyle='--', alpha=0.5)
- # ax2.axvline(x=walk, color='r', linestyle='--', alpha=0.5)
-
- plt.tight_layout()
- plt.show()
-
-
-def analyse_log_file(file):
- # Regular expressions to match IDs and Objs lines
- id_pattern = re.compile(r"IDs: \[([^\]]+)\]")
- obj_pattern = re.compile(r"Objs: \[([^\]]+)\]")
- iter_pattern = re.compile(r"Finished backprop iter (\d+)")
- step_pattern = re.compile(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}:")
-
- # Initialize data structures
- iteration_data = {}
- current_iteration = None
- id_to_obj_previous = {}
- current_step = 0
-
- with open(file, "r") as file:
- for line in file:
- # Check for iteration number
- iter_match = iter_pattern.search(line)
- if iter_match:
- current_iteration = int(iter_match.group(1))
- current_step = 0 # Reset step counter for new iteration
- continue # Proceed to next line
-
- # Check for step line (assumed to start with timestamp)
- if step_pattern.match(line):
- current_step += 1
-
- # Extract IDs
- id_match = id_pattern.search(line)
- if id_match:
- ids = list(map(int, id_match.group(1).split(",")))
- continue # IDs are followed by Objs, proceed to next line
-
- # Extract Objs
- obj_match = obj_pattern.search(line)
- if obj_match:
- objs = list(map(int, obj_match.group(1).split(",")))
-
- # Ensure current_iteration is set
- if current_iteration is None:
- continue # Skip if iteration is not identified yet
-
- # Store IDs and Objs for this iteration and step
- if current_iteration not in iteration_data:
- iteration_data[current_iteration] = []
- iteration_data[current_iteration].append((current_step, ids, objs))
-
- # Now, process the data to find shifts with detailed information
- shifts = defaultdict(list) # Key: iteration, Value: list of shift details
- id_to_obj_current = {}
-
- sorted_iterations = sorted(iteration_data.keys())
-
- for idx, iteration in enumerate(sorted_iterations):
- steps = iteration_data[iteration]
- # For each step in the iteration
- for step in steps:
- step_num, ids, objs = step
- # For each ID in the batch
- for batch_idx, (id_, obj) in enumerate(zip(ids, objs)):
- key = (batch_idx, id_) # Identify by batch index and ID
- if key in id_to_obj_previous:
- prev_info = id_to_obj_previous[key]
- prev_obj = prev_info["obj"]
- if obj != prev_obj:
- # Environment has changed for this batch member
- shifts[iteration].append(
- {
- "batch_idx": batch_idx,
- "id": id_,
- "prev_obj": prev_obj,
- "new_obj": obj,
- "prev_iteration": prev_info["iteration"],
- "prev_step": prev_info["step"],
- "current_iteration": iteration,
- "current_step": step_num,
- }
- )
- # Update current mapping
- id_to_obj_current[key] = {"obj": obj, "iteration": iteration, "step": step_num}
- # After processing all steps in the iteration, update previous mapping
- id_to_obj_previous = id_to_obj_current.copy()
- id_to_obj_current.clear()
-
- # Output the iterations where shifts occurred with detailed information
- print("Environment shifts detected with detailed information:")
- with open("shifts_output.txt", "w") as output_file:
- for iteration in sorted(shifts.keys()):
- shift_list = shifts[iteration]
- if shift_list:
- output_file.write(f"\nIteration {iteration}: number of shifts = {len(shift_list)}\n")
- for shift in shift_list:
- output_file.write(
- f" Batch index {shift['batch_idx']}, ID {shift['id']} changed from "
- f"object {shift['prev_obj']} (Iteration {shift['prev_iteration']}, Step {shift['prev_step']}) "
- f"to object {shift['new_obj']} (Iteration {shift['current_iteration']},\
- Step {shift['current_step']})\n"
- )
-
-
-def plot_loss_with_switches(log_file_path, output_file_path, large_switch_threshold):
- # Initialize lists to store data
- iterations = []
- losses = []
- large_switch_iterations = []
- switch_counts = {}
-
- # Regular expressions to match lines in the log
- loss_pattern = re.compile(r"Loss: ([\d\.]+)")
- iteration_pattern = re.compile(r"Finished backprop iter (\d+)")
- # For the output file with switches
- switch_iteration_pattern = re.compile(r"Iteration (\d+): number of shifts = (\d+)")
-
- # Parse the training log file
- with open(log_file_path, "r") as log_file:
- current_iteration = None
- for line in log_file:
- # Check for iteration number
- iteration_match = iteration_pattern.search(line)
- if iteration_match:
- current_iteration = int(iteration_match.group(1))
- iterations.append(current_iteration)
- continue # Move to the next line
-
- # Check for loss value
- loss_match = loss_pattern.search(line)
- if loss_match and current_iteration is not None:
- loss = float(loss_match.group(1)[:-1])
- losses.append(loss)
- continue # Move to the next line
-
- # Parse the output file to get switch information
- with open(output_file_path, "r") as output_file:
- for line in output_file:
- # Check for switch iteration
- switch_iter_match = switch_iteration_pattern.match(line)
- if switch_iter_match:
- iteration = int(switch_iter_match.group(1))
- num_shifts = int(switch_iter_match.group(2))
- # Record iterations with shifts exceeding the threshold
- if num_shifts >= large_switch_threshold:
- large_switch_iterations.append(iteration)
- switch_counts[iteration] = num_shifts
-
- # Ensure the lists are aligned
- iterations = iterations[: len(losses)]
-
- # Plotting the loss over iterations
- plt.figure(figsize=(12, 6))
- plt.plot(iterations, losses, label="Training Loss", color="blue")
-
- # Add markers for iterations with large switches
- for switch_iter in large_switch_iterations:
- if switch_iter in iterations:
- idx = iterations.index(switch_iter)
- plt.axvline(x=switch_iter, color="red", linestyle="--", alpha=0.5)
- # Optionally, add a text annotation for the number of shifts
- plt.text(
- switch_iter,
- losses[idx],
- f"{switch_counts[switch_iter]} shifts",
- rotation=90,
- va="bottom",
- ha="center",
- color="red",
- fontsize=8,
- )
-
- plt.title("Training Loss over Iterations with Large Batch Index Switches")
- plt.xlabel("Iteration")
- plt.ylabel("Loss")
- plt.legend()
- plt.grid(True)
- plt.show()
-
-
-parse_and_plot_run_log(
- "/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples/agent_examples/begging_full/run.log"
-)
-# analyse_log_file('/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples/agent_examples
-# /test/run.log')
-# plot_loss_with_switches('/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples
-# /agent_examples/test/run.log',
-# '/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples/agent_examples/
-# test/output.txt', 50)
diff --git a/examples/agent_examples/whittington_2020_debug.py b/examples/agent_examples/whittington_2020_debug.py
deleted file mode 100644
index e63bbde..0000000
--- a/examples/agent_examples/whittington_2020_debug.py
+++ /dev/null
@@ -1,56 +0,0 @@
-import importlib
-import os
-
-import numpy as np
-import pandas as pd
-
-from neuralplayground.plotting import PlotSim
-
-# simulation_id = "examples/agent_examples/TEM_test_with_break"
-simulation_id = "TEM_test_witch_break"
-save_path = simulation_id + "/"
-plotting_loop_params = {"n_walk": 200}
-
-training_dict = pd.read_pickle(os.path.join(os.getcwd(), save_path, "params.dict"))
-model_weights = pd.read_pickle(os.path.join(save_path, "agent"))
-model_spec = importlib.util.spec_from_file_location("model", save_path + "whittington_2020_model.py")
-model = importlib.util.module_from_spec(model_spec)
-model_spec.loader.exec_module(model)
-params = pd.read_pickle(os.path.join(save_path, "agent_hyper"))
-tem = model.Model(params)
-tem.load_state_dict(model_weights)
-
-sim = PlotSim(
- simulation_id=simulation_id,
- agent_class=training_dict["agent_class"],
- agent_params=training_dict["agent_params"],
- env_class=training_dict["env_class"],
- env_params=training_dict["env_params"],
- plotting_loop_params=plotting_loop_params,
-)
-
-trained_agent, trained_env = sim.plot_sim(save_path, random_state=False, custom_state=[0.0, 0.0])
-# trained_env.plot_trajectories();
-
-max_steps_per_env = np.random.randint(4000, 5000, size=params["batch_size"])
-current_steps = np.zeros(params["batch_size"], dtype=int)
-
-obs, state = trained_env.reset(random_state=False, custom_state=[0.0, 0.0])
-for i in range(200):
- while trained_agent.n_walk < params["n_rollout"]:
- actions = trained_agent.batch_act(obs)
- obs, state, reward = trained_env.step(actions, normalize_step=True)
- trained_agent.update()
-
- current_steps += params["n_rollout"]
- finished_walks = current_steps >= max_steps_per_env
- if any(finished_walks):
- for env_i in np.where(finished_walks)[0]:
- trained_env.reset_env(env_i)
- trained_agent.prev_iter[0].a[env_i] = None
-
- max_steps_per_env[env_i] = params["n_rollout"] * np.random.randint(
- trained_agent.walk_length_center - params["walk_it_window"] * 0.5,
- trained_agent.walk_length_center + params["walk_it_window"] * 0.5,
- )
- current_steps[env_i] = 0
diff --git a/examples/agent_examples/whittington_2020_plot.py b/examples/agent_examples/whittington_2020_plot.py
deleted file mode 100644
index 4c1c56f..0000000
--- a/examples/agent_examples/whittington_2020_plot.py
+++ /dev/null
@@ -1,183 +0,0 @@
-# Standard Imports
-import importlib.util
-import pickle
-
-import matplotlib.pyplot as plt
-import numpy as np
-import torch
-
-import neuralplayground.agents.whittington_2020_extras.whittington_2020_analyse as analyse
-from neuralplayground.agents.whittington_2020 import Whittington2020
-from neuralplayground.arenas.batch_environment import BatchEnvironment
-
-# NeuralPlayground Imports
-from neuralplayground.arenas.discritized_objects import DiscreteObjectEnvironment
-
-# NeuralPlayground Experiment Imports
-from neuralplayground.experiments import Sargolini2006Data
-
-# Select trained model
-date = "2023-08-07"
-run = "3"
-index = "19999"
-save_path = "NeuralPlayground/examples/"
-# Load the model: use import library to import module from specified path
-model_spec = importlib.util.spec_from_file_location(
- "model", save_path + "/Summaries2/" + date + "/torch_run" + run + "/script/whittington_2020_model.py"
-)
-model = importlib.util.module_from_spec(model_spec)
-model_spec.loader.exec_module(model)
-
-# Load the parameters of the model
-params = torch.load(save_path + "/Summaries2/" + date + "/torch_run" + run + "/model/params_" + index + ".pt")
-# Create a new tem model with the loaded parameters
-tem = model.Model(params)
-# Load the model weights after training
-model_weights = torch.load(save_path + "/Summaries2/" + date + "/torch_run" + run + "/model/tem_" + index + ".pt")
-# Set the model weights to the loaded trained model weights
-tem.load_state_dict(model_weights)
-# Make sure model is in evaluate mode (not crucial because it doesn't currently use dropout or batchnorm layers)
-tem.eval()
-
-# Initialise environment parameters
-batch_size = 16
-arena_x_limits = [
- [-5, 5],
- [-4, 4],
- [-5, 5],
- [-6, 6],
- [-4, 4],
- [-5, 5],
- [-6, 6],
- [-5, 5],
- [-4, 4],
- [-5, 5],
- [-6, 6],
- [-5, 5],
- [-4, 4],
- [-5, 5],
- [-6, 6],
- [-5, 5],
-]
-arena_y_limits = [
- [-5, 5],
- [-4, 4],
- [-5, 5],
- [-6, 6],
- [-4, 4],
- [-5, 5],
- [-6, 6],
- [-5, 5],
- [-4, 4],
- [-5, 5],
- [-6, 6],
- [-5, 5],
- [-4, 4],
- [-5, 5],
- [-6, 6],
- [-5, 5],
-]
-# arena_x_limits = [[-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10],
-# [-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10]]
-# arena_y_limits = [[-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1],
-# [-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1]]
-env_name = "env_example"
-mod_name = "SimpleTEM"
-time_step_size = 1
-state_density = 1
-agent_step_size = 1 / state_density
-n_objects = 45
-
-# Init simple 2D environment with discrtised objects
-env_class = DiscreteObjectEnvironment
-env = BatchEnvironment(
- environment_name=env_name,
- env_class=DiscreteObjectEnvironment,
- batch_size=batch_size,
- arena_x_limits=arena_x_limits,
- arena_y_limits=arena_y_limits,
- state_density=state_density,
- n_objects=n_objects,
- agent_step_size=agent_step_size,
- use_behavioural_data=False,
- data_path=None,
- experiment_class=Sargolini2006Data,
-)
-
-# Init TEM agent
-agent = Whittington2020(
- model_name=mod_name,
- params=params,
- batch_size=batch_size,
- room_widths=env.room_widths,
- room_depths=env.room_depths,
- state_densities=env.state_densities,
- use_behavioural_data=False,
-)
-
-# # Run around environment
-# observation, state = env.reset(random_state=True, custom_state=None)
-# while agent.n_walk < 5000:
-# if agent.n_walk % 100 == 0:
-# print(agent.n_walk)
-# action = agent.batch_act(observation)
-# observation, state = env.step(action, normalize_step=True)
-# model_input, history, environments = agent.collect_final_trajectory()
-# environments = [env.collect_environment_info(model_input, history, environments)]
-
-# # Save environments and model_input using pickle
-# with open('NPG_environments.pkl', 'wb') as f:
-# pickle.dump(environments, f)
-# with open('NPG_model_input.pkl', 'wb') as f:
-# pickle.dump(model_input, f)
-
-# Load environments and model_input using pickle
-with open("NPG_environments.pkl", "rb") as f:
- environments = pickle.load(f)
-with open("NPG_model_input.pkl", "rb") as f:
- model_input = pickle.load(f)
-
-with torch.no_grad():
- forward = tem(model_input, prev_iter=None)
-include_stay_still = False
-shiny_envs = [False, False, False, False]
-env_to_plot = 0
-envs_to_avg = shiny_envs if shiny_envs[env_to_plot] else [not shiny_env for shiny_env in shiny_envs]
-
-correct_model, correct_node, correct_edge = analyse.compare_to_agents(
- forward, tem, environments, include_stay_still=include_stay_still
-)
-zero_shot = analyse.zero_shot(forward, tem, environments, include_stay_still=include_stay_still)
-occupation = analyse.location_occupation(forward, tem, environments)
-g, p = analyse.rate_map(forward, tem, environments)
-from_acc, to_acc = analyse.location_accuracy(forward, tem, environments)
-
-# Plot rate maps for grid or place cells
-agent.plot_rate_map(g)
-
-# Plot results of agent comparison and zero-shot inference analysis
-filt_size = 41
-plt.figure()
-plt.plot(
- analyse.smooth(
- np.mean(np.array([env for env_i, env in enumerate(correct_model) if envs_to_avg[env_i]]), 0)[1:], filt_size
- ),
- label="tem",
-)
-plt.plot(
- analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_node) if envs_to_avg[env_i]]), 0)[1:], filt_size),
- label="node",
-)
-plt.plot(
- analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_edge) if envs_to_avg[env_i]]), 0)[1:], filt_size),
- label="edge",
-)
-plt.ylim(0, 1)
-plt.legend()
-plt.title(
- "Zero-shot inference: "
- + str(np.mean([np.mean(env) for env_i, env in enumerate(zero_shot) if envs_to_avg[env_i]]) * 100)
- + "%"
-)
-
-# plt.show()
diff --git a/examples/agent_examples/whittington_2020_plot_test.py b/examples/agent_examples/whittington_2020_plot_test.py
deleted file mode 100644
index 28d3a5c..0000000
--- a/examples/agent_examples/whittington_2020_plot_test.py
+++ /dev/null
@@ -1,47 +0,0 @@
-import importlib
-import os
-import pickle
-
-import matplotlib.pyplot as plt
-import pandas as pd
-
-from neuralplayground.plotting import PlotSim
-
-simulation_id = "TEM_custom_plot_sim"
-save_path = os.path.join(os.getcwd(), "examples", "agent_examples", "results_sim")
-training_dict = pd.read_pickle(os.path.join(save_path, "params.dict"))
-model_weights = pd.read_pickle(os.path.join(save_path, "agent"))
-model_spec = importlib.util.spec_from_file_location("model", save_path + "/whittington_2020_model.py")
-model = importlib.util.module_from_spec(model_spec)
-model_spec.loader.exec_module(model)
-params = pd.read_pickle(os.path.join(save_path, "agent_hyper"))
-tem = model.Model(params)
-tem.load_state_dict(model_weights)
-tem.eval()
-
-plotting_loop_params = {"n_episode": 50}
-sim = PlotSim(
- simulation_id=simulation_id,
- agent_class=training_dict["agent_class"],
- agent_params=training_dict["agent_params"],
- env_class=training_dict["env_class"],
- env_params=training_dict["env_params"],
- plotting_loop_params=plotting_loop_params,
-)
-print(sim)
-sim.plot_sim(save_path)
-
-# Load environments and model_input using pickle
-with open(os.path.join(save_path, "NPG_environments.pkl"), "rb") as f:
- environments = pickle.load(f)
-with open(os.path.join(save_path, "NPG_model_input.pkl"), "rb") as f:
- model_input = pickle.load(f)
-
-training_dict["params"] = training_dict["agent_params"]
-del training_dict["agent_params"]
-agent = training_dict["agent_class"](**training_dict["params"])
-agent.plot_run(tem, model_input, environments)
-
-# Plot rate maps for grid or place cells
-agent.plot_rate_map(rate_map_type="g")
-plt.show()
diff --git a/examples/agent_examples/whittington_slurm.sh b/examples/agent_examples/whittington_slurm.sh
deleted file mode 100644
index 55415c1..0000000
--- a/examples/agent_examples/whittington_slurm.sh
+++ /dev/null
@@ -1,19 +0,0 @@
-#!/bin/bash
-
-#SBATCH -J TEM_beg # job name
-#SBATCH -p gpu # partition (queue)
-#SBATCH -N 1 # number of nodes
-#SBATCH --mem 50G # memory pool for all cores
-#SBATCH -n 4 # number of cores
-#SBATCH -t 0-72:00 # time (D-HH:MM)
-#SBATCH --gres gpu:1 # request 1 GPU (of any kind)
-#SBATCH -o TEM_beg.%x.%N.%j.out # STDOUT
-#SBATCH -e TEM_beg.%x.%N.%j.err # STDERR
-
-source ~/.bashrc
-
-conda activate NPG-env
-
-python whittington_2020_run.py
-
-exit
diff --git a/neuralplayground/arenas/batch_environment.py b/neuralplayground/arenas/batch_environment.py
index 81d6262..897e1d4 100644
--- a/neuralplayground/arenas/batch_environment.py
+++ b/neuralplayground/arenas/batch_environment.py
@@ -7,26 +7,25 @@
class BatchEnvironment(Environment):
+ """
+ Class to handle a batch of environments, where each environment is an instance of the same class. This is useful for training a single agent on multiple environments simultaneously.
+ ----------
+ environment_name: str
+ Name of the environment
+ env_class: object
+ Class of the environment
+ batch_size: int
+ Number of environments in the batch
+ **env_kwargs: dict
+ Keyword arguments for the environment
+ """
def __init__(
self,
environment_name: str = "BatchEnv",
env_class: object = DiscreteObjectEnvironment,
batch_size: int = 1,
**env_kwargs,
- ):
- """
- Initialise a batch of environments. This is useful for training a single agent on multiple environments simultaneously.
- Parameters
- ----------
- environment_name: str
- Name of the environment
- env_class: object
- Class of the environment
- batch_size: int
- Number of environments in the batch
- **env_kwargs: dict
- Keyword arguments for the environment
- """
+ ):
super().__init__(environment_name, **env_kwargs)
self.env_kwargs = env_kwargs.copy()
arg_env_params = env_kwargs["arg_env_params"]