Skip to content

Commit

Permalink
fixbugs
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementineDomine committed Nov 2, 2023
1 parent 4c5a9fa commit 3a99009
Show file tree
Hide file tree
Showing 5 changed files with 181 additions and 110 deletions.
122 changes: 74 additions & 48 deletions neuralplayground/agents/domine_2023.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
plot_input_target_output,
plot_message_passing_layers,
plot_curves,

)
from neuralplayground.agents.domine_2023_extras.class_utils import (
rng_sequence_from_rng,
set_device,
update_outputs_test,
)
from sklearn.metrics import matthews_corrcoef, roc_auc_score

Expand Down Expand Up @@ -64,11 +66,10 @@ def __init__( # autogenerated
self.grad_history = []
self.train_on_shortest_path = train_on_shortest_path
self.experiment_name = experiment_name
self.train_on_shortest_path = train_on_shortest_path
self.resample = resample
self.wandb_on = wandb_on
self.seed = seed

self.seed = seed
self.feature_position = feature_position
self.weighted = weighted

Expand All @@ -79,10 +80,6 @@ def __init__( # autogenerated
self.num_training_steps = num_training_steps
# cconfig.num_training_steps # @param

self.batch_size = batch_size
self.nx_min = nx_min
self.nx_max = nx_max

# This can be tought of the brain making different rep of different granularity
# Could be explained during sleep
self.batch_size_test = batch_size_test
Expand All @@ -98,6 +95,8 @@ def __init__( # autogenerated
self.room_depth = np.diff(self.arena_y_limits)[0]
self.agent_step_size = 0



self.log_every = num_training_steps // 10
if self.weighted:
self.edge_lables = True
Expand All @@ -116,6 +115,31 @@ def __init__( # autogenerated
os.mkdir(os.path.join(save_path, "results"))
self.save_path = os.path.join(save_path, "results")

wandb_logs = {
"train_on_shortest_path": train_on_shortest_path,
"resample": resample,

"batch_size_test": batch_size_test,
"nx_min_test": nx_min_test, # This is thought of the state density
"nx_max_test": nx_max_test, # This is thought of the state density
"batch_size": batch_size,
"nx_min": nx_min, # This is thought of the state density
"nx_max": nx_max,

"seed": seed,
"feature_position": feature_position,
"weighted": weighted,

"num_hidden": num_hidden,
"num_layers": num_layers,
"num_message_passing_steps": num_message_passing_steps,
"learning_rate": learning_rate,
"num_training_steps": num_training_steps,
}

if self.wandb_on:
wandb.log(wandb_logs)

else:
dateTimeObj = datetime.now()
save_path = os.path.join(Path(os.getcwd()).resolve(), "results")
Expand Down Expand Up @@ -166,10 +190,16 @@ def __init__( # autogenerated

def compute_loss(params, inputs, targets):
outputs = net_hk.apply(params, inputs)
return jnp.mean((outputs[0].nodes - targets) ** 2) # using MSE
return jnp.mean((outputs[0].nodes - targets) ** 2)

self._compute_loss = jax.jit(compute_loss)

def compute_loss_per_graph(params, inputs, targets):
outputs = net_hk.apply(params, inputs)
#for each graph:
return (outputs[0].nodes - targets) ** 2


def update_step(params, opt_state):
loss, grads = jax.value_and_grad(compute_loss)(
params, self.graph, self.targets
Expand All @@ -193,11 +223,7 @@ def evaluate(params, inputs, target,wse_value=True,indices=None):
output = outputs[0].nodes
for ind in indices:
output = output.at[ind].set(0)
#lst = list(outputs)
#lst[0].nodes.at[ind].set(0)
#output = tuple(lst)
#outputs = outputs[0].nodes.replace(0, 0)
#outputs = outputs.replace(0,0)

MCC = matthews_corrcoef(
np.squeeze(target), round(np.squeeze(output))
)
Expand Down Expand Up @@ -289,7 +315,6 @@ def update(self):

rng = next(self.rng_seq)
# Sample
# target_test_wse = target_test - graph_test.nodes[:, 0]
if self.resample:
self.graph, self.targets = sample_padded_grid_batch_shortest_path(
rng,
Expand Down Expand Up @@ -352,17 +377,17 @@ def update(self):
# Test without end start in the target
loss_test_wse = self._compute_loss(self.params, graph_test, target_test_wse)
self.losses_test_wse.append(loss_test_wse)
outputs_test_wse_list, roc_auc_test_wse, MCC_test_wse = self._evaluate(
outputs_test_wse_wrong, roc_auc_test_wse, MCC_test_wse = self._evaluate(
self.params, graph_test, target_test_wse, False, indices_test
)
self.MCCs_test_wse.append(MCC_test_wse)

# Log
wandb_logs = {
"loss_test": loss_test,
"loss_test_wse": loss_test_wse,
"loss_train": loss,
"loss_train_wse": loss_wse,
# "loss_test": loss_test,
# "loss_test_wse": loss_test_wse,
# "loss_train": loss,
# "loss_train_wse": loss_wse,

"log_loss_test": np.log(loss_test),
"log_loss_test_wse": np.log(loss_test_wse),
Expand All @@ -386,7 +411,7 @@ def update(self):
self.global_steps = self.global_steps + 1
if self.global_steps % self.log_every == 0:
print(
f"Training step {self.global_steps}: loss = {loss} , loss_test = {loss_test}, roc_auc_test = {roc_auc_test}, roc_auc_train = {roc_auc_train}"
f"Training step {self.global_steps}: log_loss = {np.log(loss)} , log_loss_test = {np.log(loss_test)}, roc_auc_test = {roc_auc_test}, roc_auc_train = {roc_auc_train}"
)
return

Expand Down Expand Up @@ -431,27 +456,25 @@ def print_and_plot(self):
target_test_wse = target_test - graph_test.nodes[:]
target_wse = self.targets - self.graph.nodes[:]



outputs_test, roc_auc_test_wse, MCC_test_wse = self._evaluate(
self.params, graph_test, target_test_wse, False,indices_test
)
outputs_test_wse= update_outputs_test(outputs_test, indices_test)

outputs_test, roc_auc_test, MCC_test = self._evaluate(
self.params, graph_test, target_test, True
)
outputs_test_wse, roc_auc_test_wse, MCC_test_wse = self._evaluate(
self.params, graph_test, target_test_wse, False,indices_test
)

outputs_test_wse_list = outputs_test_wse[0].nodes
for ind in indices_test:
outputs_test_wse_list = outputs_test_wse_list.at[ind].set(0)
outputs, roc_auc_wse, MCC_wse = self._evaluate(
self.params, self.graph, target_wse, False, indices_train
)
outputs_train_wse = update_outputs_test(outputs, indices_train)

outputs, roc_auc, MCC = self._evaluate(
self.params, self.graph, self.targets, True
)
outputs_wse, roc_auc_wse, MCC_wse = self._evaluate(
self.params, self.graph, target_wse, False, indices_train
)

outputs_wse_list = outputs_wse[0].nodes
for ind in indices_train:
outputs_wse_list = outputs_wse_list.at[ind].set(0)

# SAVE PARAMETER (NOT WE SAVE THE FILES SO IT SHOULD BE THERE AS WELL )
if self.wandb_on:
Expand Down Expand Up @@ -556,9 +579,9 @@ def print_and_plot(self):
plot_input_target_output(
list(graph_test.nodes.sum(-1)),
target_test.sum(-1),
outputs_test[0].nodes.tolist(),
np.squeeze(outputs_test[0].nodes).tolist(),
graph_test,
6,
2,
self.edge_lables,
os.path.join(self.save_path, "in_out_targ_test.pdf"),
"in_out_targ_test",
Expand All @@ -570,7 +593,7 @@ def print_and_plot(self):
target_test.sum(-1),
new_vector,
graph_test,
6,
2,
self.edge_lables,
os.path.join(self.save_path, "in_out_targ_test_threshold.pdf"),
"in_out_targ_test",
Expand All @@ -580,9 +603,10 @@ def print_and_plot(self):
list(graph_test.nodes.sum(-1)),
outputs_test[1],
target_test.sum(-1),
outputs_test[0].nodes.tolist(),
np.squeeze(
outputs_test[0].nodes).tolist(),
graph_test,
6,
2,
self.num_message_passing_steps,
self.edge_lables,
os.path.join(
Expand All @@ -595,9 +619,9 @@ def print_and_plot(self):
plot_input_target_output(
list(graph_test.nodes.sum(-1)),
target_test_wse.sum(-1),
outputs_test_wse_list,
np.squeeze(outputs_test_wse).tolist(),
graph_test,
6,
2,
self.edge_lables,
os.path.join(self.save_path, "in_out_targ_test_wse.pdf"),
"in_out_targ_test_wse",
Expand All @@ -612,7 +636,7 @@ def print_and_plot(self):
self.targets.sum(-1),
new_vector,
self.graph,
6,
2,
self.edge_lables,
os.path.join(self.save_path, "in_out_targ_train_threshol.pdf"),
"in_out_targ_train",
Expand All @@ -621,9 +645,9 @@ def print_and_plot(self):
plot_input_target_output(
list(self.graph.nodes.sum(-1)),
target_wse.sum(-1),
outputs_wse_list,
np.squeeze(outputs_train_wse).tolist(),
self.graph,
6,
2,
self.edge_lables,
os.path.join(self.save_path, "in_out_targ_train_wse.pdf"),
"in_out_targ_train_wse",
Expand Down Expand Up @@ -662,9 +686,10 @@ def print_and_plot(self):
plot_input_target_output(
list(self.graph.nodes.sum(-1)),
self.targets.sum(-1),
outputs[0].nodes.tolist(),
np.squeeze(
outputs[0].nodes).tolist(),
self.graph,
6,
2,
self.edge_lables,
os.path.join(self.save_path, "in_out_targ_train.pdf"),
"in_out_targ_train",
Expand All @@ -674,9 +699,10 @@ def print_and_plot(self):
list(self.graph.nodes.sum(-1)),
outputs[1],
self.targets.sum(-1),
outputs[0].nodes.tolist(),
np.squeeze(
outputs[0].nodes).tolist(),
self.graph,
6,
2,
self.num_message_passing_steps,
self.edge_lables,
os.path.join(self.save_path, "message_passing_graph_train.pdf"),
Expand Down Expand Up @@ -750,14 +776,14 @@ def print_and_plot(self):
nx_max=config.nx_max,
batch_size_test=config.batch_size_test,
nx_min_test=config.nx_min_test,
nx_max_test=7,
nx_max_test=config.nx_max_test,
arena_y_limits=arena_y_limits,
arena_x_limits=arena_x_limits,
)

for n in range(config.num_training_steps):
agent.update()
agent.print_and_plot()
agent.print_and_plot()

# TODO: Run manadger (not possible for now), to get a seperated code we would juste need to change the paths and config this would mean get rid of the comfig
# The other alternative is to see that we have multiple env that we resample every time
Expand Down
Loading

0 comments on commit 3a99009

Please sign in to comment.