Skip to content

Commit

Permalink
Layer Norm+residual+Plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementineDomine committed Nov 8, 2023
1 parent 6c113ac commit 8614466
Show file tree
Hide file tree
Showing 11 changed files with 257 additions and 263 deletions.
264 changes: 109 additions & 155 deletions neuralplayground/agents/domine_2023.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
experiment_name: 'Learn identity no resample'
train_on_shortest_path: False # make sure it works when this is the case
resample: False # @param
experiment_name: 'Large_sp_re_ '
train_on_shortest_path: True # make sure it works when this is the case
resample: True # @param
wandb_on: True
seed: 42
seed: 46

feature_position: False # make sure it works when this is the case
feature_position: False # make sure it works when this is the case
weighted: False

num_hidden: 500 # @param
num_layers: 2 # @param
num_hidden: 160 # @param
num_layers: 1 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.001 # @param
num_training_steps: 10 # @param

learning_rate: 0.00001 # @param
num_training_steps: 20 # @param
residual: False
layer_norm: False

# Env Stuff
batch_size: 6
nx_min: 2
nx_max: 5
batch_size: 2
nx_min: 5
nx_max: 6

batch_size_test: 6
nx_min_test: 2
nx_max_test: 5
batch_size_test: 2
nx_min_test: 5
nx_max_test: 6
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
experiment_name: 'Learn identity no resample'
train_on_shortest_path: False # make sure it works when this is the case
resample: False # @param
experiment_name: 'Large_sp_re_ '
train_on_shortest_path: True # make sure it works when this is the case
resample: True # @param
wandb_on: True
seed: 42
seed: 45

feature_position: False # make sure it works when this is the case
feature_position: False # make sure it works when this is the case
weighted: False

num_hidden: 500 # @param
num_layers: 2 # @param
num_hidden: 160 # @param
num_layers: 1 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.001 # @param
num_training_steps: 10 # @param

learning_rate: 0.00001 # @param
num_training_steps: 20 # @param
residual: False
layer_norm: False

# Env Stuff
batch_size: 6
nx_min: 2
nx_max: 5
batch_size: 2
nx_min: 5
nx_max: 6

batch_size_test: 6
nx_min_test: 2
nx_max_test: 5
batch_size_test: 2
nx_min_test: 5
nx_max_test: 6
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
experiment_name: 'Learn identity no resample'
train_on_shortest_path: False # make sure it works when this is the case
resample: False # @param
experiment_name: 'Large_sp_re_ '
train_on_shortest_path: True # make sure it works when this is the case
resample: True # @param
wandb_on: True
seed: 43
seed: 44

feature_position: False # make sure it works when this is the case
feature_position: False # make sure it works when this is the case
weighted: False

num_hidden: 500 # @param
num_layers: 2 # @param
num_hidden: 160 # @param
num_layers: 1 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.001 # @param
num_training_steps: 10 # @param

learning_rate: 0.00001 # @param
num_training_steps: 20 # @param
residual: False
layer_norm: False

# Env Stuff
batch_size: 6
nx_min: 2
nx_max: 5
batch_size: 2
nx_min: 5
nx_max: 6

batch_size_test: 6
nx_min_test: 2
nx_max_test: 5
batch_size_test: 2
nx_min_test: 5
nx_max_test: 6
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
experiment_name: 'Learn identity no resample'
train_on_shortest_path: False # make sure it works when this is the case
resample: False # @param
experiment_name: 'Large_sp_re_ '
train_on_shortest_path: True # make sure it works when this is the case
resample: True # @param
wandb_on: True
seed: 4
seed: 43

feature_position: False # make sure it works when this is the case
feature_position: False # make sure it works when this is the case
weighted: False

num_hidden: 500 # @param
num_layers: 2 # @param
num_hidden: 160 # @param
num_layers: 1 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.001 # @param
num_training_steps: 10 # @param

learning_rate: 0.00001 # @param
num_training_steps: 20 # @param
residual: False
layer_norm: False

# Env Stuff
batch_size: 6
nx_min: 2
nx_max: 5
batch_size: 2
nx_min: 5
nx_max: 6

batch_size_test: 6
nx_min_test: 2
nx_max_test: 5
batch_size_test: 2
nx_min_test: 5
nx_max_test: 6
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
experiment_name: 'Learn identity no resample'
train_on_shortest_path: False # make sure it works when this is the case
resample: False # @param
experiment_name: 'Large_sp_re_ '
train_on_shortest_path: True # make sure it works when this is the case
resample: True # @param
wandb_on: True
seed: 45
seed: 42

feature_position: False # make sure it works when this is the case
feature_position: False # make sure it works when this is the case
weighted: False

num_hidden: 500 # @param
num_layers: 2 # @param
num_hidden: 160 # @param
num_layers: 1 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.001 # @param
num_training_steps: 10 # @param

learning_rate: 0.00001 # @param
num_training_steps: 20 # @param
residual: False
layer_norm: False

# Env Stuff
batch_size: 6
nx_min: 2
nx_max: 5
batch_size: 2
nx_min: 5
nx_max: 6

batch_size_test: 6
nx_min_test: 2
nx_max_test: 5
batch_size_test: 2
nx_min_test: 5
nx_max_test: 6
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
experiment_name: 'Learn identity no resample'
train_on_shortest_path: False # make sure it works when this is the case
resample: False # @param
experiment_name: 'Large_sp_re_ '
train_on_shortest_path: True # make sure it works when this is the case
resample: True # @param
wandb_on: True
seed: 46
seed: 41

feature_position: False # make sure it works when this is the case
feature_position: False # make sure it works when this is the case
weighted: False

num_hidden: 500 # @param
num_layers: 2 # @param
num_hidden: 160 # @param
num_layers: 1 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.001 # @param
num_training_steps: 10 # @param

learning_rate: 0.00001 # @param
num_training_steps: 20 # @param
residual: False
layer_norm: False

# Env Stuff
batch_size: 6
nx_min: 2
nx_max: 5
batch_size: 2
nx_min: 5
nx_max: 6

batch_size_test: 6
nx_min_test: 2
nx_max_test: 5
batch_size_test: 2
nx_min_test: 5
nx_max_test: 6
10 changes: 6 additions & 4 deletions neuralplayground/agents/domine_2023_extras/class_config.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
experiment_name: 'Learn identity no resample with Feature position '
experiment_name: 'Large_sp_re_ '
train_on_shortest_path: True # make sure it works when this is the case
resample: True # @param
wandb_on: False
Expand All @@ -9,9 +9,11 @@ weighted: False

num_hidden: 160 # @param
num_layers: 1 # @param
num_message_passing_steps: 2 # @param
learning_rate: 0.0001 # @param
num_training_steps: 200 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.00001 # @param
num_training_steps: 10 # @param
residual: False
layer_norm: False

# Env Stuff
batch_size: 2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ class ConfigTemplate:
name="resample",
types=[bool],
),
config_field.Field(
name="residual",
types=[bool],
),
config_field.Field(
name="train_on_shortest_path",
types=[bool],
Expand Down Expand Up @@ -148,5 +152,13 @@ class ConfigTemplate:
name="num_training_steps",
types=[float, int],
),
config_field.Field(
name="residual",
types=[bool],
),
config_field.Field(
name="layer_norm",
types=[bool],
),
],
)
35 changes: 28 additions & 7 deletions neuralplayground/agents/domine_2023_extras/class_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# TODO(clementine): set up object oriented GNN classes (eventually)


def get_forward_function(num_hidden, num_layers, num_message_passing_steps):
def get_forward_function(num_hidden, num_layers, num_message_passing_steps,add_residual=True, use_layer_norm =False):
"""Get function that performs a forward call on a simple GNN."""

def _forward(x):
Expand All @@ -24,25 +24,46 @@ def _forward(x):

# Apply rounds of message passing.
message_passing = []
layer_output =[]
for n in range(num_message_passing_steps):
x = message_passing_layer(x, edge_mlp_sizes, node_mlp_sizes)
if add_residual:
previous_x = x # Store the current state for the residual connection
x = message_passing_layer(x, edge_mlp_sizes, node_mlp_sizes, use_layer_norm)
x = x._replace(nodes=x.nodes + previous_x.nodes, edges=x.edges + previous_x.edges)
#layer_output += message_passing_layer(x, edge_mlp_sizes, node_mlp_sizes, use_layer_norm)
else:
x = message_passing_layer(x, edge_mlp_sizes, node_mlp_sizes, use_layer_norm)
message_passing.append(x)

#x = message_passing_layer(x, edge_mlp_sizes, node_mlp_sizes)
#message_passing. append(x)
# Map features to desired feature size.
x = jraph.GraphMapFeatures(
embed_edge_fn=hk.Linear(output_size=edge_output_size),
embed_node_fn=hk.Linear(output_size=node_output_size),
)(x)

return x, message_passing

return _forward

def mlp(edge_mlp_sizes, use_layer_norm):
sequential_modules = [hk.nets.MLP(output_sizes=edge_mlp_sizes)]
if use_layer_norm:
#TODO: Clementine Domine check if this is the right axis
sequential_modules.append(hk.LayerNorm(axis=-1, param_axis=-1, create_scale=True,create_offset=True))
return hk.Sequential(sequential_modules)

def message_passing_layer(x, edge_mlp_sizes, node_mlp_sizes):
update_edge_fn = jraph.concatenated_args(hk.nets.MLP(output_sizes=edge_mlp_sizes))
update_node_fn = jraph.concatenated_args(hk.nets.MLP(output_sizes=node_mlp_sizes))
def message_passing_layer(x, edge_mlp_sizes, node_mlp_sizes, use_layer_norm):
update_edge_fn = jraph.concatenated_args(mlp(edge_mlp_sizes, use_layer_norm))
update_node_fn = jraph.concatenated_args(mlp(node_mlp_sizes, use_layer_norm))
x = jraph.GraphNetwork(
update_edge_fn=update_edge_fn, update_node_fn=update_node_fn
)(x)
return x

#def message_passing_layer(x, edge_mlp_sizes, node_mlp_sizes):
# update_edge_fn = jraph.concatenated_args(hk.nets.MLP(output_sizes=edge_mlp_sizes))
# update_node_fn = jraph.concatenated_args(hk.nets.MLP(output_sizes=node_mlp_sizes))
# x = jraph.GraphNetwork(
# update_edge_fn=update_edge_fn, update_node_fn=update_node_fn
#)(x)
#return x
Original file line number Diff line number Diff line change
Expand Up @@ -282,5 +282,4 @@ def plot_curves(curves, path, title, legend_labels=None, x_label=None, y_label=N
if legend_labels:
ax.legend()
plt.savefig(path)
plt.show()
plt.close()

0 comments on commit 8614466

Please sign in to comment.