Skip to content

Commit

Permalink
letting it go
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementineDomine committed Oct 3, 2023
1 parent 312ab59 commit 9660e05
Showing 1 changed file with 54 additions and 37 deletions.
91 changes: 54 additions & 37 deletions neuralplayground/agents/domine_2023.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# TODO: NOTE to self: This is a work in progress, it has not been tested to work, I think Jax is not a good way to implement in object oriented coding.
# I think if I want to implement it here I should use neuralplayground it would be in pytorch.

import argparse
import os
import shutil
Expand Down Expand Up @@ -125,8 +128,27 @@ def __init__ (
os.mkdir(os.path.join(save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M")))
self.save_path = os.path.join(
os.path.join(save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M")))
self.reset()
self.saving_run_parameters()
forward = get_forward_function(self.num_hidden, self.num_layers, self.num_message_passing_steps)
self.net_hk = hk.without_apply_rng(hk.transform(forward))
self.rng = jax.random.PRNGKey(self.seed)
self.rng_seq = rng_sequence_from_rng(self.rng)

if self.train_on_shortest_path:
self.graph, self.targets = sample_padded_grid_batch_shortest_path(
self.rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max
)
else:
self.graph, self.targets = sample_padded_grid_batch_shortest_path(
self.rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max
)
self.params = self.net_hk.init(self.rng, self.graph)
self.optimizer = optax.adam(self.learning_rate)
self.opt_state = self.optimizer.init(self.params)

def saving_run_parameters(self):

# SAVING Trainning Files
path = os.path.join(self.save_path, "run.py")
HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023.py")
shutil.copyfile(HERE, path)
Expand All @@ -148,7 +170,7 @@ def __init__ (
shutil.copyfile(HERE, path)

# This is the function that does the forward pass of the model
self.reset()


def evaluate(self, model, params, inputs, target):
outputs = model.apply(params, inputs)
Expand All @@ -167,40 +189,35 @@ def reset(self,a=1):
self.MCCs_train = []
self.MCCs_test = []
self.roc_aucs_test = []
return


return
def compute_loss(self, params, inputs, targets):
# not jitted because it will get jitted in jax.value_and_grad
outputs = self.net_hk.apply(params, inputs)
return jnp.mean((outputs[0].nodes - targets) ** 2) # using MSE

# @jax.jit
def update_step(self,grads, opt_state, params ):
updates, opt_state = self.optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params


def update(self):
forward = get_forward_function(self.num_hidden, self.num_layers, self.num_message_passing_steps)
net_hk = hk.without_apply_rng(hk.transform(forward))
self.rng = jax.random.PRNGKey(self.seed)
self.rng_seq = rng_sequence_from_rng(self.rng)

if self.train_on_shortest_path:
graph, targets = sample_padded_grid_batch_shortest_path(
self.rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max
)
else:
graph, targets = sample_padded_grid_batch_shortest_path(
self.rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max
)
self.params = net_hk.init(self.rng, graph)
optimizer = optax.adam(self.learning_rate)
self.opt_state = optimizer.init(self.params)

@jax.jit
def compute_loss(params, inputs, targets):
# @jax.jit
#def compute_loss(params, inputs, targets):
# not jitted because it will get jitted in jax.value_and_grad
outputs = net_hk.apply(params, inputs)
return jnp.mean((outputs[0].nodes - targets) ** 2) # using MSE
# outputs = net_hk.apply(params, inputs)
# return jnp.mean((outputs[0].nodes - targets) ** 2) # using MSE

@jax.jit
def update_step(grads, opt_state, params, ):
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params
#@jax.jit
# def update_step(grads, opt_state, params, ):
# updates, opt_state = optimizer.update(grads, opt_state, params)
# params = optax.apply_updates(params, updates)
# return params

rng = next(self.rng_seq)
graph_test, target_test = sample_padded_grid_batch_shortest_path(
Expand All @@ -210,30 +227,30 @@ def update_step(grads, opt_state, params, ):
# Sample a new batch of graph every itterations
if self.resample:
if self.train_on_shortest_path:
graph, targets = sample_padded_grid_batch_shortest_path(
self.graph, self.targets = sample_padded_grid_batch_shortest_path(
rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max
)
else:
graph, targets = sample_padded_grid_batch_shortest_path(
self.graph, self.targets = sample_padded_grid_batch_shortest_path(
rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max
)
targets = graph.nodes
self.targets = self.graph.nodes
# Train
loss, grads = jax.value_and_grad(compute_loss)(
self.params, graph, targets
loss, grads = jax.value_and_grad(self.compute_loss)(
self.params, self.graph, self.targets
) # jits inside of value_and_grad
self.params = update_step(grads, self.opt_state, self.params)
self.params = self.update_step(grads, self.opt_state, self.params)
self.losses.append(loss)
outputs_train, roc_auc_train, MCC_train = self.evaluate(net_hk, self.params, graph, targets)
outputs_train, roc_auc_train, MCC_train = self.evaluate(self.net_hk, self.params, self.graph, self.targets)
self.roc_aucs_train.append(roc_auc_train)
self.MCCs_train.append(MCC_train) # Matthews correlation coefficient
# Test # model should basically learn to do nothing from this
loss_test = compute_loss(self.params,graph_test, target_test)
loss_test = self.compute_loss(self.params,graph_test, target_test)
self.losses_test.append(loss_test)
outputs_test, roc_auc_test, MCC_test = self.evaluate(net_hk, self.params, graph_test, target_test)
outputs_test, roc_auc_test, MCC_test = self.evaluate(self.net_hk, self.params, graph_test, target_test)
self.roc_aucs_test.append(roc_auc_test)
self.MCCs_test.append(MCC_test)
self.net_hk = net_hk
self.net_hk = self.net_hk

# Log
wandb_logs = {"loss": loss, "losses_test": loss_test, "roc_auc_test": roc_auc_test, "roc_auc": roc_auc_train}
Expand Down

0 comments on commit 9660e05

Please sign in to comment.