Skip to content

Commit

Permalink
no resampling the test
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementineDomine committed Nov 12, 2023
1 parent e05067d commit 9a6a7d9
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 68 deletions.
183 changes: 119 additions & 64 deletions neuralplayground/agents/domine_2023.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
rng_sequence_from_rng,
set_device,
update_outputs_test,
get_activations_graph_n
)
from sklearn.metrics import matthews_corrcoef, roc_auc_score

Expand Down Expand Up @@ -62,8 +63,6 @@ def __init__( # autogenerated
batch_size_test: int = 4,
nx_min_test: int = 4,
nx_max_test: int = 7,


**mod_kwargs,
):
self.plot=True
Expand Down Expand Up @@ -133,7 +132,29 @@ def __init__( # autogenerated
self.nx_min,
self.nx_max,
)
rng = next(self.rng_seq)
self.graph_test, self.target_test = sample_padded_grid_batch_shortest_path(
rng,
self.batch_size,
self.feature_position,
self.weighted,
self.nx_min,
self.nx_max,
)

else:
self.graph_test, self.target_test = sample_padded_grid_batch_shortest_path(
rng,
self.batch_size_test,
self.feature_position,
self.weighted,
self.nx_min_test,
self.nx_max_test,
)
self.target_test = np.reshape(
self.graph_test.nodes[:, 0], (self.graph_test.nodes[:, 0].shape[0], -1)
)
rng = next(self.rng_seq)
self.graph, self.targets = sample_padded_grid_batch_shortest_path(
rng,
self.batch_size,
Expand All @@ -142,9 +163,26 @@ def __init__( # autogenerated
self.nx_min,
self.nx_max,
)

if self.feature_position:
self.indices_test = np.where(self.graph_test.nodes[:, 0] == 1)[0]
self.target_test_wse = self.target_test - np.reshape(
self.graph_test.nodes[:, 0], (self.graph_test.nodes[:, 0].shape[0], -1)
)
self.target_wse = self.targets - np.reshape(
self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1)
)
else:
self.indices_train = np.where(self.graph.nodes[:] == 1)[0]
self.indices_test = np.where(self.graph_test.nodes[:] == 1)[0]
self.target_test_wse = self.target_test - self.graph_test.nodes[:]
self.target_wse = self.targets - self.graph.nodes[:]


forward = get_forward_function(
self.num_hidden, self.num_layers, self.num_message_passing_steps, self.residuals,self.layer_norm
)

net_hk = hk.without_apply_rng(hk.transform(forward))
params = net_hk.init(rng, self.graph)
param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
Expand All @@ -154,18 +192,46 @@ def __init__( # autogenerated
opt_state = optimizer.init(self.params)
self.opt_state = opt_state


def compute_loss(params, inputs, targets):
outputs = net_hk.apply(params, inputs)
def compute_loss(params, graph, targets):
outputs = net_hk.apply(params, graph)
return jnp.mean((outputs[0].nodes - targets) ** 2)

self._compute_loss = jax.jit(compute_loss)

def compute_loss_per_graph(params, inputs, targets):
outputs = net_hk.apply(params, inputs)
#for each graph:
def compute_loss_per_node(params, graph, targets):
outputs = net_hk.apply(params, graph)
return (outputs[0].nodes - targets) ** 2

self._compute_loss_per_node = jax.jit(compute_loss_per_node)

def compute_loss_per_graph(params, graph, targets):
outputs = net_hk.apply(params, graph)
loss_per_graph=[]
for i in range(self.batch_size):
graph_outputs= get_activations_graph_n(np.squeeze(outputs[0].nodes), graph, i)
graph_target = get_activations_graph_n((targets.sum(-1)), graph,i)
loss_per_graph.append(jnp.mean((graph_outputs - graph_target) ** 2))

return [np.squeeze(loss_per_graph).tolist() , graph.n_node[0:self.batch_size].tolist()]

self._compute_loss_per_graph = compute_loss_per_graph

def compute_loss_nodes_shortest_path(params, graph, targets):
outputs = net_hk.apply(params, graph)
loss_per_graph=[]
len_shortest_path=[]
for i in range(self.batch_size):
graph_outputs= get_activations_graph_n(np.squeeze(outputs[0].nodes), graph, i)
graph_target = get_activations_graph_n((targets.sum(-1)), graph,i)
indices_train = np.where(graph_target == 1)[0]
len_shortest_path.append(len(indices_train))
loss_per_graph.append(jnp.mean((graph_outputs[indices_train] - graph_target[indices_train]) ** 2)/len(indices_train))
return np.concatenate((np.squeeze(loss_per_graph),np.asarray(len_shortest_path)),axis=0)

self._compute_loss_nodes_shortest_path = compute_loss_nodes_shortest_path

# def compute_loss_per_node(params, inputs, targets):
# outputs = net_hk.apply(params, inputs)
# return (outputs[0].nodes - targets) ** 2

def update_step(params, opt_state):
loss, grads = jax.value_and_grad(compute_loss)(
Expand Down Expand Up @@ -203,7 +269,6 @@ def evaluate(params, inputs, target,wse_value=True,indices=None):
wandb_logs = {
"train_on_shortest_path": train_on_shortest_path,
"resample": resample,

"batch_size_test": batch_size_test,
"nx_min_test": nx_min_test, # This is thought of the state density
"nx_max_test": nx_max_test, # This is thought of the state density
Expand Down Expand Up @@ -278,6 +343,9 @@ def reset(self, a=1):
self.global_steps = 0
self.losses_train = []
self.losses_test = []
self.losses_per_node_test = []
self.losses_per_graph_test = []
self.losses_per_shortest_path_test=[]
self.losses_train_wse = []
self.losses_test_wse = []
self.roc_aucs_train = []
Expand All @@ -290,18 +358,8 @@ def reset(self, a=1):

def update(self):
rng = next(self.rng_seq)
if self.train_on_shortest_path:
self.graph_test, self.target_test = sample_padded_grid_batch_shortest_path(
rng,
self.batch_size,
self.feature_position,
self.weighted,
self.nx_min,
self.nx_max,
)
rng = next(self.rng_seq)

if self.resample:
if self.resample:
if self.train_on_shortest_path:
self.graph, self.targets = sample_padded_grid_batch_shortest_path(
rng,
self.batch_size,
Expand All @@ -310,48 +368,29 @@ def update(self):
self.nx_min,
self.nx_max,
)
else:
self.graph_test, self.target_test = sample_padded_grid_batch_shortest_path(
rng,
self.batch_size_test,
self.feature_position,
self.weighted,
self.nx_min_test,
self.nx_max_test,
)
self.target_test = np.reshape(
self.graph_test.nodes[:, 0], (self.graph_test.nodes[:, 0].shape[0], -1)
)
else:
rng = next(self.rng_seq)
# Sample

rng = next(self.rng_seq)
# Sample
if self.resample:
self.graph, self.targets = sample_padded_grid_batch_shortest_path(
rng,
self.batch_size,
self.feature_position,
self.weighted,
self.nx_min,
self.nx_max,
rng,
self.batch_size,
self.feature_position,
self.weighted,
self.nx_min,
self.nx_max,
)
self.targets = np.reshape(
self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1)
)
self.targets = np.reshape(
self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1)
)

if self.feature_position:
indices_train = np.where(self.graph.nodes[:, 0] == 1)[0]
indices_test = np.where(self.graph_test.nodes[:, 0] == 1)[0]
self.target_test_wse = self.target_test - np.reshape(
self.graph_test.nodes[:, 0], (self.graph_test.nodes[:, 0].shape[0], -1)
)
self.target_wse = self.targets - np.reshape(
if self.feature_position:
self.target_wse = self.targets - np.reshape(
self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1)
)
else:
indices_train = np.where( self.graph.nodes[:]== 1)[0]
indices_test = np.where(self.graph_test.nodes[:] == 1)[0]
self.target_test_wse = self.target_test - self.graph_test.nodes[:]
self.target_wse = self.targets - self.graph.nodes[:]
)
else:
self.indices_train = np.where( self.graph.nodes[:]== 1)[0]
self.target_wse = self.targets - self.graph.nodes[:]

# Train
self.params, self.opt_state, loss = self._update_step(
Expand All @@ -368,12 +407,19 @@ def update(self):
loss_wse = self._compute_loss(self.params, self.graph, self.target_wse)
self.losses_train_wse.append(loss_wse)
outputs_train_wse_wrong, roc_auc_train_wse, MCC_train_wse = self._evaluate(
self.params, self.graph, self.target_wse, False, indices_train
self.params, self.graph, self.target_wse, False, self.indices_train
)
self.outputs_train_wse = update_outputs_test(outputs_train_wse_wrong, indices_train)
self.outputs_train_wse = update_outputs_test(outputs_train_wse_wrong, self.indices_train)
self.MCCs_train_wse.append(MCC_train_wse)

# Test
loss_test_per_node = self._compute_loss_per_node(self.params, self.graph_test, self.target_test)
loss_test_per_graph = self._compute_loss_per_graph(self.params, self.graph_test, self.target_test)
loss_nodes_shortest_path = self._compute_loss_nodes_shortest_path(self.params, self.graph_test, self.target_test)
self.losses_per_node_test.append(np.squeeze(loss_test_per_node))
self.losses_per_graph_test.append((loss_test_per_graph))
self.losses_per_shortest_path_test.append(loss_nodes_shortest_path)

loss_test = self._compute_loss(self.params, self.graph_test, self.target_test)
self.losses_test.append(loss_test)
self.outputs_test, roc_auc_test, MCC_test = self._evaluate(
Expand All @@ -386,9 +432,9 @@ def update(self):
loss_test_wse = self._compute_loss(self.params, self.graph_test, self.target_test_wse)
self.losses_test_wse.append(loss_test_wse)
outputs_test_wse_wrong, roc_auc_test_wse, MCC_test_wse = self._evaluate(
self.params, self.graph_test, self.target_test_wse, False, indices_test
self.params, self.graph_test, self.target_test_wse, False, self.indices_test
)
self.outputs_test_wse = update_outputs_test(outputs_test_wse_wrong, indices_test)
self.outputs_test_wse = update_outputs_test(outputs_test_wse_wrong, self.indices_test)
self.MCCs_test_wse.append(MCC_test_wse)

# Log
Expand Down Expand Up @@ -420,7 +466,6 @@ def update(self):
f"Training step {self.global_steps}: log_loss = {np.log(loss)} , log_loss_test = {np.log(loss_test)}, roc_auc_test = {roc_auc_test}, roc_auc_train = {roc_auc_train}"
)


if self.global_steps == self.num_training_steps:
if self.wandb_on:
with open("readme.txt", "w") as f:
Expand Down Expand Up @@ -476,6 +521,16 @@ def plot_learning_curves(self,trainning_step):
os.path.join(self.save_path, "losses_test_"+trainning_step+".pdf"),
"losses_test",
)

plot_curves([self.losses_per_node_test[0]], os.path.join(self.save_path, "losses_per_node_test_" + trainning_step + ".pdf"),
"Losses")
plot_curves([self.losses_per_graph_test[0]],
os.path.join(self.save_path, "Losses_per_graph_test_" + trainning_step + ".pdf"),
"Losses",[self.losses_per_graph_test[1]])
plot_curves([self.losses_per_shortest_path_test[0]],
os.path.join(self.save_path, "Losses_per_shortest_path_test_" + trainning_step + ".pdf"),
"Losses")

plot_curves(
[self.losses_train_wse],
os.path.join(self.save_path, "Losses_wse_"+trainning_step+".pdf"),
Expand Down
8 changes: 4 additions & 4 deletions neuralplayground/agents/domine_2023_seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,7 @@



import statistics
print(statistics.mean([agent_1.roc_aucs_train,agent_2.roc_aucs_train]))
print(statistics.stdev([agent_1.roc_aucs_train,agent_2.roc_aucs_train]))


parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down Expand Up @@ -214,7 +212,9 @@
agent_5.update()



import statistics
print(statistics.mean([agent_1.roc_aucs_train[-1],agent_2.roc_aucs_train[-1]]))
print(statistics.stdev([agent_1.roc_aucs_train[-1],agent_2.roc_aucs_train[-1]]))
import statistics
print(statistics.mean([agent_1.roc_aucs_train,agent_2.roc_aucs_train,agent_3.roc_aucs_train,agent_4.roc_aucs_train,agent_5.roc_aucs_train]))
print(statistics.stdev([agent_1.roc_aucs_train,agent_2.roc_aucs_train,agent_3.roc_aucs_train,agent_4.roc_aucs_train,agent_5.roc_aucs_train]))

0 comments on commit 9a6a7d9

Please sign in to comment.