Skip to content

Commit

Permalink
pre-commit run on all files
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeHollingsworth committed Jul 16, 2024
1 parent 7de0832 commit 5d33231
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 84 deletions.
2 changes: 1 addition & 1 deletion examples/agent_examples/whittington_2020_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,4 @@
+ "%"
)

# plt.show()
# plt.show()
2 changes: 1 addition & 1 deletion examples/agent_examples/whittington_2020_plot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@

# Plot rate maps for grid or place cells
agent.plot_rate_map(rate_map_type="g")
plt.show()
plt.show()
4 changes: 2 additions & 2 deletions examples/agent_examples/whittington_2020_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
"environment_name": "DiscreteObject",
"state_density": 1,
"n_objects": params["n_x"],
"agent_step_size": 1, Note: this must be 1 / state density
"agent_step_size": 1, # Note: this must be 1 / state density
"use_behavioural_data": False,
"data_path": None,
"experiment_class": Sargolini2006Data,
Expand Down Expand Up @@ -106,7 +106,7 @@
}

# Full model training consists of 20000 episodes
training_loop_params = {"n_episode": 10000, "params": full_agent_params,'random_state': False, 'custom_state': [0.0, 0.0]}
training_loop_params = {"n_episode": 10000, "params": full_agent_params, "random_state": False, "custom_state": [0.0, 0.0]}

# Create the training simulation object
sim = SingleSim(
Expand Down
12 changes: 6 additions & 6 deletions examples/agent_examples/whittington_slurm.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash
# Set the job name variable

#SBATCH --job-name=TEM_update_test
#SBATCH --mem=20000 # memory pool for all cores
#SBATCH --time=72:00:00 # time
#SBATCH -o TEM_logs/TEM_update_test.%N.%j.out # STDOUT
#SBATCH -e TEM_logs/TEM_update_test.%N.%j.err # STDERR
#SBATCH --mem=20000
#SBATCH --time=72:00:00
#SBATCH -o TEM_logs/TEM_update_test.%N.%j.out
#SBATCH -e TEM_logs/TEM_update_test.%N.%j.err
#SBATCH -p gpu
#SBATCH --gres=gpu:1
#SBATCH --ntasks=1
Expand All @@ -15,4 +15,4 @@ conda activate NPG-env

python whittington_2020_run.py

exit
exit
38 changes: 17 additions & 21 deletions neuralplayground/agents/whittington_2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,10 @@ def update(self):
self.logger.info("Weights:" + str([w for w in loss_weights.numpy()]))
self.logger.info(" ")

self.accuracy_history['iter'].append(self.iter)
self.accuracy_history['p_accuracy'].append(acc_p)
self.accuracy_history['g_accuracy'].append(acc_g)
self.accuracy_history['gt_accuracy'].append(acc_gt)
self.accuracy_history["iter"].append(self.iter)
self.accuracy_history["p_accuracy"].append(acc_p)
self.accuracy_history["g_accuracy"].append(acc_g)
self.accuracy_history["gt_accuracy"].append(acc_gt)

# Save accuracies periodically (e.g., every 100 iterations)
if (self.iter) % 100 == 0:
Expand Down Expand Up @@ -356,11 +356,7 @@ def initialise(self):
# Initialise whether a state has been visited for each world
self.visited = [[False for _ in range(self.n_states[env])] for env in range(self.pars["batch_size"])]
self.prev_iter = None
self.accuracy_history = {
'iter': [],
'p_accuracy': [],
'g_accuracy': [],
'gt_accuracy': []}
self.accuracy_history = {"iter": [], "p_accuracy": [], "g_accuracy": [], "gt_accuracy": []}

def save_agent(self, save_path: str):
"""Save current state and information in general to re-instantiate the agent
Expand Down Expand Up @@ -409,13 +405,13 @@ def save_files(self):
os.path.join(self.script_path, "whittington_2020_utils.py"),
)
return

def save_accuracies(self):
current_dir = os.path.dirname(os.getcwd())
run_path = os.path.join(current_dir, "agent_examples", self.save_name)
run_path = os.path.normpath(run_path)
accuracy_file = os.path.join(run_path, f"accuracies.pkl")
with open(accuracy_file, 'wb') as f:
accuracy_file = os.path.join(run_path, "accuracies.pkl")
with open(accuracy_file, "wb") as f:
pickle.dump(self.accuracy_history, f)

def action_policy(self):
Expand Down Expand Up @@ -595,22 +591,22 @@ def get_rate_map_matrix(self, rate_maps, i, j):
int(self.room_depths[0] * self.state_densities[0]),
),
)

def plot_accuracies(self, save_path=None):
accuracy_data = self.accuracy_history
plt.figure(figsize=(10, 6))
plt.plot(accuracy_data['iter'], accuracy_data['p_accuracy'], label='p accuracy')
plt.plot(accuracy_data['iter'], accuracy_data['g_accuracy'], label='g accuracy')
plt.plot(accuracy_data['iter'], accuracy_data['gt_accuracy'], label='gt accuracy')
plt.xlabel('Iteration')
plt.ylabel('Accuracy (%)')
plt.title('TEM Model Accuracies Over Training')
plt.plot(accuracy_data["iter"], accuracy_data["p_accuracy"], label="p accuracy")
plt.plot(accuracy_data["iter"], accuracy_data["g_accuracy"], label="g accuracy")
plt.plot(accuracy_data["iter"], accuracy_data["gt_accuracy"], label="gt accuracy")
plt.xlabel("Iteration")
plt.ylabel("Accuracy (%)")
plt.title("TEM Model Accuracies Over Training")
plt.legend()
plt.grid(True)

if save_path:
plt.savefig(save_path)
else:
plt.show()

plt.close()
2 changes: 1 addition & 1 deletion neuralplayground/arenas/batch_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def plot_trajectories(self):
for i, environment in enumerate(self.environments):
environment.history = [sublist[i] for sublist in self.history]
axs[i] = environment.plot_trajectory(ax=axs[i])
axs[i].set_aspect('equal')
axs[i].set_aspect("equal")
axs[i].set_title(f"Environment {i+1}")

# Adjust spacing between subplots
Expand Down
143 changes: 92 additions & 51 deletions neuralplayground/arenas/discritized_objects.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import random

import cv2
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.colors import LogNorm

from neuralplayground.arenas.arena_core import Environment
from neuralplayground.plotting.plot_utils import make_plot_trajectories
Expand Down Expand Up @@ -111,14 +109,14 @@ def __init__(
self.resolution_d = int(self.room_depth * self.state_density)
self.state_size = 1 / self.state_density

self.x_array = np.linspace(self.arena_x_limits[0] + self.state_size/2,
self.arena_x_limits[1] - self.state_size/2,
self.resolution_w)
self.y_array = np.linspace(self.arena_y_limits[0] + self.state_size/2,
self.arena_y_limits[1] - self.state_size/2,
self.resolution_d)
self.x_array = np.linspace(
self.arena_x_limits[0] + self.state_size / 2, self.arena_x_limits[1] - self.state_size / 2, self.resolution_w
)
self.y_array = np.linspace(
self.arena_y_limits[0] + self.state_size / 2, self.arena_y_limits[1] - self.state_size / 2, self.resolution_d
)
self.mesh = np.meshgrid(self.x_array, self.y_array)
self.xy_combination = np.column_stack([self.mesh[0].ravel(), self.mesh[1].ravel()])
self.xy_combination = np.column_stack([self.mesh[0].ravel(), self.mesh[1].ravel()])
self.n_states = self.resolution_w * self.resolution_d
self.objects = self.generate_objects()
self.occupancy_grid = np.zeros((self.resolution_d, self.resolution_w))
Expand Down Expand Up @@ -248,7 +246,6 @@ def generate_objects(self):
objects = np.eye(self.n_objects)[object_indices]
return objects


def make_object_observation(self, pos):
"""
Make an observation of the object in the environment at the current position.
Expand All @@ -272,14 +269,14 @@ def pos_to_state(self, pos):
pos = pos[0]
elif self.use_behavioral_data and len(pos) > 2:
pos = pos[:2]

x_index = np.floor((pos[0] - self.arena_x_limits[0]) / self.state_size).astype(int)
y_index = np.floor((pos[1] - self.arena_y_limits[0]) / self.state_size).astype(int)

# Ensure indices are within bounds
x_index = np.clip(x_index, 0, self.resolution_w - 1)
y_index = np.clip(y_index, 0, self.resolution_d - 1)

return y_index * self.resolution_w + x_index

def _create_default_walls(self):
Expand All @@ -293,19 +290,38 @@ def _create_default_walls(self):
"""
self.default_walls = []
self.default_walls.append(
np.array([[self.arena_limits[0, 0] - 0.1, self.arena_limits[1, 0] + 0.1], [self.arena_limits[0, 0] - 0.1, self.arena_limits[1, 1] + 0.1]])
np.array(
[
[self.arena_limits[0, 0] - 0.1, self.arena_limits[1, 0] + 0.1],
[self.arena_limits[0, 0] - 0.1, self.arena_limits[1, 1] + 0.1],
]
)
)
self.default_walls.append(
np.array([[self.arena_limits[0, 1] + 0.1, self.arena_limits[1, 0] - 0.1], [self.arena_limits[0, 1] + 0.1, self.arena_limits[1, 1] + 0.1]])
np.array(
[
[self.arena_limits[0, 1] + 0.1, self.arena_limits[1, 0] - 0.1],
[self.arena_limits[0, 1] + 0.1, self.arena_limits[1, 1] + 0.1],
]
)
)
self.default_walls.append(
np.array([[self.arena_limits[0, 0] - 0.1, self.arena_limits[1, 0] - 0.1], [self.arena_limits[0, 1] + 0.1, self.arena_limits[1, 0] - 0.1]])
np.array(
[
[self.arena_limits[0, 0] - 0.1, self.arena_limits[1, 0] - 0.1],
[self.arena_limits[0, 1] + 0.1, self.arena_limits[1, 0] - 0.1],
]
)
)
self.default_walls.append(
np.array([[self.arena_limits[0, 0] - 0.1, self.arena_limits[1, 1] + 0.1], [self.arena_limits[0, 1] + 0.1, self.arena_limits[1, 1] + 0.1]])
np.array(
[
[self.arena_limits[0, 0] - 0.1, self.arena_limits[1, 1] + 0.1],
[self.arena_limits[0, 1] + 0.1, self.arena_limits[1, 1] + 0.1],
]
)
)


def _create_custom_walls(self):
"""Custom walls method. In this case is empty since the environment is a simple square room
Override this method to generate more walls, see jupyter notebook with examples"""
Expand Down Expand Up @@ -333,12 +349,12 @@ def validate_action(self, pre_state, action, new_state):
for wall in self.wall_list:
new_state, crossed = check_crossing_wall(pre_state=pre_state, new_state=np.asarray(new_state), wall=wall)
crossed_wall = crossed or crossed_wall

# Snap the new_state back to the nearest discrete state
x_index = np.argmin(np.abs(self.x_array - new_state[0]))
y_index = np.argmin(np.abs(self.y_array - new_state[1]))
new_state = np.array([self.x_array[x_index], self.y_array[y_index]])

return new_state, crossed_wall

def plot_trajectory(
Expand Down Expand Up @@ -424,31 +440,43 @@ def visualize_environment(self):
# Visualize discretization
ax1.set_title("Environment Discretization")
for x in np.arange(self.arena_x_limits[0], self.arena_x_limits[1] + self.state_size, self.state_size):
ax1.axvline(x, color='gray', linestyle='-', linewidth=1)
ax1.axvline(x, color="gray", linestyle="-", linewidth=1)
for y in np.arange(self.arena_y_limits[0], self.arena_y_limits[1] + self.state_size, self.state_size):
ax1.axhline(y, color='gray', linestyle='-', linewidth=1)
ax1.scatter(self.xy_combination[:, 0], self.xy_combination[:, 1], color='red', s=20, zorder=2)
ax1.set_aspect('equal')
ax1.axhline(y, color="gray", linestyle="-", linewidth=1)
ax1.scatter(self.xy_combination[:, 0], self.xy_combination[:, 1], color="red", s=20, zorder=2)
ax1.set_aspect("equal")
ax1.set_xlim(self.arena_x_limits)
ax1.set_ylim(self.arena_y_limits)
ax1.set_xlabel("X")
ax1.set_ylabel("Y")

# Visualize object assignment
ax2.set_title("Object Assignment" + f" (n_objects={self.n_objects})," + f" (n_states={self.n_states})," + f" (grid={self.resolution_w}x{self.resolution_d})")
ax2.set_title(
"Object Assignment"
+ f" (n_objects={self.n_objects}),"
+ f" (n_states={self.n_states}),"
+ f" (grid={self.resolution_w}x{self.resolution_d})"
)
object_grid = np.argmax(self.objects, axis=1).reshape((self.resolution_d, self.resolution_w))
im = ax2.imshow(object_grid, cmap='tab20', extent=[*self.arena_x_limits, *self.arena_y_limits], origin='lower')
im = ax2.imshow(object_grid, cmap="tab20", extent=[*self.arena_x_limits, *self.arena_y_limits], origin="lower")
plt.colorbar(im, ax=ax2, label="Object ID")

# Add text labels for object IDs and scatter plot for xy_combination
for i in range(self.resolution_d):
for j in range(self.resolution_w):
ax2.text(self.x_array[j], self.y_array[i], str(object_grid[i, j]),
ha='center', va='center', color='white', fontweight='bold')

ax2.scatter(self.xy_combination[:, 0], self.xy_combination[:, 1], color='red', s=20, zorder=2)

ax2.set_aspect('equal')
ax2.text(
self.x_array[j],
self.y_array[i],
str(object_grid[i, j]),
ha="center",
va="center",
color="white",
fontweight="bold",
)

ax2.scatter(self.xy_combination[:, 0], self.xy_combination[:, 1], color="red", s=20, zorder=2)

ax2.set_aspect("equal")
ax2.set_xlim(self.arena_x_limits)
ax2.set_ylim(self.arena_y_limits)
ax2.set_xlabel("X")
Expand All @@ -464,7 +492,7 @@ def visualize_environment(self):
print(f"Number of discrete states: {self.resolution_w * self.resolution_d}")
print(f"Number of unique objects: {self.n_objects}")
print(f"Grid dimensions: {self.resolution_w} x {self.resolution_d}")

# Print object distribution
unique, counts = np.unique(np.argmax(self.objects, axis=1), return_counts=True)
print("\nObject distribution:")
Expand All @@ -474,32 +502,45 @@ def visualize_environment(self):
def visualize_occupancy(self, log_scale=True):
fig, ax = plt.subplots(figsize=(12, 10))
cmap = plt.cm.YlOrRd

if log_scale:
# Use log scale for better visualization of differences
im = ax.imshow(self.occupancy_grid, cmap=cmap, norm=LogNorm(), extent=[*self.arena_x_limits, *self.arena_y_limits], origin='lower')
im = ax.imshow(
self.occupancy_grid,
cmap=cmap,
norm=LogNorm(),
extent=[*self.arena_x_limits, *self.arena_y_limits],
origin="lower",
)
else:
im = ax.imshow(self.occupancy_grid, cmap=cmap, extent=[*self.arena_x_limits, *self.arena_y_limits], origin='lower')
plt.colorbar(im, ax=ax, label='Number of visits (log scale)' if log_scale else 'Number of visits')
ax.set_title('Agent Occupancy Heatmap')
ax.set_xlabel('X')
ax.set_ylabel('Y')
im = ax.imshow(self.occupancy_grid, cmap=cmap, extent=[*self.arena_x_limits, *self.arena_y_limits], origin="lower")

plt.colorbar(im, ax=ax, label="Number of visits (log scale)" if log_scale else "Number of visits")

ax.set_title("Agent Occupancy Heatmap")
ax.set_xlabel("X")
ax.set_ylabel("Y")

# Add grid lines
for x in np.arange(self.arena_x_limits[0], self.arena_x_limits[1] + self.state_size, self.state_size):
ax.axvline(x, color='gray', linestyle='--', linewidth=0.5)
ax.axvline(x, color="gray", linestyle="--", linewidth=0.5)
for y in np.arange(self.arena_y_limits[0], self.arena_y_limits[1] + self.state_size, self.state_size):
ax.axhline(y, color='gray', linestyle='--', linewidth=0.5)
ax.axhline(y, color="gray", linestyle="--", linewidth=0.5)

# Add text annotations for each cell
for i in range(self.resolution_d):
for j in range(self.resolution_w):
value = self.occupancy_grid[i, j]
text_color = 'white' if value > np.mean(self.occupancy_grid) else 'black'
ax.text(self.x_array[j], self.y_array[i], f'{int(value)}',
ha='center', va='center', color=text_color, fontweight='bold')

text_color = "white" if value > np.mean(self.occupancy_grid) else "black"
ax.text(
self.x_array[j],
self.y_array[i],
f"{int(value)}",
ha="center",
va="center",
color=text_color,
fontweight="bold",
)

plt.tight_layout()
plt.show()
4 changes: 3 additions & 1 deletion neuralplayground/backend/training_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def episode_based_training_loop(agent: AgentCore, env: Environment, t_episode: i
return agent, env, dict_training


def tem_training_loop(agent: AgentCore, env: Environment, n_episode: int, params: dict, random_state: bool = True, custom_state: list = None):
def tem_training_loop(
agent: AgentCore, env: Environment, n_episode: int, params: dict, random_state: bool = True, custom_state: list = None
):
"""Training loop for agents and environments that use a TEM-based update.
Parameters
Expand Down

0 comments on commit 5d33231

Please sign in to comment.