Skip to content

Commit

Permalink
New branch for last-minute commenting & cleaning (#90)
Browse files Browse the repository at this point in the history
* notebooks finished, TEM set to simple run & logger path fixed

* Simple2D and DiscreteObject added as examples for BatchEnvironment

* batch environment test fixed

* trained TEM models added to GIN

---------

Co-authored-by: LukeHollingsworth <[email protected]>
  • Loading branch information
LukeHollingsworth and LukeHollingsworth authored Aug 21, 2023
1 parent b570f6e commit b1d9a66
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 368 deletions.
373 changes: 27 additions & 346 deletions examples/agent_examples/whittington_2020_example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/agent_examples/whittington_2020_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from neuralplayground.experiments import Sargolini2006Data

simulation_id = "TEM_custom_sim"
save_path = os.path.join(os.getcwd(), "examples", "agent_examples", "results_sim")
save_path = os.path.join(os.getcwd(), "results_sim")
# save_path = os.path.join(os.getcwd(), "examples", "agent_examples", "trained_results")
agent_class = Whittington2020
env_class = BatchEnvironment
Expand Down
132 changes: 123 additions & 9 deletions examples/arena_examples/arena_examples.ipynb

Large diffs are not rendered by default.

12 changes: 5 additions & 7 deletions neuralplayground/agents/whittington_2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import neuralplayground.agents.whittington_2020_extras.whittington_2020_analyse as analyse
import neuralplayground.agents.whittington_2020_extras.whittington_2020_model as model
import neuralplayground.agents.whittington_2020_extras.whittington_2020_parameters as parameters
import neuralplayground.agents.whittington_2020_extras.whittington_2020_utils as utils

# Custom modules
from neuralplayground.plotting.plot_utils import make_plot_rate_map
Expand Down Expand Up @@ -335,13 +336,10 @@ def initialise(self):
# # Create a tensor board to stay updated on training progress. Start tensorboard with tensorboard --logdir=runs
# self.writer = SummaryWriter(self.train_path)
# Create a logger to write log output to file
# current_dir = os.path.dirname(os.getcwd())
# while os.path.basename(current_dir) != "examples":
# current_dir = os.path.dirname(current_dir)
# relative_path = "agent_examples/results_sim"
# run_path = os.path.join(current_dir, relative_path)
# run_path = os.path.normpath(run_path)
# self.logger = utils.make_logger(run_path)
current_dir = os.path.dirname(os.getcwd())
run_path = os.path.join(current_dir, "agent_examples", "results_sim")
run_path = os.path.normpath(run_path)
self.logger = utils.make_logger(run_path)
# Make an ADAM optimizer for TEM
self.adam = torch.optim.Adam(self.tem.parameters(), lr=self.pars["lr_max"])
# Initialise whether a state has been visited for each world
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def make_logger(run_path):
# Remove anly existing handlers so you don't output to old files, or to new files twice
logger.handlers = []
# Create a file handler, but only if the handler does
os.makedirs(run_path, exist_ok=True)
log_file_path = os.path.join(run_path, "run.log")
handler = logging.FileHandler(log_file_path)
handler.setLevel(logging.INFO)
Expand Down
6 changes: 3 additions & 3 deletions neuralplayground/arenas/discritized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,12 @@ def step(self, action: np.ndarray, normalize_step: bool = False, skip_every: int
self.state = observation
self.transition = {
"action": action,
"state": self.old_state,
"next_state": self.state,
"state": self.old_state[-1],
"next_state": self.state[-1],
"reward": reward,
"step": self.global_steps,
}
# self.history.append(transition)
self.history.append(self.transition)
self._increase_global_step()
return observation, self.state, reward

Expand Down
2 changes: 1 addition & 1 deletion neuralplayground/saved_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"stachenfeld_2018_in_hafting2008.zip": "960cdc8d4fa9ef86ed1d5ef144fe6949d227c081b837ae24e49335bdaf971899", # noqa: E501
"weber_2018_in_wernle.zip": "51f701966229ba8a70aab7b7ce79f4965e80904661eb6cdad85d03b0ddb7f8ff", # noqa: E501
"weber_2018_in_merging_room.zip": "10c537bc1d410de1bba18fe36624501bc4caddc0a032f3889a39435256a0205c", # noqa: E501
"tem_in_2D.zip": "9f82bb8e231e6e38526deb1ea4a1be6bee95f54dba364cf5a129fbf8b3f191eb", # noqa: E501
"whittington_2020_in_discritized_objects.zip": "3b527b03cd011b5e71ff66304f25d2406acddcbd3f770139ca8d8edc71cf1703", # noqa: E501
},
)

Expand Down
2 changes: 1 addition & 1 deletion tests/arena_exp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def test_agent_interaction(self):
action = agent.batch_act(obs)
# Run environment for given action
obs, state, reward = env.step(action, normalize_step=True)
env.plot_trajectory()
env.plot_trajectories()

def test_init_env(self, init_env):
assert isinstance(init_env[0], BatchEnvironment)
Expand Down

0 comments on commit b1d9a66

Please sign in to comment.