Skip to content

Commit

Permalink
precommit black
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeHollingsworth committed Nov 6, 2024
1 parent 853d185 commit 4afe24c
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 44 deletions.
69 changes: 33 additions & 36 deletions examples/agent_examples/whittington_2020_example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/agent_examples/whittington_2020_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from neuralplayground.experiments import Sargolini2006Data

# Set the location for saving the results of the simulation
simulation_id = "TEM_test"
simulation_id = "TEM_begging"
save_path = os.path.join(os.getcwd(), simulation_id)
# save_path = os.path.join(os.getcwd(), "examples", "agent_examples", "trained_results")
agent_class = Whittington2020
Expand Down
6 changes: 3 additions & 3 deletions examples/agent_examples/whittington_slurm.sh
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
#!/bin/bash

#SBATCH -J TEM_5x5 # job name
#SBATCH -J TEM_beg # job name
#SBATCH -p gpu # partition (queue)
#SBATCH -N 1 # number of nodes
#SBATCH --mem 50G # memory pool for all cores
#SBATCH -n 4 # number of cores
#SBATCH -t 0-72:00 # time (D-HH:MM)
#SBATCH --gres gpu:1 # request 1 GPU (of any kind)
#SBATCH -o TEM_5x5.%x.%N.%j.out # STDOUT
#SBATCH -e TEM_5x5.%x.%N.%j.err # STDERR
#SBATCH -o TEM_beg.%x.%N.%j.out # STDOUT
#SBATCH -e TEM_beg.%x.%N.%j.err # STDERR

source ~/.bashrc

Expand Down
8 changes: 4 additions & 4 deletions neuralplayground/arenas/discritized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,10 +244,10 @@ def step(self, action: np.ndarray, normalize_step: bool = True, skip_every: int
}
self.history.append(self.transition)
self._increase_global_step()
self.steps_in_curr_env += 1
if self.steps_in_curr_env >= self.max_steps_per_env:
self.steps_in_curr_env = 0
self.reset_objects()
# self.steps_in_curr_env += 1
# if self.steps_in_curr_env >= self.max_steps_per_env:
# self.steps_in_curr_env = 0
# self.reset_objects()
return observation, self.state, reward

def generate_objects(self):
Expand Down

0 comments on commit 4afe24c

Please sign in to comment.