Skip to content

Commit

Permalink
Whittington 2020 (#70)
Browse files Browse the repository at this point in the history
* new TEM run file

* TEM saving variables and running in varied environments

* example plots added to notebook

* tests updated

* only model parameters

* variable saving added

* updated run and notebook files

* update model files

* updated arena files

* Delete whittington_2020_examples.ipynb

* Delete whittington_2020_test.py

* label readme

* includes torch dependencies

* includes torch

* update environment classes

* updated agent functions

* updated training and plotting scripts

* requirements updated

* updated plotting notebook

* added pytorch dependency

* torch=1.12.1 not found

* added pytorch dependency

* added pytorch dependency

* TEM tests added

* improved plotting

* summary files for example TEM results

* example results for plotting

* default parameters

* Update whittington_2020_example.ipynb

* beg

* environment duplication bug fix

* config classes done

* zero shot accuracy fixed

* tensorboard added to requirements and model saving added

* state (x,y) causing grid plotting issues

* bug in position rounding and centering found

* position centering and rounding added

* plotting error in envrionment variable

* plotting finctions working - NPG data added under /torch_run3

* gridscore metric

* add metric

* modified agent and exp for the metric

* updated metric- allows for 2 D

* Creado mediante Colaboratory

* Creado mediante Colaboratory

* colab example

* fixing colab

* Creado mediante Colaboratory

* colab on readme

* open in colab

* fixing path to colab

* setting colab env

* colab

* 2023-05-17 contains both NPG and original models

* updating colab installation

* colab example skeleton done, need markdowns and explanation

* adding use of behavioural data to batched environment

* adding option to use behavioural data

* colab example from main

* adding behavioural trajectory

* config file running properly

* adding behavioural trajectories

* commenting config module, need to config other modules

* data path setup for behavioural trajectory use

* behaviorual traj

* TEM running on behavioural data

* added robust data path creation and access

* fixing action-transition discrepencies

* fixing action-transition discrepency

* action generation added

* agent generation added

* config file comment

* action-transition discrepency fixed

* backend for automatic simulation

* single sim manager running properly

* generating dir when runing sim

* generalised plotting function

* wernle_2018

* forgot the pre-commit

* generalised rate map for experiments

* generalised rate plotting

* make plotting a module

* update metric

* testing sim manager

* metric that works + get rid of developement plot+ same fontsize

* pre commit

* saving entire object with pickle

* config input

* dict to json

* saving params as dict

* comparison from the run manadger

* status checker and load simulation

* status checker and load simulation

* merge corrections

* fix hafting

* get ratemap matrix done for all agents, documentation of simulation manager done

* merging room

* weber plot_rates.py function

* random action policy working in updated branch

* plot rates

* nice plotting

* update rates function + test comparison figure on other agents

* comparison sargo

* update

* fixing TEM plotting

* wenrnel tetrode, get_grid score, title figure , table figure and config update

* environment variables fixed

* new saved TEM models added

* update the jupyter notebooks

* grid scorere

* New plotting Functions

* Really cool jupyter

* update

* width

* score

* the score

* juypter

* jupyter

* simulation manager notebook

* simulation manager notebook

* back to the previous setting

* Update metrics.py

Changed gridscore calc from using np.min and np.max to np.nanmin and np.nanmax

* standardised plotting of TEM results added

* passing test

* backend run example

* simulation manager example almost done

* changes

* cleaning up plotting code

* simulation manager example done

* plot_utils

* merge changes from comparison_board

* removed additional summaries

* commenting added to new files

* pre-commit changes made

* for test to pass

* fix manifest

---------

Co-authored-by: LukeHollingsworth <[email protected]>
Co-authored-by: Luke Hollingsworth <[email protected]>
Co-authored-by: rodrigcd <[email protected]>
Co-authored-by: Luke Hollingsworth <[email protected]>
Co-authored-by: rodrigcd <[email protected]>
Co-authored-by: rhayman <[email protected]>
  • Loading branch information
7 people authored Jul 28, 2023
1 parent 2234c7d commit 159ccbd
Show file tree
Hide file tree
Showing 12 changed files with 4,235 additions and 2 deletions.
305 changes: 305 additions & 0 deletions examples/agent_examples/whittington_2020_example.ipynb

Large diffs are not rendered by default.

186 changes: 186 additions & 0 deletions examples/agent_examples/whittington_2020_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Standard Imports
import importlib.util
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch

import neuralplayground.agents.whittington_2020_extras.whittington_2020_analyse as analyse
from neuralplayground.agents.whittington_2020 import Whittington2020
from neuralplayground.arenas.batch_environment import BatchEnvironment

# NeuralPlayground Imports
from neuralplayground.arenas.discritized_objects import DiscreteObjectEnvironment

# NeuralPlayground Experiment Imports
from neuralplayground.experiments import Sargolini2006Data

# Select trained model
date = "2023-05-17"
run = "0"
index = "19999"
base_path = "/nfs/nhome/live/lhollingsworth/Documents/NeuralPlayground/NPG/EHC_model_comparison"
npg_path = "/nfs/nhome/live/lhollingsworth/Documents/NeuralPlayground/NPG/EHC_model_comparison/examples"
base_win_path = "H:/Documents/PhD/NeuralPlayground"
win_path = "H:/Documents/PhD/NeuralPlayground/NPG/NeuralPlayground/examples"
# Load the model: use import library to import module from specified path
model_spec = importlib.util.spec_from_file_location(
"model", win_path + "/Summaries/" + date + "/torch_run" + run + "/script/whittington_2020_model.py"
)
model = importlib.util.module_from_spec(model_spec)
model_spec.loader.exec_module(model)

# Load the parameters of the model
params = torch.load(win_path + "/Summaries/" + date + "/torch_run" + run + "/model/params_" + index + ".pt")
# Create a new tem model with the loaded parameters
tem = model.Model(params)
# Load the model weights after training
model_weights = torch.load(win_path + "/Summaries/" + date + "/torch_run" + run + "/model/tem_" + index + ".pt")
# Set the model weights to the loaded trained model weights
tem.load_state_dict(model_weights)
# Make sure model is in evaluate mode (not crucial because it doesn't currently use dropout or batchnorm layers)
tem.eval()

# Initialise environment parameters
batch_size = 16
arena_x_limits = [
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
]
arena_y_limits = [
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
]
# arena_x_limits = [[-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10],
# [-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10]]
# arena_y_limits = [[-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1],
# [-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1]]
env_name = "env_example"
mod_name = "SimpleTEM"
time_step_size = 1
state_density = 1
agent_step_size = 1 / state_density
n_objects = 45

# Init simple 2D environment with discrtised objects
env_class = DiscreteObjectEnvironment
env = BatchEnvironment(
environment_name=env_name,
env_class=DiscreteObjectEnvironment,
batch_size=batch_size,
arena_x_limits=arena_x_limits,
arena_y_limits=arena_y_limits,
state_density=state_density,
n_objects=n_objects,
agent_step_size=agent_step_size,
use_behavioural_data=False,
data_path=None,
experiment_class=Sargolini2006Data,
)

# Init TEM agent
agent = Whittington2020(
model_name=mod_name,
params=params,
batch_size=batch_size,
room_widths=env.room_widths,
room_depths=env.room_depths,
state_densities=env.state_densities,
use_behavioural_data=False,
)

# # Run around environment
# observation, state = env.reset(random_state=True, custom_state=None)
# while agent.n_walk < 5000:
# if agent.n_walk % 100 == 0:
# print(agent.n_walk)
# action = agent.batch_act(observation)
# observation, state = env.step(action, normalize_step=True)
# model_input, history, environments = agent.collect_final_trajectory()
# environments = [env.collect_environment_info(model_input, history, environments)]

# # Save environments and model_input using pickle
# with open('NPG_environments.pkl', 'wb') as f:
# pickle.dump(environments, f)
# with open('NPG_model_input.pkl', 'wb') as f:
# pickle.dump(model_input, f)

# Load environments and model_input using pickle
with open("NPG_environments.pkl", "rb") as f:
environments = pickle.load(f)
with open("NPG_model_input.pkl", "rb") as f:
model_input = pickle.load(f)

with torch.no_grad():
forward = tem(model_input, prev_iter=None)
include_stay_still = False
shiny_envs = [False, False, False, False]
env_to_plot = 0
envs_to_avg = shiny_envs if shiny_envs[env_to_plot] else [not shiny_env for shiny_env in shiny_envs]

correct_model, correct_node, correct_edge = analyse.compare_to_agents(
forward, tem, environments, include_stay_still=include_stay_still
)
zero_shot = analyse.zero_shot(forward, tem, environments, include_stay_still=include_stay_still)
occupation = analyse.location_occupation(forward, tem, environments)
g, p = analyse.rate_map(forward, tem, environments)
from_acc, to_acc = analyse.location_accuracy(forward, tem, environments)

# Plot rate maps for grid or place cells
agent.plot_rate_map(g)

# Plot results of agent comparison and zero-shot inference analysis
filt_size = 41
plt.figure()
plt.plot(
analyse.smooth(
np.mean(np.array([env for env_i, env in enumerate(correct_model) if envs_to_avg[env_i]]), 0)[1:], filt_size
),
label="tem",
)
plt.plot(
analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_node) if envs_to_avg[env_i]]), 0)[1:], filt_size),
label="node",
)
plt.plot(
analyse.smooth(np.mean(np.array([env for env_i, env in enumerate(correct_edge) if envs_to_avg[env_i]]), 0)[1:], filt_size),
label="edge",
)
plt.ylim(0, 1)
plt.legend()
plt.title(
"Zero-shot inference: "
+ str(np.mean([np.mean(env) for env_i, env in enumerate(zero_shot) if envs_to_avg[env_i]]) * 100)
+ "%"
)

# plt.show()
124 changes: 124 additions & 0 deletions examples/agent_examples/whittington_2020_run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
Run file for the Tolman-Eichenbaum Machine (TEM) model from Whittington et al. 2020. An example setup is provided, with
TEM learning to predict upcoming sensory stimulus in a range of 16 square environments of varying sizes.
"""

# Standard Imports

import matplotlib.pyplot as plt

# NeuralPlayground Agent Imports
import neuralplayground.agents.whittington_2020_extras.whittington_2020_parameters as parameters
from neuralplayground.agents.whittington_2020 import Whittington2020
from neuralplayground.arenas.batch_environment import BatchEnvironment

# NeuralPlayground Arena Imports
from neuralplayground.arenas.discritized_objects import DiscreteObjectEnvironment

# NeuralPlayground Experiment Imports
from neuralplayground.experiments import Sargolini2006Data

# Initialise TEM Parameters
pars_orig = parameters.parameters()
params = pars_orig.copy()

# Initialise environment parameters
batch_size = 16
arena_x_limits = [
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
]
arena_y_limits = [
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
[-4, 4],
[-5, 5],
[-6, 6],
[-5, 5],
]
# arena_x_limits = [[-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10],
# [-20,20], [-20,20], [-15,15], [-10,10], [-20,20], [-20,20], [-15,15], [-10,10]]
# arena_y_limits = [[-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1],
# [-4,4], [-2,2], [-2,2], [-1,1], [-4,4], [-2,2], [-2,2], [-1,1]]
env_name = "Sargolini2006"
mod_name = "SimpleTEM"
time_step_size = 1
state_density = 1
agent_step_size = 1 / state_density
n_objects = 45

# # Init environment from Hafting 2008 (optional, if chosen, comment out the )
# env = Hafting2008(agent_step_size=agent_step_size,
# time_step_size=time_step_size,
# use_behavioral_data=False)

# # Init simple 2D (batched) environment with discrtised objects
# env_class = DiscreteObjectEnvironment

# Init environment from Sargolini, with behavioural data instead of random walk
env = BatchEnvironment(
environment_name=env_name,
env_class=DiscreteObjectEnvironment,
batch_size=batch_size,
arena_x_limits=arena_x_limits,
arena_y_limits=arena_y_limits,
state_density=state_density,
n_objects=n_objects,
agent_step_size=agent_step_size,
use_behavioural_data=False,
data_path=None,
experiment_class=Sargolini2006Data,
)

# Init TEM agent
agent = Whittington2020(
model_name=mod_name,
params=params,
batch_size=batch_size,
room_widths=env.room_widths,
room_depths=env.room_depths,
state_densities=env.state_densities,
use_behavioural_data=False,
)

# Reset environment and begin training (random_state=True is currently necessary)
observation, state = env.reset(random_state=True, custom_state=None)
for i in range(3):
print("Iteration: ", i)
while agent.n_walk < params["n_rollout"]:
actions = agent.batch_act(observation)
observation, state = env.step(actions, normalize_step=True)
agent.update()

# Plot most recent trajectory of the first environment in batch
ax = env.plot_trajectory()
fontsize = 18
ax.grid()
ax.set_xlabel("width", fontsize=fontsize)
ax.set_ylabel("depth", fontsize=fontsize)
plt.savefig("trajectory.png")
plt.show()
Loading

0 comments on commit 159ccbd

Please sign in to comment.