Skip to content

Commit

Permalink
recent TEM updates
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeHollingsworth committed Oct 7, 2024
1 parent 0ead088 commit 1fe31ab
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 28 deletions.
54 changes: 32 additions & 22 deletions examples/agent_examples/whittington_2020_example.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions examples/agent_examples/whittington_2020_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from neuralplayground.experiments import Sargolini2006Data

# Set the location for saving the results of the simulation
simulation_id = "TEM_test_5x5_small_walk"
simulation_id = "TEM_test_with_break"
save_path = os.path.join(os.getcwd(), simulation_id)
# save_path = os.path.join(os.getcwd(), "examples", "agent_examples", "trained_results")
agent_class = Whittington2020
Expand Down Expand Up @@ -106,7 +106,7 @@
}

# Full model training consists of 20000 episodes
training_loop_params = {"n_episode": 20000, "params": full_agent_params, "random_state": False, "custom_state": [0.0, 0.0]}
training_loop_params = {"n_episode": 3000, "params": full_agent_params, "random_state": False, "custom_state": [0.0, 0.0]}

# Create the training simulation object
sim = SingleSim(
Expand Down
7 changes: 4 additions & 3 deletions neuralplayground/agents/agent_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ def act(self, obs, policy_func=None):

self.obs_history.append(obs)
if len(self.obs_history) >= 1000: # reset every 1000
self.obs_history = [
obs,
]
# self.obs_history = [
# obs,
# ]
self.obs_history.pop(0)
if policy_func is not None:
return policy_func(obs)

Expand Down
1 change: 1 addition & 0 deletions neuralplayground/agents/whittington_2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,7 @@ def update(self):
# Compute model accuracies
acc_p, acc_g, acc_gt = np.mean([[np.mean(a) for a in step.correct()] for step in forward], axis=0)
acc_p, acc_g, acc_gt = [a * 100 for a in (acc_p, acc_g, acc_gt)]
self.accuracies = (acc_p + acc_g + acc_gt) / 3
# Log progress
if self.iter % 1 == 0:
# Write series of messages to logger from this backprop iteration
Expand Down
2 changes: 1 addition & 1 deletion neuralplayground/backend/training_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def tem_training_loop(

training_dict = [agent.mod_kwargs, env.env_kwargs, agent.tem.hyper]

max_steps_per_env = np.random.randint(4000, 6000, size=params["batch_size"])
max_steps_per_env = np.random.randint(4000, 5000, size=params["batch_size"])
current_steps = np.zeros(params["batch_size"], dtype=int)

obs, state = env.reset(random_state=random_state, custom_state=custom_state)
Expand Down

0 comments on commit 1fe31ab

Please sign in to comment.