Skip to content

Commit

Permalink
black precommit changes
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeHollingsworth committed Oct 29, 2024
1 parent d2bdd14 commit 853d185
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
2 changes: 1 addition & 1 deletion examples/agent_examples/whittington_2020_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from neuralplayground.experiments import Sargolini2006Data

# Set the location for saving the results of the simulation
simulation_id = "TEM_test_with_break"
simulation_id = "TEM_test"
save_path = os.path.join(os.getcwd(), simulation_id)
# save_path = os.path.join(os.getcwd(), "examples", "agent_examples", "trained_results")
agent_class = Whittington2020
Expand Down
12 changes: 9 additions & 3 deletions neuralplayground/agents/whittington_2020.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,15 @@ def update(self):
self.final_model_input = model_input

forward = self.tem(model_input, self.prev_iter)
chunk = [[step[0][0], np.argmax(step[1][0]), step[2][0]] for step in model_input]
for i in range(len(chunk)):
self.logger.info(chunk[i])
# chunk = [[step[0][0], np.argmax(step[1][0]), step[2][0]] for step in model_input]
for i in range(len(model_input)):
# self.logger.info(chunk[i])
ids = [step["id"] for step in model_input[i][0]]
objs = [int(np.argmax(step)) for step in model_input[i][1]]
actions = model_input[i][2]
self.logger.info("IDs: " + str(ids))
self.logger.info("Objs: " + str(objs))
self.logger.info("Actions: " + str(actions))
# if self.prev_iter is None:
# with open('OG_log.txt', 'a') as f:
# f.write('Walk number: ' + str(self.global_steps) + '\n')
Expand Down

0 comments on commit 853d185

Please sign in to comment.