Skip to content

Commit

Permalink
updated test
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeHollingsworth committed Jul 16, 2024
1 parent 94b8ac8 commit 7de0832
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 7 deletions.
4 changes: 2 additions & 2 deletions examples/agent_examples/whittington_2020_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from neuralplayground.experiments import Sargolini2006Data

# Set the location for saving the results of the simulation
simulation_id = "TEM_results_new_states"
simulation_id = "TEM_results_update_test"
save_path = os.path.join(os.getcwd(), simulation_id)
# save_path = os.path.join(os.getcwd(), "examples", "agent_examples", "trained_results")
agent_class = Whittington2020
Expand Down Expand Up @@ -106,7 +106,7 @@
}

# Full model training consists of 20000 episodes
training_loop_params = {"n_episode": 20000, "params": full_agent_params,'random_state': False, 'custom_state': [0.0, 0.0]}
training_loop_params = {"n_episode": 10000, "params": full_agent_params,'random_state': False, 'custom_state': [0.0, 0.0]}

# Create the training simulation object
sim = SingleSim(
Expand Down
6 changes: 3 additions & 3 deletions examples/agent_examples/whittington_slurm.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash
# Set the job name variable
#SBATCH --job-name=TEM2_new_states
#SBATCH --job-name=TEM_update_test
#SBATCH --mem=20000 # memory pool for all cores
#SBATCH --time=72:00:00 # time
#SBATCH -o TEM_logs/TEM2_new_states.%N.%j.out # STDOUT
#SBATCH -e TEM_logs/TEM2_new_states.%N.%j.err # STDERR
#SBATCH -o TEM_logs/TEM_update_test.%N.%j.out # STDOUT
#SBATCH -e TEM_logs/TEM_update_test.%N.%j.err # STDERR
#SBATCH -p gpu
#SBATCH --gres=gpu:1
#SBATCH --ntasks=1
Expand Down
40 changes: 38 additions & 2 deletions neuralplayground/arenas/discritized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import cv2
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import numpy as np
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

Expand Down Expand Up @@ -118,10 +119,9 @@ def __init__(
self.resolution_d)
self.mesh = np.meshgrid(self.x_array, self.y_array)
self.xy_combination = np.column_stack([self.mesh[0].ravel(), self.mesh[1].ravel()])
self.ws = int(self.room_width * self.state_density)
self.hs = int(self.room_depth * self.state_density)
self.n_states = self.resolution_w * self.resolution_d
self.objects = self.generate_objects()
self.occupancy_grid = np.zeros((self.resolution_d, self.resolution_w))

def reset(self, random_state=True, custom_state=None):
"""
Expand Down Expand Up @@ -169,6 +169,7 @@ def reset(self, random_state=True, custom_state=None):
custom_state = np.concatenate([pos, head_dir])

self.objects = self.generate_objects()
self.occupancy_grid = np.zeros((self.resolution_d, self.resolution_w))

# Fully observable environment, make_observation returns the state
observation = self.make_object_observation(pos)
Expand Down Expand Up @@ -224,6 +225,8 @@ def step(self, action: np.ndarray, normalize_step: bool = True, skip_every: int
)
reward = self.reward_function(action, self.state[-1]) # If you get reward, it should be coded here
observation = self.make_object_observation(new_pos_state)
state_index = self.pos_to_state(new_pos_state)
self.occupancy_grid[state_index // self.resolution_w, state_index % self.resolution_w] += 1
self.state = observation
self.transition = {
"action": action,
Expand Down Expand Up @@ -467,3 +470,36 @@ def visualize_environment(self):
print("\nObject distribution:")
for obj, count in zip(unique, counts):
print(f"Object {obj}: {count} states ({count/(self.resolution_w * self.resolution_d):.2%})")

def visualize_occupancy(self, log_scale=True):
fig, ax = plt.subplots(figsize=(12, 10))
cmap = plt.cm.YlOrRd

if log_scale:
# Use log scale for better visualization of differences
im = ax.imshow(self.occupancy_grid, cmap=cmap, norm=LogNorm(), extent=[*self.arena_x_limits, *self.arena_y_limits], origin='lower')
else:
im = ax.imshow(self.occupancy_grid, cmap=cmap, extent=[*self.arena_x_limits, *self.arena_y_limits], origin='lower')

plt.colorbar(im, ax=ax, label='Number of visits (log scale)' if log_scale else 'Number of visits')

ax.set_title('Agent Occupancy Heatmap')
ax.set_xlabel('X')
ax.set_ylabel('Y')

# Add grid lines
for x in np.arange(self.arena_x_limits[0], self.arena_x_limits[1] + self.state_size, self.state_size):
ax.axvline(x, color='gray', linestyle='--', linewidth=0.5)
for y in np.arange(self.arena_y_limits[0], self.arena_y_limits[1] + self.state_size, self.state_size):
ax.axhline(y, color='gray', linestyle='--', linewidth=0.5)

# Add text annotations for each cell
for i in range(self.resolution_d):
for j in range(self.resolution_w):
value = self.occupancy_grid[i, j]
text_color = 'white' if value > np.mean(self.occupancy_grid) else 'black'
ax.text(self.x_array[j], self.y_array[i], f'{int(value)}',
ha='center', va='center', color=text_color, fontweight='bold')

plt.tight_layout()
plt.show()
3 changes: 3 additions & 0 deletions neuralplayground/backend/simulation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,9 @@ def run_sim(self, save_path: str = None):
print("---> Training loop")
trained_agent, trained_env, training_hist = self.training_loop(agent, env, **self.training_loop_params)

# for i in range(16):
# trained_env.environments[i].visualize_occupancy()

# Saving models
print("---> Saving models")
self._save_models(save_path, trained_agent, trained_env, training_hist)
Expand Down

0 comments on commit 7de0832

Please sign in to comment.