Skip to content

Commit

Permalink
fixing bug with the wse
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementineDomine committed Oct 26, 2023
1 parent 623462d commit 4c5a9fa
Show file tree
Hide file tree
Showing 9 changed files with 287 additions and 110 deletions.
211 changes: 125 additions & 86 deletions neuralplayground/agents/domine_2023.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
experiment_name: 'Learn identity no resample'
train_on_shortest_path: False # make sure it works when this is the case
resample: False # @param
wandb_on: True
seed: 42

feature_position: False # make sure it works when this is the case
weighted: False

num_hidden: 500 # @param
num_layers: 2 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.001 # @param
num_training_steps: 10 # @param


# Env Stuff
batch_size: 6
nx_min: 2
nx_max: 5

batch_size_test: 6
nx_min_test: 2
nx_max_test: 5
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
experiment_name: 'Learn identity no resample'
train_on_shortest_path: False # make sure it works when this is the case
resample: False # @param
wandb_on: True
seed: 42

feature_position: False # make sure it works when this is the case
weighted: False

num_hidden: 500 # @param
num_layers: 2 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.001 # @param
num_training_steps: 10 # @param


# Env Stuff
batch_size: 6
nx_min: 2
nx_max: 5

batch_size_test: 6
nx_min_test: 2
nx_max_test: 5
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
experiment_name: 'Learn identity no resample'
train_on_shortest_path: False # make sure it works when this is the case
resample: False # @param
wandb_on: True
seed: 43

feature_position: False # make sure it works when this is the case
weighted: False

num_hidden: 500 # @param
num_layers: 2 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.001 # @param
num_training_steps: 10 # @param


# Env Stuff
batch_size: 6
nx_min: 2
nx_max: 5

batch_size_test: 6
nx_min_test: 2
nx_max_test: 5
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
experiment_name: 'Learn identity no resample'
train_on_shortest_path: False # make sure it works when this is the case
resample: False # @param
wandb_on: True
seed: 4

feature_position: False # make sure it works when this is the case
weighted: False

num_hidden: 500 # @param
num_layers: 2 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.001 # @param
num_training_steps: 10 # @param


# Env Stuff
batch_size: 6
nx_min: 2
nx_max: 5

batch_size_test: 6
nx_min_test: 2
nx_max_test: 5
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
experiment_name: 'Learn identity no resample'
train_on_shortest_path: False # make sure it works when this is the case
resample: False # @param
wandb_on: True
seed: 45

feature_position: False # make sure it works when this is the case
weighted: False

num_hidden: 500 # @param
num_layers: 2 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.001 # @param
num_training_steps: 10 # @param


# Env Stuff
batch_size: 6
nx_min: 2
nx_max: 5

batch_size_test: 6
nx_min_test: 2
nx_max_test: 5
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
experiment_name: 'Learn identity no resample'
train_on_shortest_path: False # make sure it works when this is the case
resample: False # @param
wandb_on: True
seed: 46

feature_position: False # make sure it works when this is the case
weighted: False

num_hidden: 500 # @param
num_layers: 2 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.001 # @param
num_training_steps: 10 # @param


# Env Stuff
batch_size: 6
nx_min: 2
nx_max: 5

batch_size_test: 6
nx_min_test: 2
nx_max_test: 5
23 changes: 11 additions & 12 deletions neuralplayground/agents/domine_2023_extras/class_config.yaml
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
experiment_name: 'train_on_shortest_path'
experiment_name: 'Learn identity no resample with Feature position '
train_on_shortest_path: True # make sure it works when this is the case
resample: True # @param
resample: False # @param
wandb_on: False
seed: 42

feature_position: False # make sure it works when this is the case
feature_position: False # make sure it works when this is the case
weighted: False

num_hidden: 500 # @param
num_layers: 2 # @param
num_message_passing_steps: 4 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.001 # @param
num_training_steps: 10 # @param

num_training_steps: 200 # @param

# Env Stuff
batch_size: 4
nx_min: 3
nx_max: 5
batch_size: 6
nx_min: 2
nx_max: 6

batch_size_test: 4
nx_min_test: 3
nx_max_test: 5
batch_size_test: 6
nx_min_test: 2
nx_max_test: 6
19 changes: 7 additions & 12 deletions neuralplayground/agents/domine_2023_extras/class_plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# @title Make rng sequence generator
import matplotlib.pyplot as plt

import networkx as nx
from neuralplayground.agents.domine_2023_extras.class_utils import (
convert_jraph_to_networkx_graph,
Expand All @@ -14,7 +15,7 @@ def plot_input_target_output(
# minim 2 otherwise it breaks
rows = ["{}".format(row) for row in ["Input", "Target", "Outputs"]]
fig, axes = plt.subplots(3, n)
fig.set_size_inches(8, 8)
fig.set_size_inches(15, 15)
for i in range(n):
nx_graph = convert_jraph_to_networkx_graph(graph, i)
pos = nx.spring_layout(nx_graph, iterations=100, seed=39775)
Expand Down Expand Up @@ -74,6 +75,7 @@ def plot_input_target_output(
axes.set_ylabel(row, rotation=0, size="large")
plt.suptitle(title)
plt.savefig(save_path)
plt.close()


def plot_message_passing_layers(
Expand All @@ -90,7 +92,7 @@ def plot_message_passing_layers(
):
# minim 2 otherwise it breaks
fig, axes = plt.subplots(n_message_passing + 3, n)
fig.set_size_inches(8, 8)
fig.set_size_inches(15, 15)
for j in range(n):
nx_graph = convert_jraph_to_networkx_graph(graph, j)
pos = nx.spring_layout(nx_graph, iterations=100, seed=39775)
Expand Down Expand Up @@ -161,6 +163,7 @@ def plot_message_passing_layers(
)
plt.suptitle(title)
plt.savefig(save_path)
plt.close()


def plot_graph_grid_activations(
Expand Down Expand Up @@ -210,7 +213,7 @@ def plot_message_passing_layers_units(
):
# minim 2 otherwise it breaks
fig, axes = plt.subplots(n_message_passing, number_hidden)
fig.set_size_inches(12, 12)
fig.set_size_inches(15, 15)
nx_graph = convert_jraph_to_networkx_graph(graph, 0)
pos = nx.spring_layout(nx_graph, iterations=100, seed=39775)
if edege_lables:
Expand Down Expand Up @@ -244,15 +247,6 @@ def plot_message_passing_layers_units(
plt.savefig(save_path)


def plot_xy(auc_roc, path, title):
fig = plt.figure(figsize=(5, 3))
ax = fig.add_subplot(111)
ax.title.set_text(title)
ax.plot(auc_roc)
plt.savefig(path)


import matplotlib.pyplot as plt


def plot_curves(curves, path, title, legend_labels=None, x_label=None, y_label=None):
Expand All @@ -276,3 +270,4 @@ def plot_curves(curves, path, title, legend_labels=None, x_label=None, y_label=N

plt.savefig(path)
plt.show()
plt.close()

0 comments on commit 4c5a9fa

Please sign in to comment.