Skip to content

Commit

Permalink
medium size run added
Browse files Browse the repository at this point in the history
  • Loading branch information
LukeHollingsworth committed Jul 15, 2024
1 parent 36f5da1 commit 18c4abb
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 59 deletions.
74 changes: 23 additions & 51 deletions examples/agent_examples/whittington_2020_example.ipynb

Large diffs are not rendered by default.

10 changes: 5 additions & 5 deletions examples/agent_examples/whittington_2020_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from neuralplayground.experiments import Sargolini2006Data

# Set the location for saving the results of the simulation
simulation_id = "TEM_results_new"
simulation_id = "TEM_results_med"
save_path = os.path.join(os.getcwd(), simulation_id)
# save_path = os.path.join(os.getcwd(), "examples", "agent_examples", "trained_results")
agent_class = Whittington2020
Expand All @@ -29,7 +29,7 @@
[-3, 3],
[-2, 2],
[-3, 3],
[-3, 3],
[-4, 4],
[-2, 2],
[-3, 3],
[-4, 4],
Expand Down Expand Up @@ -65,9 +65,9 @@
# Set parameters for the environment that generates observations
discrete_env_params = {
"environment_name": "DiscreteObject",
"state_density": 1,
"state_density": 2,
"n_objects": params["n_x"],
"agent_step_size": 1, # Note: this must be 1 / state density
"agent_step_size": 1/2, # Note: this must be 1 / state density
"use_behavioural_data": False,
"data_path": None,
"experiment_class": Sargolini2006Data,
Expand Down Expand Up @@ -106,7 +106,7 @@
}

# Full model training consists of 20000 episodes
training_loop_params = {"n_episode": 5000, "params": full_agent_params,'random_state': False, 'custom_state': [0.0, 0.0]}
training_loop_params = {"n_episode": 20000, "params": full_agent_params,'random_state': False, 'custom_state': [0.0, 0.0]}

# Create the training simulation object
sim = SingleSim(
Expand Down
6 changes: 3 additions & 3 deletions examples/agent_examples/whittington_slurm.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
#!/bin/bash
# Set the job name variable
#SBATCH --job-name=TEM_small_2
#SBATCH --job-name=TEM_med
#SBATCH --mem=50000 # memory pool for all cores
#SBATCH --time=72:00:00 # time
#SBATCH -o TEM_logs/TEM_small_2.%N.%j.out # STDOUT
#SBATCH -e TEM_logs/TEM_small_2.%N.%j.err # STDERR
#SBATCH -o TEM_logs/TEM_med.%N.%j.out # STDOUT
#SBATCH -e TEM_logs/TEM_med.%N.%j.err # STDERR
#SBATCH -p gpu
#SBATCH --gres=gpu:1
#SBATCH --ntasks=1
Expand Down

0 comments on commit 18c4abb

Please sign in to comment.