diff --git a/examples/agent_examples/plotting_metrics.py b/examples/agent_examples/plotting_metrics.py deleted file mode 100644 index 98b7144..0000000 --- a/examples/agent_examples/plotting_metrics.py +++ /dev/null @@ -1,247 +0,0 @@ -import re -from collections import defaultdict - -import matplotlib.pyplot as plt - - -def parse_and_plot_run_log(file_path): - iterations = [] - losses = [] - accuracies_p = [] - accuracies_g = [] - accuracies_gt = [] - new_walks = [] - - iter_pattern = r"Finished backprop iter (\d+)" - loss_pattern = r"Loss: ([\d.]+)\." # Note the added \. to catch the trailing period - accuracy_pattern = r"Accuracy:

([\d.]+)% ([\d.]+)% ([\d.]+)%" - new_walk_pattern = r"Iteration (\d+): new walk" - - with open(file_path, "r") as file: - for line in file: - iter_match = re.search(iter_pattern, line) - if iter_match: - iterations.append(int(iter_match.group(1))) - - loss_match = re.search(loss_pattern, line) - if loss_match: - losses.append(float(loss_match.group(1))) - - accuracy_match = re.search(accuracy_pattern, line) - if accuracy_match: - accuracies_p.append(float(accuracy_match.group(1))) - accuracies_g.append(float(accuracy_match.group(2))) - accuracies_gt.append(float(accuracy_match.group(3))) - - new_walk_match = re.search(new_walk_pattern, line) - if new_walk_match: - new_walks.append(int(new_walk_match.group(1))) - - fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 12)) - - ax1.plot(iterations, losses) - ax1.set_xlabel("Iteration") - ax1.set_ylabel("Loss") - ax1.set_title("Loss over Iterations") - - ax2.plot(iterations, accuracies_p, label="p accuracy") - ax2.plot(iterations, accuracies_g, label="g accuracy") - ax2.plot(iterations, accuracies_gt, label="gt accuracy") - ax2.set_xlabel("Iteration") - ax2.set_ylabel("Accuracy (%)") - ax2.set_title("Accuracies over Iterations") - ax2.legend() - - # Add vertical lines for new walks - # for walk in new_walks: - # ax1.axvline(x=walk, color='r', linestyle='--', alpha=0.5) - # ax2.axvline(x=walk, color='r', linestyle='--', alpha=0.5) - - plt.tight_layout() - plt.show() - - -def analyse_log_file(file): - # Regular expressions to match IDs and Objs lines - id_pattern = re.compile(r"IDs: \[([^\]]+)\]") - obj_pattern = re.compile(r"Objs: \[([^\]]+)\]") - iter_pattern = re.compile(r"Finished backprop iter (\d+)") - step_pattern = re.compile(r"^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}:") - - # Initialize data structures - iteration_data = {} - current_iteration = None - id_to_obj_previous = {} - current_step = 0 - - with open(file, "r") as file: - for line in file: - # Check for iteration number - iter_match = iter_pattern.search(line) - if iter_match: - current_iteration = int(iter_match.group(1)) - current_step = 0 # Reset step counter for new iteration - continue # Proceed to next line - - # Check for step line (assumed to start with timestamp) - if step_pattern.match(line): - current_step += 1 - - # Extract IDs - id_match = id_pattern.search(line) - if id_match: - ids = list(map(int, id_match.group(1).split(","))) - continue # IDs are followed by Objs, proceed to next line - - # Extract Objs - obj_match = obj_pattern.search(line) - if obj_match: - objs = list(map(int, obj_match.group(1).split(","))) - - # Ensure current_iteration is set - if current_iteration is None: - continue # Skip if iteration is not identified yet - - # Store IDs and Objs for this iteration and step - if current_iteration not in iteration_data: - iteration_data[current_iteration] = [] - iteration_data[current_iteration].append((current_step, ids, objs)) - - # Now, process the data to find shifts with detailed information - shifts = defaultdict(list) # Key: iteration, Value: list of shift details - id_to_obj_current = {} - - sorted_iterations = sorted(iteration_data.keys()) - - for idx, iteration in enumerate(sorted_iterations): - steps = iteration_data[iteration] - # For each step in the iteration - for step in steps: - step_num, ids, objs = step - # For each ID in the batch - for batch_idx, (id_, obj) in enumerate(zip(ids, objs)): - key = (batch_idx, id_) # Identify by batch index and ID - if key in id_to_obj_previous: - prev_info = id_to_obj_previous[key] - prev_obj = prev_info["obj"] - if obj != prev_obj: - # Environment has changed for this batch member - shifts[iteration].append( - { - "batch_idx": batch_idx, - "id": id_, - "prev_obj": prev_obj, - "new_obj": obj, - "prev_iteration": prev_info["iteration"], - "prev_step": prev_info["step"], - "current_iteration": iteration, - "current_step": step_num, - } - ) - # Update current mapping - id_to_obj_current[key] = {"obj": obj, "iteration": iteration, "step": step_num} - # After processing all steps in the iteration, update previous mapping - id_to_obj_previous = id_to_obj_current.copy() - id_to_obj_current.clear() - - # Output the iterations where shifts occurred with detailed information - print("Environment shifts detected with detailed information:") - with open("shifts_output.txt", "w") as output_file: - for iteration in sorted(shifts.keys()): - shift_list = shifts[iteration] - if shift_list: - output_file.write(f"\nIteration {iteration}: number of shifts = {len(shift_list)}\n") - for shift in shift_list: - output_file.write( - f" Batch index {shift['batch_idx']}, ID {shift['id']} changed from " - f"object {shift['prev_obj']} (Iteration {shift['prev_iteration']}, Step {shift['prev_step']}) " - f"to object {shift['new_obj']} (Iteration {shift['current_iteration']},\ - Step {shift['current_step']})\n" - ) - - -def plot_loss_with_switches(log_file_path, output_file_path, large_switch_threshold): - # Initialize lists to store data - iterations = [] - losses = [] - large_switch_iterations = [] - switch_counts = {} - - # Regular expressions to match lines in the log - loss_pattern = re.compile(r"Loss: ([\d\.]+)") - iteration_pattern = re.compile(r"Finished backprop iter (\d+)") - # For the output file with switches - switch_iteration_pattern = re.compile(r"Iteration (\d+): number of shifts = (\d+)") - - # Parse the training log file - with open(log_file_path, "r") as log_file: - current_iteration = None - for line in log_file: - # Check for iteration number - iteration_match = iteration_pattern.search(line) - if iteration_match: - current_iteration = int(iteration_match.group(1)) - iterations.append(current_iteration) - continue # Move to the next line - - # Check for loss value - loss_match = loss_pattern.search(line) - if loss_match and current_iteration is not None: - loss = float(loss_match.group(1)[:-1]) - losses.append(loss) - continue # Move to the next line - - # Parse the output file to get switch information - with open(output_file_path, "r") as output_file: - for line in output_file: - # Check for switch iteration - switch_iter_match = switch_iteration_pattern.match(line) - if switch_iter_match: - iteration = int(switch_iter_match.group(1)) - num_shifts = int(switch_iter_match.group(2)) - # Record iterations with shifts exceeding the threshold - if num_shifts >= large_switch_threshold: - large_switch_iterations.append(iteration) - switch_counts[iteration] = num_shifts - - # Ensure the lists are aligned - iterations = iterations[: len(losses)] - - # Plotting the loss over iterations - plt.figure(figsize=(12, 6)) - plt.plot(iterations, losses, label="Training Loss", color="blue") - - # Add markers for iterations with large switches - for switch_iter in large_switch_iterations: - if switch_iter in iterations: - idx = iterations.index(switch_iter) - plt.axvline(x=switch_iter, color="red", linestyle="--", alpha=0.5) - # Optionally, add a text annotation for the number of shifts - plt.text( - switch_iter, - losses[idx], - f"{switch_counts[switch_iter]} shifts", - rotation=90, - va="bottom", - ha="center", - color="red", - fontsize=8, - ) - - plt.title("Training Loss over Iterations with Large Batch Index Switches") - plt.xlabel("Iteration") - plt.ylabel("Loss") - plt.legend() - plt.grid(True) - plt.show() - - -parse_and_plot_run_log( - "/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples/agent_examples/begging_full/run.log" -) -# analyse_log_file('/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples/agent_examples -# /test/run.log') -# plot_loss_with_switches('/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples -# /agent_examples/test/run.log', -# '/Users/lukehollingsworth/Documents/PhD/SaxeLab/NeuralPlayground/NeuralPlayground/examples/agent_examples/ -# test/output.txt', 50) diff --git a/examples/agent_examples/whittington_2020_debug.py b/examples/agent_examples/whittington_2020_debug.py deleted file mode 100644 index e63bbde..0000000 --- a/examples/agent_examples/whittington_2020_debug.py +++ /dev/null @@ -1,56 +0,0 @@ -import importlib -import os - -import numpy as np -import pandas as pd - -from neuralplayground.plotting import PlotSim - -# simulation_id = "examples/agent_examples/TEM_test_with_break" -simulation_id = "TEM_test_witch_break" -save_path = simulation_id + "/" -plotting_loop_params = {"n_walk": 200} - -training_dict = pd.read_pickle(os.path.join(os.getcwd(), save_path, "params.dict")) -model_weights = pd.read_pickle(os.path.join(save_path, "agent")) -model_spec = importlib.util.spec_from_file_location("model", save_path + "whittington_2020_model.py") -model = importlib.util.module_from_spec(model_spec) -model_spec.loader.exec_module(model) -params = pd.read_pickle(os.path.join(save_path, "agent_hyper")) -tem = model.Model(params) -tem.load_state_dict(model_weights) - -sim = PlotSim( - simulation_id=simulation_id, - agent_class=training_dict["agent_class"], - agent_params=training_dict["agent_params"], - env_class=training_dict["env_class"], - env_params=training_dict["env_params"], - plotting_loop_params=plotting_loop_params, -) - -trained_agent, trained_env = sim.plot_sim(save_path, random_state=False, custom_state=[0.0, 0.0]) -# trained_env.plot_trajectories(); - -max_steps_per_env = np.random.randint(4000, 5000, size=params["batch_size"]) -current_steps = np.zeros(params["batch_size"], dtype=int) - -obs, state = trained_env.reset(random_state=False, custom_state=[0.0, 0.0]) -for i in range(200): - while trained_agent.n_walk < params["n_rollout"]: - actions = trained_agent.batch_act(obs) - obs, state, reward = trained_env.step(actions, normalize_step=True) - trained_agent.update() - - current_steps += params["n_rollout"] - finished_walks = current_steps >= max_steps_per_env - if any(finished_walks): - for env_i in np.where(finished_walks)[0]: - trained_env.reset_env(env_i) - trained_agent.prev_iter[0].a[env_i] = None - - max_steps_per_env[env_i] = params["n_rollout"] * np.random.randint( - trained_agent.walk_length_center - params["walk_it_window"] * 0.5, - trained_agent.walk_length_center + params["walk_it_window"] * 0.5, - ) - current_steps[env_i] = 0 diff --git a/examples/agent_examples/whittington_2020_plot.py b/examples/agent_examples/whittington_2020_plot.py deleted file mode 100644 index 4c1c56f..0000000 --- a/examples/agent_examples/whittington_2020_plot.py +++ /dev/null @@ -1,183 +0,0 @@ -# Standard Imports -import importlib.util -import pickle - -import matplotlib.pyplot as plt -import numpy as np -import torch - -import neuralplayground.agents.whittington_2020_extras.whittington_2020_analyse as analyse -from neuralplayground.agents.whittington_2020 import Whittington2020 -from neuralplayground.arenas.batch_environment import BatchEnvironment - -# NeuralPlayground Imports -from neuralplayground.arenas.discritized_objects import DiscreteObjectEnvironment - -# NeuralPlayground Experiment Imports -from neuralplayground.experiments import Sargolini2006Data - -# Select trained model -date = "2023-08-07" -run = "3" -index = "19999" -save_path = "NeuralPlayground/examples/" -# Load the model: use import library to import module from specified path -model_spec = importlib.util.spec_from_file_location( - "model", save_path + "/Summaries2/" + date + "/torch_run" + run + "/script/whittington_2020_model.py" -) -model = importlib.util.module_from_spec(model_spec) -model_spec.loader.exec_module(model) - -# Load the parameters of the model -params = torch.load(save_path + "/Summaries2/" + date + "/torch_run" + run + "/model/params_" + index + ".pt") -# Create a new tem model with the loaded parameters -tem = model.Model(params) -# Load the model weights after training -model_weights = torch.load(save_path + "/Summaries2/" + date + "/torch_run" + run + "/model/tem_" + index + ".pt") -# Set the model weights to the loaded trained model weights -tem.load_state_dict(model_weights) -# Make sure model is in evaluate mode (not crucial because it doesn't currently use dropout or batchnorm layers) -tem.eval() - -# Initialise environment parameters -batch_size = 16 -arena_x_limits = [ - [-5, 5], - [-4, 4], - [-5, 5], - [-6, 6], - [-4, 4], - [-5, 5], - [-6, 6], - [-5, 5], - [-4, 4], - [-5, 5], - [-6, 6], - [-5, 5], - [-4, 4], - [-5, 5], - [-6, 6], - [-5, 5], -] -arena_y_limits = [ - [-5, 5], - [-4, 4], - [-5, 5], - [-6, 6], - [-4, 4], - [-5, 5], - [-6, 6], - [-5, 5], - [-4, 4], - [-5, 5], - [-6, 6], - [-5, 5], - [-4, 4], - [-5, 5], - [-6, 6], - [-5, 5], -] -# arena_x_limits = [[-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10], -# [-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10]] -# arena_y_limits = [[-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1], -# [-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1]] -env_name = "env_example" -mod_name = "SimpleTEM" -time_step_size = 1 -state_density = 1 -agent_step_size = 1 / state_density -n_objects = 45 - -# Init simple 2D environment with discrtised objects -env_class = DiscreteObjectEnvironment -env = BatchEnvironment( - environment_name=env_name, - env_class=DiscreteObjectEnvironment, - batch_size=batch_size, - arena_x_limits=arena_x_limits, - arena_y_limits=arena_y_limits, - state_density=state_density, - n_objects=n_objects, - agent_step_size=agent_step_size, - use_behavioural_data=False, - data_path=None, - experiment_class=Sargolini2006Data, -) - -# Init TEM agent -agent = Whittington2020( - model_name=mod_name, - params=params, - batch_size=batch_size, - room_widths=env.room_widths, - room_depths=env.room_depths, - state_densities=env.state_densities, - use_behavioural_data=False, -) - -# # Run around environment -# observation, state = env.reset(random_state=True, custom_state=None) -# while agent.n_walk < 5000: -# if agent.n_walk % 100 == 0: -# print(agent.n_walk) -# action = agent.batch_act(observation) -# observation, state = env.step(action, normalize_step=True) -# model_input, history, environments = agent.collect_final_trajectory() -# environments = [env.collect_environment_info(model_input, history, environments)] - -# # Save environments and model_input using pickle -# with open('NPG_environments.pkl', 'wb') as f: -# pickle.dump(environments, f) -# with open('NPG_model_input.pkl', 'wb') as f: -# pickle.dump(model_input, f) - -# Load environments and model_input using pickle -with open("NPG_environments.pkl", "rb") as f: - environments = pickle.load(f) -with open("NPG_model_input.pkl", "rb") as f: - model_input = pickle.load(f) - -with torch.no_grad(): - forward = tem(model_input, prev_iter=None) -include_stay_still = False -shiny_envs = [False, False, False, False] -env_to_plot = 0 -envs_to_avg = shiny_envs if shiny_envs[env_to_plot] else [not shiny_env for shiny_env in shiny_envs] - -correct_model, correct_node, correct_edge = analyse.compare_to_agents( - forward, tem, environments, include_stay_still=include_stay_still -) -zero_shot = analyse.zero_shot(forward, tem, environments, include_stay_still=include_stay_still) -occupation = analyse.location_occupation(forward, tem, environments) -g, p = analyse.rate_map(forward, tem, environments) -from_acc, to_acc = analyse.location_accuracy(forward, tem, environments) - -# Plot rate maps for grid or place cells -agent.plot_rate_map(g) - -# Plot results of agent comparison and zero-shot inference analysis -filt_size = 41 -plt.figure() -plt.plot( - analyse.smooth( - np.mean(np.array([env for env_i, env in enumerate(correct_model) if envs_to_avg[env_i]]), 0)[1:], filt_size - ), - label="tem", -) -plt.plot( - analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_node) if envs_to_avg[env_i]]), 0)[1:], filt_size), - label="node", -) -plt.plot( - analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_edge) if envs_to_avg[env_i]]), 0)[1:], filt_size), - label="edge", -) -plt.ylim(0, 1) -plt.legend() -plt.title( - "Zero-shot inference: " - + str(np.mean([np.mean(env) for env_i, env in enumerate(zero_shot) if envs_to_avg[env_i]]) * 100) - + "%" -) - -# plt.show() diff --git a/examples/agent_examples/whittington_2020_plot_test.py b/examples/agent_examples/whittington_2020_plot_test.py deleted file mode 100644 index 28d3a5c..0000000 --- a/examples/agent_examples/whittington_2020_plot_test.py +++ /dev/null @@ -1,47 +0,0 @@ -import importlib -import os -import pickle - -import matplotlib.pyplot as plt -import pandas as pd - -from neuralplayground.plotting import PlotSim - -simulation_id = "TEM_custom_plot_sim" -save_path = os.path.join(os.getcwd(), "examples", "agent_examples", "results_sim") -training_dict = pd.read_pickle(os.path.join(save_path, "params.dict")) -model_weights = pd.read_pickle(os.path.join(save_path, "agent")) -model_spec = importlib.util.spec_from_file_location("model", save_path + "/whittington_2020_model.py") -model = importlib.util.module_from_spec(model_spec) -model_spec.loader.exec_module(model) -params = pd.read_pickle(os.path.join(save_path, "agent_hyper")) -tem = model.Model(params) -tem.load_state_dict(model_weights) -tem.eval() - -plotting_loop_params = {"n_episode": 50} -sim = PlotSim( - simulation_id=simulation_id, - agent_class=training_dict["agent_class"], - agent_params=training_dict["agent_params"], - env_class=training_dict["env_class"], - env_params=training_dict["env_params"], - plotting_loop_params=plotting_loop_params, -) -print(sim) -sim.plot_sim(save_path) - -# Load environments and model_input using pickle -with open(os.path.join(save_path, "NPG_environments.pkl"), "rb") as f: - environments = pickle.load(f) -with open(os.path.join(save_path, "NPG_model_input.pkl"), "rb") as f: - model_input = pickle.load(f) - -training_dict["params"] = training_dict["agent_params"] -del training_dict["agent_params"] -agent = training_dict["agent_class"](**training_dict["params"]) -agent.plot_run(tem, model_input, environments) - -# Plot rate maps for grid or place cells -agent.plot_rate_map(rate_map_type="g") -plt.show() diff --git a/examples/agent_examples/whittington_slurm.sh b/examples/agent_examples/whittington_slurm.sh deleted file mode 100644 index 55415c1..0000000 --- a/examples/agent_examples/whittington_slurm.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -#SBATCH -J TEM_beg # job name -#SBATCH -p gpu # partition (queue) -#SBATCH -N 1 # number of nodes -#SBATCH --mem 50G # memory pool for all cores -#SBATCH -n 4 # number of cores -#SBATCH -t 0-72:00 # time (D-HH:MM) -#SBATCH --gres gpu:1 # request 1 GPU (of any kind) -#SBATCH -o TEM_beg.%x.%N.%j.out # STDOUT -#SBATCH -e TEM_beg.%x.%N.%j.err # STDERR - -source ~/.bashrc - -conda activate NPG-env - -python whittington_2020_run.py - -exit diff --git a/neuralplayground/arenas/batch_environment.py b/neuralplayground/arenas/batch_environment.py index 81d6262..897e1d4 100644 --- a/neuralplayground/arenas/batch_environment.py +++ b/neuralplayground/arenas/batch_environment.py @@ -7,26 +7,25 @@ class BatchEnvironment(Environment): + """ + Class to handle a batch of environments, where each environment is an instance of the same class. This is useful for training a single agent on multiple environments simultaneously. + ---------- + environment_name: str + Name of the environment + env_class: object + Class of the environment + batch_size: int + Number of environments in the batch + **env_kwargs: dict + Keyword arguments for the environment + """ def __init__( self, environment_name: str = "BatchEnv", env_class: object = DiscreteObjectEnvironment, batch_size: int = 1, **env_kwargs, - ): - """ - Initialise a batch of environments. This is useful for training a single agent on multiple environments simultaneously. - Parameters - ---------- - environment_name: str - Name of the environment - env_class: object - Class of the environment - batch_size: int - Number of environments in the batch - **env_kwargs: dict - Keyword arguments for the environment - """ + ): super().__init__(environment_name, **env_kwargs) self.env_kwargs = env_kwargs.copy() arg_env_params = env_kwargs["arg_env_params"]