From 1cd20e2e03f8dd1f1f72ab15508a8cf0f9c4eec3 Mon Sep 17 00:00:00 2001 From: clementine Date: Tue, 5 Nov 2024 19:16:33 -0500 Subject: [PATCH] lengths and mp study --- neuralplayground/agents/debug.py | 58 -- neuralplayground/agents/domine_2023.py | 885 ------------------ neuralplayground/agents/domine_2023_2.py | 284 ++---- neuralplayground/agents/domine_2023_2_mp.py | 117 +++ neuralplayground/agents/domine_2023_2seed.py | 239 +++++ neuralplayground/agents/domine_2023_3.py | 389 -------- neuralplayground/agents/domine_2023_4.py | 736 --------------- .../class_config_template.py | 20 - .../agents/domine_2023_extras_2/config.yaml | 21 +- .../domine_2023_extras_2/models/GCN_model.py | 60 +- .../processing/Graph_generation.py | 9 +- .../utils/plotting_utils.py | 88 +- neuralplayground/agents/lenghts.py | 557 +++++++++++ 13 files changed, 1093 insertions(+), 2370 deletions(-) delete mode 100644 neuralplayground/agents/debug.py delete mode 100644 neuralplayground/agents/domine_2023.py create mode 100644 neuralplayground/agents/domine_2023_2_mp.py create mode 100644 neuralplayground/agents/domine_2023_2seed.py delete mode 100644 neuralplayground/agents/domine_2023_3.py delete mode 100644 neuralplayground/agents/domine_2023_4.py create mode 100644 neuralplayground/agents/lenghts.py diff --git a/neuralplayground/agents/debug.py b/neuralplayground/agents/debug.py deleted file mode 100644 index 7be1ed3d..00000000 --- a/neuralplayground/agents/debug.py +++ /dev/null @@ -1,58 +0,0 @@ -import matplotlib.pyplot as plt -from IPython.display import HTML, display -from tqdm.notebook import tqdm - -from neuralplayground.agents import Stachenfeld2018 -from neuralplayground.arenas import Hafting2008 - -display(HTML("")) -import matplotlib.pyplot as plt -from tqdm.notebook import tqdm - -from neuralplayground.arenas import Hafting2008 - -env = Hafting2008(time_step_size=0.1, agent_step_size=None, use_behavioral_data=True) - -agent_step_size = 10 -discount = 0.9 -threshold = 1e-6 -lr_td = 1e-2 -t_episode = 1000 -n_episode = 100 -state_density = 1 / agent_step_size -twoDvalue = True - -agent = Stachenfeld2018( - discount=discount, - t_episode=t_episode, - n_episode=n_episode, - threshold=threshold, - lr_td=lr_td, - room_width=env.room_width, - room_depth=env.room_depth, - state_density=state_density, - twoD=twoDvalue, -) - - -plot_every = 100000 -total_iters = 0 -obs, state = env.reset() -obs = obs[:2] -for i in tqdm(range(100001)): - # Observe to choose an action - action = agent.act(obs) # the action is link to density of state to make sure we always land in a new - K = agent.update() - obs, state, reward = env.step(action) - obs = obs[:2] - total_iters += 1 - if total_iters % plot_every == 0: - agent.plot_rate_map(sr_matrix=agent.srmat, eigen_vectors=[1, 10, 15, 20], save_path="./sr_Hating.png") -agent.plot_rate_map(sr_matrix=agent.srmat, eigen_vectors=[1, 10, 15, 20], save_path="./sr_Hating.png") -T = agent.get_T_from_M(agent.srmat_ground) -agent.plot_transition() -T = agent.get_T_from_M(agent.srmat_ground) -agent.plot_transition() -ax = env.plot_trajectory(plot_every=100) -plt.show() -print("hello") diff --git a/neuralplayground/agents/domine_2023.py b/neuralplayground/agents/domine_2023.py deleted file mode 100644 index 5374d230..00000000 --- a/neuralplayground/agents/domine_2023.py +++ /dev/null @@ -1,885 +0,0 @@ -import argparse -import os -import shutil -from datetime import datetime -from pathlib import Path -import haiku as hk -import jax -import jax.ops as jop -import jax.numpy as jnp -import optax -import wandb -from sklearn.metrics import matthews_corrcoef, roc_auc_score -from neuralplayground.agents.agent_core import AgentCore - -os.environ["KMP_DUPLICATE_LIB_OK"] = "True" -from neuralplayground.agents.domine_2023_extras.class_Graph_generation import ( - sample_padded_batch_graph, -) -from neuralplayground.agents.domine_2023_extras.class_grid_run_config import GridConfig -from neuralplayground.agents.domine_2023_extras.class_models import get_forward_function -from neuralplayground.agents.domine_2023_extras.class_plotting_utils import ( - plot_input_target_output, - plot_message_passing_layers, - plot_curves, -) -from neuralplayground.agents.domine_2023_extras.class_utils import ( - rng_sequence_from_rng, - set_device, - update_outputs_test, - get_length_shortest_path, -) - - - -#TODO: Implement all in Neuralplayground -class Domine2023( - AgentCore, -): - def __init__( # autogenerated - self, - # agent_name: str = "SR", - experiment_name="smaller size generalisation graph with no position feature", - train_on_shortest_path: bool = True, - resample: bool = True, - wandb_on: bool = False, - seed: int = 41, - feature_position: bool = False, - weighted: bool = True, - num_hidden: int = 100, - num_layers: int = 2, - num_message_passing_steps: int = 3, - learning_rate: float = 0.001, - num_training_steps: int = 10, - residual=True, - layer_norm=True, - batch_size: int = 4, - nx_min: int = 4, - nx_max: int = 7, - batch_size_test: int = 4, - nx_min_test: int = 4, - nx_max_test: int = 7, - grid: bool= True, - plot: bool= True, - dist_cutoff=10, - n_std_dist_cutoff=5, - - **mod_kwargs, - ): - - self.grid = grid - self.plot = plot - self.obs_history = [] - self.grad_history = [] - self.train_on_shortest_path = train_on_shortest_path - self.experiment_name = experiment_name - self.resample = resample - self.wandb_on = wandb_on - self.dist_cutoff = dist_cutoff , - self.n_std_dist_cutoff= n_std_dist_cutoff, - - self.seed = seed - self.feature_position = feature_position - self.weighted = weighted - - self.num_hidden = num_hidden - self.num_layers = num_layers - self.num_message_passing_steps = num_message_passing_steps - self.learning_rate = learning_rate - self.num_training_steps = num_training_steps - - # This can be tought of the brain making different rep of different granularity - # Could be explained during sleep - self.batch_size_test = batch_size_test - self.nx_min_test = nx_min_test # This is thought of the state density - self.nx_max_test = nx_max_test # This is thought of the state density - self.batch_size = batch_size - self.nx_min = nx_min # This is thought of the state density - self.nx_max = nx_max - - self.arena_x_limits = mod_kwargs["arena_y_limits"] - self.arena_y_limits = mod_kwargs["arena_y_limits"] - self.agent_step_size = 0 - self.residuals = residual - self.layer_norm = layer_norm - - self.log_every = num_training_steps // 10 - if self.weighted: - self.edge_lables = True - else: - self.edge_lables = True - - if self.wandb_on: - dateTimeObj = datetime.now() - wandb.init( - project="graph-delaunay_small", - entity="graph-brain", - name=experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), - ) - self.wandb_logs = {} - save_path = wandb.run.dir - os.mkdir(os.path.join(save_path, "results")) - self.save_path = os.path.join(save_path, "results") - self.reset() - - rng = jax.random.PRNGKey(self.seed) - self.rng_seq = rng_sequence_from_rng(rng) - - if self.train_on_shortest_path: - self.graph, self.targets = sample_padded_batch_graph( - rng, - self.batch_size, - self.feature_position, - self.weighted, - self.nx_min, - self.nx_max, - self.grid, - self.dist_cutoff[0], - self.n_std_dist_cutoff[0], - ) - rng = next(self.rng_seq) - self.graph_test, self.target_test = sample_padded_batch_graph( - rng, - self.batch_size_test, - self.feature_position, - self.weighted, - self.nx_min_test, - self.nx_max_test, - self.grid, - self.dist_cutoff[0], - self.n_std_dist_cutoff[0], - ) - - else: - self.graph_test, self.target_test = sample_padded_batch_graph( - rng, - self.batch_size_test, - self.feature_position, - self.weighted, - self.nx_min_test, - self.nx_max_test, - self.grid, - self.dist_cutoff[0], - self.n_std_dist_cutoff[0], - ) - self.target_test = jnp.reshape( - self.graph_test.nodes[:, 0], (self.graph_test.nodes[:, 0].shape[0], -1) - ) - rng = next(self.rng_seq) - self.graph, self.targets = sample_padded_batch_graph( - rng, - self.batch_size, - self.feature_position, - self.weighted, - self.nx_min, - self.nx_max, - self.grid, - self.dist_cutoff[0], - self.n_std_dist_cutoff[0], - ) - self.targets = jnp.reshape( - self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1) - ) - - if self.feature_position: - self.indices_train = jnp.where(self.graph.nodes[:] == 1)[0] - self.indices_test = jnp.where(self.graph_test.nodes[:, 0] == 1)[0] - - self.target_test_wse = self.target_test - jnp.reshape( - self.graph_test.nodes[:, 0], (self.graph_test.nodes[:, 0].shape[0], -1) - ) - self.target_wse = self.targets - jnp.reshape( - self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1) - ) - else: - self.indices_train = jnp.where(self.graph.nodes[:] == 1)[0] - self.indices_test = jnp.where(self.graph_test.nodes[:] == 1)[0] - self.target_test_wse = self.target_test - self.graph_test.nodes[:] - self.target_wse = self.targets - self.graph.nodes[:] - - forward = get_forward_function( - self.num_hidden, - self.num_layers, - self.num_message_passing_steps, - self.residuals, - self.layer_norm, - ) - - net_hk = hk.without_apply_rng(hk.transform(forward)) - params = net_hk.init(rng, self.graph) - param_count = sum(x.size for x in jax.tree_util.tree_leaves(params)) - print("Total number of parameters: %d" % param_count) - self.params = params - optimizer = optax.adam(self.learning_rate) - opt_state = optimizer.init(self.params) - self.opt_state = opt_state - - def compute_loss(params, graph, targets): - outputs = net_hk.apply(params, graph) - return jnp.mean((outputs[0].nodes - targets) ** 2) - - self._compute_loss = jax.jit(compute_loss) - - def compute_output(params, graph, ): - outputs = net_hk.apply(params, graph) - return outputs - - self._compute_output = jax.jit(compute_output) - - def compute_loss_per_node(params, graph, targets): - outputs = net_hk.apply(params, graph) - return (outputs[0].nodes - targets) ** 2 - - self._compute_loss_per_node = jax.jit(compute_loss_per_node) - - def compute_loss_per_graph(params, graph, targets): - outputs = self._compute_output(params, graph) - node_features = jnp.squeeze(targets) # n_node_total x n_feat - # graph id for each node - i = int(0) - for n in graph.n_node: - if i == 0: - graph_ids = jnp.zeros(n) + i - else: - graph_id = jnp.zeros(n) + i - graph_ids = jnp.concatenate([graph_ids, graph_id], axis=0) - i = i + 1 - graph_ids = jnp.concatenate( - [jnp.zeros(n) + i for i, n in enumerate(graph.n_node)], axis=0 - ) - assert graph_ids.shape[0] == node_features.shape[0] - summed_outputs = jop.segment_sum(outputs[0].nodes, graph_ids.astype(int)) - summed_node_features = jop.segment_sum(node_features, graph_ids.astype(int)) - assert summed_node_features.shape[0] == graph.n_node.shape[0] - denom = graph.n_node - denom = jnp.where(denom == 0, 1, denom) - - mean_node_features = summed_node_features / denom - mean_outputs = jnp.squeeze(summed_outputs) / denom - return (mean_node_features - mean_outputs) ** 2 - - self._compute_loss_per_graph = compute_loss_per_graph - - #def compute_loss_per_graph(params, graph, targets, n_node): - # outputs = self._compute_output(params, graph) - # node_features = jnp.squeeze(targets) # n_node_total x n_feat - # n_graph = n_node.shape[0] - # sum_n_node = jnp.sum(n_node) - # graph_idx = jnp.arange(n_graph) - # To aggregate nodes and edges from each graph to global features, - # we first construct tensors that map the node to the corresponding graph. - # For example, if you have `n_node=[1,2]`, we construct the tensor - # [0, 1, 1]. We then do the same for edges. - # node_gr_idx = jnp.repeat( - # graph_idx, n_node, axis=0, total_repeat_length=sum_n_node) - # assert node_gr_idx.shape[0] == node_features.shape[0] - # summed_outputs = jop.segment_sum(outputs[0].nodes, node_gr_idx.astype(int)) - # summed_node_features = jop.segment_sum(node_features, node_gr_idx .astype(int)) - # assert summed_node_features.shape[0] == graph.n_node.shape[0] - # denom = graph.n_node - # denom = jnp.where(denom == 0, 1, denom) - # mean_node_features = summed_node_features / denom - # mean_outputs = jnp.squeeze(summed_outputs) / denom - # return (mean_node_features - mean_outputs) ** 2 - - self._compute_loss_per_graph = compute_loss_per_graph - - def compute_loss_nodes_shortest_path(params, graph, targets): - outputs = net_hk.apply(params, graph) - node_features = jnp.squeeze(targets) # n_node_total x n_feat - # graph id for each node - i = int(0) - for n in graph.n_node: - if i == 0: - graph_ids = jnp.zeros(n) + i - else: - graph_id = jnp.zeros(n) + i - graph_ids = jnp.concatenate([graph_ids, graph_id], axis=0) - i = i + 1 - - graph_ids = graph_ids + (jnp.squeeze(targets * i)) - denom = [jnp.size(jnp.where(graph_ids[:] == n)) for n in range((len(graph.n_node)-1)*2+1)] - denom= jnp.asarray(denom) - denom = jnp.where(denom == 0, 1, denom) - assert graph_ids.shape[0] == node_features.shape[0] - summed_outputs = jnp.squeeze( - jop.segment_sum(outputs[0].nodes, graph_ids.astype(int)) - ) - summed_node_features = jop.segment_sum(node_features, graph_ids.astype(int)) - mean_summed_outputs = summed_outputs /denom - mean_summed_node_features=summed_node_features / denom - - return (mean_summed_outputs - mean_summed_node_features) ** 2 # np.concatenate((np.squeeze(loss_per_graph),np.asarray(len_shortest_path)),axis=0) - self._compute_loss_nodes_shortest_path = compute_loss_nodes_shortest_path - - def update_step(params, opt_state): - loss, grads = jax.value_and_grad(compute_loss)( - params, self.graph, self.targets - ) # jits inside of value_and_grad - updates, opt_state = optimizer.update(grads, opt_state, params) - params = optax.apply_updates(params, updates) - return params, opt_state, loss - self._update_step = jax.jit(update_step) - - def evaluate(params, inputs, target, wse_value=True, indices=None): - outputs = net_hk.apply(params, inputs) - if wse_value: - roc_auc = roc_auc_score( - jnp.squeeze(target), jnp.squeeze(outputs[0].nodes) - ) - MCC = matthews_corrcoef( - jnp.squeeze(target), round(jnp.squeeze(outputs[0].nodes)) - ) - else: - output = outputs[0].nodes - for ind in indices: - output = output.at[ind].set(0) - - MCC = matthews_corrcoef(jnp.squeeze(target), round(jnp.squeeze(output))) - roc_auc = False - - return outputs, roc_auc, MCC - - - self._evaluate = evaluate - - wandb_logs = { - "train_on_shortest_path": train_on_shortest_path, - "resample": resample, - "batch_size_test": batch_size_test, - "nx_min_test": nx_min_test, # This is thought of the state density - "nx_max_test": nx_max_test, # This is thought of the state density - "batch_size": batch_size, - "nx_min": nx_min, # This is thought of the state density - "nx_max": nx_max, - "seed": seed, - "feature_position": feature_position, - "weighted": weighted, - "num_hidden": num_hidden, - "num_layers": num_layers, - "num_message_passing_steps": num_message_passing_steps, - "learning_rate": learning_rate, - "num_training_steps": num_training_steps, - "param_count": param_count, - "residual": residual, - "layer_norm": layer_norm, - } - - if self.wandb_on: - wandb.log(wandb_logs) - - else: - dateTimeObj = datetime.now() - save_path = os.path.join(Path(os.getcwd()).resolve(), "results") - os.mkdir( - os.path.join( - save_path, - self.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), - ) - ) - self.save_path = os.path.join( - os.path.join( - save_path, - self.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), - ) - ) - self.saving_run_parameters() - - def saving_run_parameters(self): - path = os.path.join(self.save_path, "run.py") - HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023.py") - shutil.copyfile(HERE, path) - - path = os.path.join(self.save_path, "class_Graph_generation.py") - HERE = os.path.join( - Path(os.getcwd()).resolve(), "domine_2023_extras/class_Graph_generation.py" - ) - shutil.copyfile(HERE, path) - - path = os.path.join(self.save_path, "class_utils.py") - HERE = os.path.join( - Path(os.getcwd()).resolve(), "domine_2023_extras/class_utils.py" - ) - shutil.copyfile(HERE, path) - - path = os.path.join(self.save_path, "class_plotting_utils.py") - HERE = os.path.join( - Path(os.getcwd()).resolve(), "domine_2023_extras/class_plotting_utils.py" - ) - shutil.copyfile(HERE, path) - #ToDo:change that eventually because it is not saving the right things - path = os.path.join(self.save_path, "class_config_run.yaml") - HERE = os.path.join( - Path(os.getcwd()).resolve(), "domine_2023_extras/class_config_base.yaml" - ) - shutil.copyfile(HERE, path) - - def reset(self, a=1): - self.obs_history = [] # Initialize observation history to update weights later - self.grad_history = [] - self.global_steps = 0 - self.losses_train = [] - self.losses_test = [] - self.log_losses_per_node_test = [] - self.log_losses_per_graph_test = [] - self.log_losses_per_shortest_path_test = [] - self.losses_train_wse = [] - self.losses_test_wse = [] - self.roc_aucs_train = [] - self.MCCs_train = [] - self.MCCs_test = [] - self.roc_aucs_test = [] - self.MCCs_train_wse = [] - self.MCCs_test_wse = [] - return - - def update(self): - rng = next(self.rng_seq) - if self.resample: - if self.train_on_shortest_path: - self.graph, self.targets = sample_padded_batch_graph( - rng, - self.batch_size, - self.feature_position, - self.weighted, - self.nx_min, - self.nx_max, - self.grid, - self.dist_cutoff, - self.n_std_dist_cutoff[0], - ) - else: - rng = next(self.rng_seq) - # Sample - self.graph, self.targets = sample_padded_batch_graph( - rng, - self.batch_size, - self.feature_position, - self.weighted, - self.nx_min, - self.nx_max, - self.grid, - self.dist_cutoff, - self.n_std_dist_cutoff[0], - ) - self.targets = jnp.reshape( - self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1) - ) - - if self.feature_position: - self.indices_train = jnp.where(self.graph.nodes[:] == 1)[0] - self.target_wse = self.targets - jnp.reshape( - self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1) - ) - else: - self.indices_train = jnp.where(self.graph.nodes[:] == 1)[0] - self.target_wse = self.targets - self.graph.nodes[:] - - # Train - self.params, self.opt_state, loss = self._update_step( - self.params, self.opt_state - ) - self.losses_train.append(loss) - self.outputs_train, roc_auc_train, MCC_train = self._evaluate( - self.params, self.graph, self.targets, True - ) - - self.roc_aucs_train.append(roc_auc_train) - self.MCCs_train.append(MCC_train) - - # Train without end start in the target - loss_wse = self._compute_loss(self.params, self.graph, self.target_wse) - self.losses_train_wse.append(loss_wse) - outputs_train_wse_wrong, roc_auc_train_wse, MCC_train_wse = self._evaluate( - self.params, self.graph, self.target_wse, False, self.indices_train - ) - self.outputs_train_wse = update_outputs_test( - outputs_train_wse_wrong, self.indices_train - ) - self.MCCs_train_wse.append(MCC_train_wse) - - # Test - loss_test_per_node = self._compute_loss_per_node( - self.params, self.graph_test, self.target_test - ) - loss_test_per_graph = self._compute_loss_per_graph( - self.params, self.graph_test, self.target_test - ) - loss_nodes_shortest_path = self._compute_loss_nodes_shortest_path( - self.params, self.graph_test, self.target_test - ) - self.log_losses_per_node_test.append(jnp.log(jnp.squeeze(loss_test_per_node))) - self.log_losses_per_graph_test.append(jnp.log(loss_test_per_graph)) - self.log_losses_per_shortest_path_test.append(jnp.log(loss_nodes_shortest_path)) - - loss_test = self._compute_loss(self.params, self.graph_test, self.target_test) - self.losses_test.append(loss_test) - self.outputs_test, roc_auc_test, MCC_test = self._evaluate( - self.params, self.graph_test, self.target_test, True - ) - self.roc_aucs_test.append(roc_auc_test) - self.MCCs_test.append(MCC_test) - - # Test without end start in the target - loss_test_wse = self._compute_loss( - self.params, self.graph_test, self.target_test_wse - ) - self.losses_test_wse.append(loss_test_wse) - outputs_test_wse_wrong, roc_auc_test_wse, MCC_test_wse = self._evaluate( - self.params, self.graph_test, self.target_test_wse, False, self.indices_test - ) - self.outputs_test_wse = update_outputs_test( - outputs_test_wse_wrong, self.indices_test - ) - self.MCCs_test_wse.append(MCC_test_wse) - - # Log - wandb_logs = { - "loss_test_per_node": jnp.log(jnp.squeeze(loss_test_per_node)), - "log_loss_test": jnp.log(loss_test), - "log_loss_test_wse": jnp.log(loss_test_wse), - "log_loss": jnp.log(loss), - "log_loss_wse": jnp.log(loss_wse), - "roc_auc_test": roc_auc_test, - "roc_auc_test_wse": roc_auc_test_wse, - "roc_auc_train": roc_auc_train, - "roc_auc_train_wse": roc_auc_train_wse, - "MCC_test": MCC_test, - "MCC_test_wse": MCC_test_wse, - "MCC_train": MCC_train, - "MCC_train_wse": MCC_train_wse, - } - if self.wandb_on: - wandb.log(wandb_logs) - self.global_steps = self.global_steps + 1 - if self.global_steps % self.log_every == 0: - # if self.plot == True: - # Uncomment if one wants to plot the activation at different time points - # self.plot_learning_curves(str(self.global_steps)) - # self.plot_activation(str(self.global_steps)) - print( - f"Training step {self.global_steps}: log_loss = {jnp.log(loss)} , log_loss_test = {jnp.log(loss_test)}, roc_auc_test = {roc_auc_test}, roc_auc_train = {roc_auc_train}" - ) - - if self.global_steps == self.num_training_steps: - if self.wandb_on: - with open("readme.txt", "w") as f: - f.write("readme") - with open(os.path.join(self.save_path, "Constant.txt"), "w") as outfile: - outfile.write( - "num_message_passing_steps" - + str(self.num_message_passing_steps) - + "\n" - ) - outfile.write("Learning_rate:" + str(self.learning_rate) + "\n") - outfile.write( - "num_training_steps:" + str(self.num_training_steps) + "\n" - ) - outfile.write("roc_auc" + str(roc_auc_test) + "\n") - outfile.write("MCC" + str(MCC_test) + "\n") - outfile.write("roc_auc_wse" + str(roc_auc_test_wse) + "\n") - outfile.write("MCC_wse" + str(MCC_test_wse) + "\n") - wandb.finish() - - if self.plot == True: - print("Plotting and Saving Figures") - self.plot_learning_curves(str(self.global_steps)) - self.plot_activation(str(self.global_steps)) - return - - def plot_learning_curves(self, trainning_step): - plot_curves( - [ - self.losses_train, - self.losses_test, - self.losses_train_wse, - self.losses_test_wse, - ], - os.path.join(self.save_path, "Losses_" + trainning_step + ".pdf"), - "All_Losses", - legend_labels=["loss", "loss test", "loss_wse", "loss_test_wse"], - ) - - plot_curves( - [ - jnp.log(jnp.asarray(self.losses_train)), - jnp.log(jnp.asarray(self.losses_test)), - jnp.log(jnp.asarray(self.losses_train_wse)), - jnp.log(jnp.asarray(self.losses_test_wse)), - ], - os.path.join(self.save_path, "Log_Losses_" + trainning_step + ".pdf"), - "All_log_Losses", - legend_labels=[ - "log_loss", - "log_loss test", - "log_loss_wse", - "log_loss_test_wse", - ], - ) - - plot_curves( - [self.losses_train], - os.path.join(self.save_path, "Losses_train_" + trainning_step + ".pdf"), - "Losses", - ) - plot_curves( - [self.losses_test], - os.path.join(self.save_path, "losses_test_" + trainning_step + ".pdf"), - "losses_test", - ) - - transposed_list = [list(item) for item in zip(*self.log_losses_per_node_test)] - plot_curves( - transposed_list, - os.path.join( - self.save_path, "Log_Losses_per_node_test_" + trainning_step + ".pdf" - ), - "Log_Losse_per_node", - ) - transposed_list = [list(item) for item in zip(*self.log_losses_per_graph_test)] - - plot_curves( - transposed_list, - os.path.join( - self.save_path, "Log_Losses_per_graph_test_" + trainning_step + ".pdf" - ), - "Log_Loss_per_graph " , - ["GRAPH" + str(n) for n in range(self.batch_size_test + 1)], - ) - - transposed_list = [ - list(item) for item in zip(*self.log_losses_per_shortest_path_test) - ] - shortest_path_length = get_length_shortest_path(self.graph_test, self.target_test) - - plot_curves( - [self.losses_train_wse], - os.path.join(self.save_path, "Losses_wse_" + trainning_step + ".pdf"), - "Losses_wse", - ) - plot_curves( - [self.losses_test_wse], - os.path.join(self.save_path, "losses_test_wse_" + trainning_step + ".pdf"), - "losses_test_wse", - ) - plot_curves( - transposed_list, - os.path.join( - self.save_path, "Log_Loss_on_shortest_path" + trainning_step + ".pdf"), - "Log_Loss_on shortest_path", - ["Other_node graph" + str(n) for n in range(self.batch_size_test + 1)]+ - ["SHORTEST_PATH graph_len_" + str(shortest_path_length[n]) + "graph_size" + str(p) for n, p in - enumerate(self.graph_test.n_node[:-1])]) - - plot_curves( - [self.roc_aucs_test, self.roc_aucs_train], - os.path.join(self.save_path, "auc_rocs_" + trainning_step + ".pdf"), - "All_auc_roc", - legend_labels=["auc_roc_test", "auc_roc_train_" + trainning_step + ".pdf"], - ) - plot_curves( - [self.roc_aucs_test], - os.path.join(self.save_path, "auc_roc_test_" + trainning_step + ".pdf"), - "auc_roc_test", - ) - plot_curves( - [self.roc_aucs_train], - os.path.join(self.save_path, "auc_roc_train_" + trainning_step + ".pdf"), - "auc_roc_train", - ) - - plot_curves( - [self.MCCs_train, self.MCCs_test, self.MCCs_train_wse, self.MCCs_test_wse], - os.path.join(self.save_path, "MCCs_" + trainning_step + ".pdf"), - "All_MCCs", - legend_labels=["MCC", "MCC test", "MCC_wse", "MCC_test_wse"], - ) - plot_curves( - [self.MCCs_train], - os.path.join(self.save_path, "MCC_train_" + trainning_step + ".pdf"), - "MCC_train", - ) - plot_curves( - [self.MCCs_test], - os.path.join(self.save_path, "MCC_test_" + trainning_step + ".pdf"), - "MCC_test", - ) - plot_curves( - [self.MCCs_train_wse], - os.path.join(self.save_path, "MCC_train_wse_" + trainning_step + ".pdf"), - "MCC_train_wse", - ) - plot_curves( - [self.MCCs_test_wse], - os.path.join(self.save_path, "MCC_test_wse_" + trainning_step + ".pdf"), - "MCC_test_wse", - ) - - def plot_activation(self, trainning_step): - # PLOTTING ACTIVATION FOR TEST AND THE TARGET OF THE THING ( NOTE THAT IS WAS TRANED ON THE ALL THING) - plot_input_target_output( - list(self.graph_test.nodes.sum(-1)), - self.target_test.sum(-1), - jnp.squeeze(self.outputs_test[0].nodes).tolist(), - self.graph_test, - 2, - self.edge_lables, - os.path.join(self.save_path, "in_out_targ_test_" + trainning_step + ".pdf"), - "in_out_targ_test", - ) - - new_vector = [1 if val > 0.3 else 0 for val in self.outputs_test[0].nodes] - plot_input_target_output( - list(self.graph_test.nodes.sum(-1)), - self.target_test.sum(-1), - new_vector, - self.graph_test, - 2, - self.edge_lables, - os.path.join( - self.save_path, "in_out_targ_test_threshold_" + trainning_step + ".pdf" - ), - "in_out_targ_test", - ) - - plot_message_passing_layers( - list(self.graph_test.nodes.sum(-1)), - self.outputs_test[1], - self.target_test.sum(-1), - jnp.squeeze(self.outputs_test[0].nodes).tolist(), - self.graph_test, - 2, - self.num_message_passing_steps, - self.edge_lables, - os.path.join( - self.save_path, - "message_passing_graph_test.pdf", - ), - "message_passing_graph_test", - ) - - plot_input_target_output( - list(self.graph_test.nodes.sum(-1)), - self.target_test_wse.sum(-1), - jnp.squeeze(self.outputs_test_wse).tolist(), - self.graph_test, - 2, - self.edge_lables, - os.path.join( - self.save_path, "in_out_targ_test_wse_" + trainning_step + ".pdf" - ), - "in_out_targ_test_wse", - ) - - # Train - # PLOTTING ACTIVATION OF THE FIRST 2 GRAPH OF THE BATCH - new_vector = [1 if val > 0.3 else 0 for val in self.outputs_train[0].nodes] - plot_input_target_output( - list(self.graph.nodes.sum(-1)), - self.targets.sum(-1), - new_vector, - self.graph, - 2, - self.edge_lables, - os.path.join( - self.save_path, "in_out_targ_train_threshold_" + trainning_step + ".pdf" - ), - "in_out_targ_train", - ) - - plot_input_target_output( - list(self.graph.nodes.sum(-1)), - self.target_wse.sum(-1), - jnp.squeeze(self.outputs_train_wse).tolist(), - self.graph, - 2, - self.edge_lables, - os.path.join( - self.save_path, "in_out_targ_train_wse_" + trainning_step + ".pdf" - ), - "in_out_targ_train_wse", - ) - - plot_input_target_output( - list(self.graph.nodes.sum(-1)), - self.targets.sum(-1), - jnp.squeeze(self.outputs_train[0].nodes).tolist(), - self.graph, - 2, - self.edge_lables, - os.path.join( - self.save_path, "in_out_targ_train_" + trainning_step + ".pdf" - ), - "in_out_targ_train", - ) - - plot_message_passing_layers( - list(self.graph.nodes.sum(-1)), - self.outputs_train[1], - self.targets.sum(-1), - jnp.squeeze(self.outputs_train[0].nodes).tolist(), - self.graph, - 2, - self.num_message_passing_steps, - self.edge_lables, - os.path.join( - self.save_path, "message_passing_graph_train_" + trainning_step + ".pdf" - ), - "message_passing_graph_train", - ) - - print("End") - - -if __name__ == "__main__": - from neuralplayground.arenas import Simple2D - - # @title Graph net functions - parser = argparse.ArgumentParser() - parser.add_argument( - "--config_path", - metavar="-C", - default="domine_2023_extras/class_config_base.yaml", - help="path to base configuration file.", - ) - - args = parser.parse_args() - set_device() - config_class = GridConfig - config = config_class(args.config_path) - - # Init environment - arena_x_limits = [-100, 100] - arena_y_limits = [-100, 100] - - agent = Domine2023( - experiment_name=config.experiment_name, - train_on_shortest_path=config.train_on_shortest_path, - resample=config.resample, # @param - wandb_on=config.wandb_on, - seed=config.seed, - feature_position=config.feature_position, - weighted=config.weighted, - num_hidden=config.num_hidden, # @param - num_layers=config.num_layers, # @param - num_message_passing_steps=config.num_message_passing_steps, # @param - learning_rate=config.learning_rate, # @param - num_training_steps=config.num_training_steps, # @param - batch_size=config.batch_size, - nx_min=config.nx_min, - nx_max=config.nx_max, - batch_size_test=config.batch_size_test, - nx_min_test=config.nx_min_test, - nx_max_test=config.nx_max_test, - arena_y_limits=arena_y_limits, - arena_x_limits=arena_x_limits, - residual=config.residual, - layer_norm=config.layer_norm, - grid = config.grid, - plot = config.plot, - dist_cutoff=config.dist_cutoff, - n_std_dist_cutoff= config.n_std_dist_cutoff, - ) - - for n in range(config.num_training_steps): - agent.update() diff --git a/neuralplayground/agents/domine_2023_2.py b/neuralplayground/agents/domine_2023_2.py index bd71e3e4..c85e47c0 100644 --- a/neuralplayground/agents/domine_2023_2.py +++ b/neuralplayground/agents/domine_2023_2.py @@ -12,7 +12,7 @@ import numpy as np from neuralplayground.agents.agent_core import AgentCore from neuralplayground.agents.domine_2023_extras_2.utils.plotting_utils import plot_curves, plot_curves_2, plot_2dgraphs -from neuralplayground.agents.domine_2023_extras_2.models.GCN_model import GCNModel +from neuralplayground.agents.domine_2023_extras_2.models.GCN_model import GCNModel, MLP ,GCNModel_2 from neuralplayground.agents.domine_2023_extras_2.class_grid_run_config import GridConfig from neuralplayground.agents.domine_2023_extras_2.utils.utils import set_device from neuralplayground.agents.domine_2023_extras_2.processing.Graph_generation import sample_graph, sample_target, sample_omniglot_graph, sample_fixed_graph @@ -24,10 +24,10 @@ class Domine2023(AgentCore): def __init__(self, experiment_name="smaller size generalisation graph with no position feature", - train_on_shortest_path=True, resample=True, wandb_on=False, seed=41, dataset = 'random', - weighted=True, num_hidden=100, num_layers=2, num_message_passing_steps=3, learning_rate=0.001, - num_training_steps=10, residual=True, layer_norm=True, batch_size=4, num_features=4, num_nodes_max=7, - batch_size_test=4, num_nodes_min_test=4, num_nodes_max_test=[7], plot=True, **mod_kwargs): + train_on_shortest_path=True, wandb_on=False, seed=41, dataset = 'random', + num_hidden=100, num_layers=2, num_message_passing_steps=3, learning_rate=0.001, + num_training_steps=10, residual=True, batch_size=4, num_features=4, num_nodes_max=7, + batch_size_test=4, num_nodes_max_test=[7], plot=True, **mod_kwargs): super(Domine2023, self).__init__() # General @@ -48,11 +48,10 @@ def __init__(self, experiment_name="smaller size generalisation graph with no po self.num_training_steps = num_training_steps self.batch_size = batch_size self.residual = residual - self.layer_norm = layer_norm + # Task self.dataset = dataset - self.weighted = weighted self.num_features = num_features self.num_nodes_max = num_nodes_max self.num_nodes_max_test = num_nodes_max_test @@ -74,7 +73,6 @@ def set_initial_seed(seed): # This is usually needed for reproducibility with certain layers like convolution. torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False - set_initial_seed(seed) self.batch_size_test = batch_size_test @@ -82,19 +80,33 @@ def set_initial_seed(seed): self.arena_y_limits = mod_kwargs["arena_y_limits"] save_path = mod_kwargs["save_path"] self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if self.dataset == 'random': - self.model = GCNModel(self.num_hidden, self.num_features + 2, self.num_layers, - self.num_message_passing_steps, self.residual, - self.layer_norm).to(self.device) - elif self.dataset == 'positional': - self.model = GCNModel(self.num_hidden, self.num_features + 3, self.num_layers, - self.num_message_passing_steps, self.residual, - self.layer_norm).to(self.device) + if self.num_message_passing_steps == 0: # If no message passing steps, use MLP + if self.dataset == 'random': + self.model = MLP(self.num_hidden, self.num_features + 2, self.num_layers, self.num_message_passing_steps, self.residual).to(self.device) + + elif self.dataset == 'positional': + self.model = MLP(self.num_hidden, self.num_features + 3, self.num_layers, self.num_message_passing_steps, self.residual).to(self.device) + + elif self.dataset == 'positional_no_edges': + self.model = MLP(self.num_hidden, self.num_features + 3, self.num_layers,self.num_message_passing_steps, self.residual).to(self.device) else: - num_features = 784 - self.model = GCNModel(self.num_hidden, num_features + 2, self.num_layers, - self.num_message_passing_steps, self.residual, - self.layer_norm).to(self.device) + if self.dataset == 'random': + self.model = GCNModel(self.num_hidden, self.num_features + 2, self.num_layers, + self.num_message_passing_steps, self.residual + ).to(self.device) + elif self.dataset == 'positional': + self.model = GCNModel(self.num_hidden, self.num_features + 3, self.num_layers, + self.num_message_passing_steps, self.residual + ).to(self.device) + elif self.dataset == 'positional_no_edges': + self.model = GCNModel_2(self.num_hidden, self.num_features + 3, self.num_layers, + self.num_message_passing_steps, self.residual + ).to(self.device) + else: + num_features = 784 + self.model = GCNModel(self.num_hidden, num_features + 2, self.num_layers, + self.num_message_passing_steps, self.residual, + ).to(self.device) self.auroc = AUROC(task="binary") @@ -113,17 +125,15 @@ def set_initial_seed(seed): self.reset() self.wandb_logs = { # This is thought of the state density "batch_size": self.batch_size, - "num_node_min": self.num_nodes_max, # This is thought of the state density + "num_node_max": self.num_nodes_max, # This is thought of the state density "seed": self.seed, "dataset": self.dataset, - "weighted": self.weighted, "num_hidden": self.num_hidden, "num_layers": self.num_layers, "num_message_passing_steps": self.num_message_passing_steps, "learning_rate": self.learning_rate, "num_training_steps": self.num_training_steps, "residual": self.residual, - "layer_norm": self.layer_norm, } if self.wandb_on: wandb.log(self.wandb_logs) @@ -264,12 +274,12 @@ def train(self): #for i in len(self.num_nodes_max_test): # This is an attemp - node_features_val_f, edges_val_f, edge_features_tensor_val_f, target_val_f = self.load_data(fixed=True, - dataset=self.dataset, - batch_size=self.batch_size, - num_nodes= - self.num_nodes_max_test[0], - ) + # node_features_val_f, edges_val_f, edge_features_tensor_val_f, target_val_f = self.load_data(fixed=True, + # dataset=self.dataset, + # batch_size=self.batch_size, + # num_nodes= + # self.num_nodes_max_test[0], + # ) # need to save the fixed one node_featur for epoch in range(self.num_training_steps): @@ -313,11 +323,10 @@ def train(self): self.global_steps += 1 print("Finished training") - if self.plot: os.makedirs(os.path.join(self.save_path, "results"), exist_ok=True) self.save_path = os.path.join(self.save_path, "results") - file_name = f"Losses_{seed}.pdf" + file_name = f"Losses_{self.seed}.pdf" # Combine the path and file name list_of_lists = [value for value in self.losses_val.values()] @@ -327,19 +336,19 @@ def train(self): list_of_list_name.append('loss_train') plot_curves( - list_of_lists , - os.path.join(self.save_path, file_name ), + list_of_lists, + os.path.join(self.save_path, file_name), "All_Losses", legend_labels=list_of_list_name, ) - file_name = f"ACCs_val_{seed}.pdf" - plot_curves( [value for value in self.ACCs_val.values()], - os.path.join(self.save_path, file_name), - "ACC Val", - legend_labels=[f'ACC_val_len{value}' for value in self.losses_val], - ) - file_name = f"ACCs_train_{seed}.pdf" + file_name = f"ACCs_val_{self.seed}.pdf" + plot_curves([value for value in self.ACCs_val.values()], + os.path.join(self.save_path, file_name), + "ACC Val", + legend_labels=[f'ACC_val_len{value}' for value in self.losses_val], + ) + file_name = f"ACCs_train_{self.seed}.pdf" plot_curves( [ @@ -350,6 +359,29 @@ def train(self): legend_labels=["ACC train"], ) + def sample_and_store(n): + # Initialize empty lists to store each sample's output + node_features_list = [] + edges_list = [] + edge_features_tensor_list = [] + target_list = [] + # Loop n times to sample data and store the outputs + for _ in range(n): + # Sample data by calling load_data + node_features, edges, edge_features_tensor, target = self.load_data(train=False, + dataset=self.dataset, + batch_size=self.batch_size) + # Append the results to the corresponding lists + node_features_list.append(node_features) + edges_list.append(edges) + edge_features_tensor_list.append(edge_features_tensor) + target_list.append(target) + return node_features_list, edges_list, edge_features_tensor_list, target_list + + n = 2 + # node_features_list, edges_list, edge_features_tensor_list, target_list = sample_and_store(n) + # plot_2dgraphs(edges_list, node_features_list, edge_features_tensor_list,['',''], os.path.join(self.save_path, "graph.pdf"), colorscale='Plasma',size=5,show=True) + return self.losses_train, self.ACCs_train, self.losses_val, self.ACCs_val def sample_and_store(n): # Initialize empty lists to store each sample's output @@ -384,175 +416,3 @@ def reset(self): return - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--config_path", metavar="-C", default="domine_2023_extras_2/config.yaml", - help="path to base configuration file.") - args = parser.parse_args() - set_device() - config_class = GridConfig - config = config_class(args.config_path) - - arena_x_limits = [-100, 100] - arena_y_limits = [-100, 100] - - seeds = [41,42] - losses_train = {seed: [] for seed in seeds} - losses_val = {seed: [] for seed in seeds} - ACCs_train = {seed: [] for seed in seeds} - ACCs_val = {seed: [] for seed in seeds} - dateTimeObj = datetime.now() - save_path = os.path.join(Path(os.getcwd()).resolve(), "results") - os.mkdir( - os.path.join( - save_path, - config.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), - ) - ) - save_path = os.path.join( - os.path.join( - save_path, - config.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), - ) - ) - for seed in seeds: - agent = Domine2023( - experiment_name=config.experiment_name, - resample=config.resample, - wandb_on=config.wandb_on, - seed=seed, - dataset=config.dataset, - weighted=config.weighted, - num_hidden=config.num_hidden, - num_layers=config.num_layers, - num_message_passing_steps=config.num_message_passing_steps, - learning_rate=config.learning_rate, - num_training_steps=config.num_training_steps, - batch_size=config.batch_size, - num_features=config.num_features, - num_nodes_max=config.num_nodes_max, - num_nodes_min=config.num_nodes_min, - batch_size_test=config.batch_size_test, - num_nodes_min_test=config.num_nodes_min_test, - num_nodes_max_test=config.num_nodes_max_test, - arena_y_limits=arena_y_limits, - arena_x_limits=arena_x_limits, - residual=config.residual, - layer_norm=config.layer_norm, - plot=config.plot, - save_path = save_path - ) - - - - losse_train, ACC_train, losse_val, ACC_val = agent.train() - losses_train[seed] = losse_train - losses_val[seed] = losse_val - ACCs_train[seed]= ACC_train - ACCs_val[seed] = ACC_val - - save_path = os.path.join(save_path, "results") - num_training_steps = config.num_training_steps - # Initialize lists to store standard deviation results - std_losses_train = [] - std_accs_train = [] - - # Compute average and standard deviation for training loss - - avg_losses_train = [] - for epoch_idx in range(num_training_steps): - # Average the loss for this epoch over all seeds - avg_epoch_loss = sum(losses_train[seed][epoch_idx] for seed in seeds) / len(seeds) - avg_losses_train.append(avg_epoch_loss) - - # Compute standard deviation for this epoch - variance_loss = sum((losses_train[seed][epoch_idx] - avg_epoch_loss) ** 2 for seed in seeds) / len(seeds) - std_epoch_loss = math.sqrt(variance_loss) - std_losses_train.append(std_epoch_loss) - - # Compute average and standard deviation for training accuracy - avg_accs_train = [] - for epoch_idx in range(num_training_steps): - # Average the accuracy for this epoch over all seeds - avg_epoch_acc = sum(ACCs_train[seed][epoch_idx] for seed in seeds) / len(seeds) - avg_accs_train.append(avg_epoch_acc) - - # Compute standard deviation for this epoch - variance_acc = sum((ACCs_train[seed][epoch_idx] - avg_epoch_acc) ** 2 for seed in seeds) / len(seeds) - std_epoch_acc = math.sqrt(variance_acc) - std_accs_train.append(std_epoch_acc) - - - # Compute average and standard deviation for validation loss - avg_losses_val_len = [] - std_losses_val_len = [] - for i in config.num_nodes_max_test: - avg_losses_val = [] - std_losses_val = [] - for epoch_idx in range(num_training_steps): - avg_epoch_loss_val = sum(losses_val[seed][i][epoch_idx] for seed in seeds) / len(seeds) - avg_losses_val.append(avg_epoch_loss_val) - variance_loss_val = sum( - (losses_val[seed][i][epoch_idx] - avg_epoch_loss_val) ** 2 for seed in seeds) / len(seeds) - std_epoch_loss_val = math.sqrt(variance_loss_val) - std_losses_val.append(std_epoch_loss_val) - avg_losses_val_len.append(avg_losses_val) - std_losses_val_len.append(std_losses_val) - - #Compute average and standard deviation for validation accuracy - avg_accs_val_len = [] - std_accs_val_len = [] - for i in config.num_nodes_max_test: - avg_accs_val = [] - std_accs_val = [] - for epoch_idx in range(num_training_steps): - avg_epoch_acc_val = sum(ACCs_val[seed][i][epoch_idx] for seed in seeds) / len(seeds) - avg_accs_val.append(avg_epoch_acc_val) - - # Compute standard deviation for this epoch - variance_acc_val = sum((ACCs_val[seed][i][epoch_idx] - avg_epoch_acc_val) ** 2 for seed in seeds) / len(seeds) - std_epoch_acc_val = math.sqrt(variance_acc_val) - std_accs_val.append(std_epoch_acc_val) - avg_accs_val_len.append(avg_accs_val) - std_accs_val_len.append(std_accs_val) - - - - list_of_list_name = [f'loss_val_len{value}' for value in losses_val[seed]] - # Append losses_train to the list of lists - avg_losses_val_len.append(avg_losses_train) - list_of_list_name.append('loss_train') - std_losses_val_len.append(std_accs_train) - - plot_curves_2( - avg_losses_val_len,std_losses_val_len, - os.path.join(save_path, "Losses.pdf"), - "All_Losses", - legend_labels= list_of_list_name, - ) - plot_curves_2( - [ - avg_accs_train , - ],[std_accs_train], - os.path.join(save_path, "ACCs_train.pdf"), - "ACC Train", - legend_labels=["ACC Train"], - ) - - list_of_list_name = [f'loss_val_len{value}' for value in losses_val[seed]] - plot_curves_2( - avg_accs_val_len,std_accs_val_len, os.path.join(save_path, "ACCs_val.pdf"), - "ACC val", - legend_labels=list_of_list_name, - ) - print() - - #TODO: They all have different evaluation ( netwokr ) do we want ot eval ( for the average it should be ifne) - #TODO: Think about nice visualisaiton - #TODO: update the plotting for the other curves - # I need to check the logging of the results - # TODO. : plan a set of experiements to run, sudy how different initialisiton - #TODO: Get a different set of valisaion lenght for each run - - # TODO : the set of seed changes every run. so it is fine. The question is \ No newline at end of file diff --git a/neuralplayground/agents/domine_2023_2_mp.py b/neuralplayground/agents/domine_2023_2_mp.py new file mode 100644 index 00000000..91657d68 --- /dev/null +++ b/neuralplayground/agents/domine_2023_2_mp.py @@ -0,0 +1,117 @@ +import argparse +import os +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" +import math +import torch +import shutil +from datetime import datetime +from pathlib import Path +import torch.nn as nn +import torch.optim as optim +import wandb +import numpy as np +from neuralplayground.agents.agent_core import AgentCore +from neuralplayground.agents.domine_2023_extras_2.utils.plotting_utils import plot_curves, plot_curves_2, plot_2dgraphs +from neuralplayground.agents.domine_2023_extras_2.class_grid_run_config import GridConfig +from neuralplayground.agents.domine_2023_extras_2.utils.utils import set_device +from neuralplayground.agents.domine_2023_2 import Domine2023 + +# from neuralplayground.agents.domine_2023_extras_2.evaluate import Evaluator +os.environ["KMP_DUPLICATE_LIB_OK"] = "True" + + +parser = argparse.ArgumentParser() +parser.add_argument("--config_path", metavar="-C", default="domine_2023_extras_2/config.yaml", + help="path to base configuration file.") +args = parser.parse_args() +set_device() +config_class = GridConfig +config = config_class(args.config_path) + +arena_x_limits = [-100, 100] +arena_y_limits = [-100, 100] + +mps = [0,1,2,3,4,5,6,7] +losses_train = {mp: [] for mp in mps} +losses_val = {mp: [] for mp in mps} +ACCs_train = {mp: [] for mp in mps} +ACCs_val = {mp: [] for mp in mps} + +dateTimeObj = datetime.now() +save_path = os.path.join(Path(os.getcwd()).resolve(), "results") +os.mkdir( +os.path.join( + save_path, + config.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), +) +) +save_path = os.path.join( +os.path.join( + save_path, + config.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), +) +) + +for mp in mps: + agent = Domine2023( + experiment_name=config.experiment_name, + wandb_on=config.wandb_on, + seed=config.seed, + dataset=config.dataset, + num_hidden=config.num_hidden, + num_layers=config.num_layers, + num_message_passing_steps= mp, + learning_rate=config.learning_rate, + num_training_steps=config.num_training_steps, + batch_size=config.batch_size, + num_features=config.num_features, + num_nodes_max=config.num_nodes_max, + batch_size_test=config.batch_size_test, + num_nodes_max_test=config.num_nodes_max_test, + arena_y_limits=arena_y_limits, + arena_x_limits=arena_x_limits, + residual=config.residual, + plot=config.plot, + save_path = save_path + ) + + losse_train, ACC_train, losse_val, ACC_val = agent.train() + losses_train[mp] = losse_train + losses_val[mp] = losse_val + ACCs_train[mp]= ACC_train + ACCs_val[mp] = ACC_val + +for i in config.num_nodes_max_test: + third_elements = {key: value[i] for key, value in losses_val.items()} + list_of_lists = list( third_elements.values()) + file_name = f"Losses_val_len{i}.pdf" + list_of_list_name = [ f"mp_{i}.pdf" for i in list(third_elements.keys())] + + plot_curves( + list_of_lists, + os.path.join(save_path, file_name), + f"Losses_val{i}", + legend_labels=list_of_list_name, + ) + +list_of_lists = list(losses_train.values()) +file_name = f"Losses_train_mp.pdf" +list_of_list_name = [f"mp_{i}.pdf" for i in list(losses_train.keys())] +plot_curves( + list_of_lists, + os.path.join(save_path, file_name), + "Losses_train_mp", + legend_labels=list_of_list_name, + ) + +list_of_lists = list(ACCs_train.values()) +file_name = f"ACCs_mp.pdf" +list_of_list_name = [f"mp_{i}.pdf" for i in list(ACCs_train.keys())] +plot_curves( + list_of_lists, + os.path.join(save_path, file_name), + "ACC_mp", + legend_labels=list_of_list_name, + ) + + diff --git a/neuralplayground/agents/domine_2023_2seed.py b/neuralplayground/agents/domine_2023_2seed.py new file mode 100644 index 00000000..75d1555a --- /dev/null +++ b/neuralplayground/agents/domine_2023_2seed.py @@ -0,0 +1,239 @@ +import argparse +import os +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" +import math +import torch +import shutil +from datetime import datetime +from pathlib import Path +import torch.nn as nn +import torch.optim as optim +import wandb +import numpy as np +from neuralplayground.agents.agent_core import AgentCore +from neuralplayground.agents.domine_2023_extras_2.utils.plotting_utils import plot_curves, plot_curves_2, plot_2dgraphs +from neuralplayground.agents.domine_2023_extras_2.models.GCN_model import GCNModel, MLP ,GCNModel_2 +from neuralplayground.agents.domine_2023_extras_2.class_grid_run_config import GridConfig +from neuralplayground.agents.domine_2023_extras_2.utils.utils import set_device +from neuralplayground.agents.domine_2023_extras_2.processing.Graph_generation import sample_graph, sample_target, sample_omniglot_graph, sample_fixed_graph +from torchmetrics import Accuracy, Precision, AUROC, Recall, MatthewsCorrCoef +from torchmetrics.classification import BinaryAccuracy +from neuralplayground.agents.domine_2023_2 import Domine2023 + +# from neuralplayground.agents.domine_2023_extras_2.evaluate import Evaluator +os.environ["KMP_DUPLICATE_LIB_OK"] = "True" + + +parser = argparse.ArgumentParser() +parser.add_argument("--config_path", metavar="-C", default="domine_2023_extras_2/config.yaml", + help="path to base configuration file.") +args = parser.parse_args() +set_device() +config_class = GridConfig +config = config_class(args.config_path) + +arena_x_limits = [-100, 100] +arena_y_limits = [-100, 100] + +seeds = [41,42] +losses_train = {seed: [] for seed in seeds} +losses_val = {seed: [] for seed in seeds} +ACCs_train = {seed: [] for seed in seeds} +ACCs_val = {seed: [] for seed in seeds} +dateTimeObj = datetime.now() +save_path = os.path.join(Path(os.getcwd()).resolve(), "results") +os.mkdir( + os.path.join( + save_path, + config.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), + ) +) +save_path = os.path.join( + os.path.join( + save_path, + config.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), + ) +) +for seed in seeds: + agent = Domine2023( + experiment_name=config.experiment_name, + wandb_on=config.wandb_on, + seed=seed, + dataset=config.dataset, + num_hidden=config.num_hidden, + num_layers=config.num_layers, + num_message_passing_steps=config.num_message_passing_steps, + learning_rate=config.learning_rate, + num_training_steps=config.num_training_steps, + batch_size=config.batch_size, + num_features=config.num_features, + num_nodes_max=config.num_nodes_max, + batch_size_test=config.batch_size_test, + num_nodes_max_test=config.num_nodes_max_test, + arena_y_limits=arena_y_limits, + arena_x_limits=arena_x_limits, + residual=config.residual, + plot=config.plot, + save_path = save_path + ) + + + + losse_train, ACC_train, losse_val, ACC_val = agent.train() + losses_train[seed] = losse_train + losses_val[seed] = losse_val + ACCs_train[seed]= ACC_train + ACCs_val[seed] = ACC_val + +save_path = os.path.join(save_path, "results") + +num_training_steps = config.num_training_steps +# Initialize lists to store standard deviation results +std_losses_train = [] +std_accs_train = [] + +# Compute average and standard deviation for training loss + +avg_losses_train = [] +for epoch_idx in range(num_training_steps): + # Average the loss for this epoch over all seeds + avg_epoch_loss = sum(losses_train[seed][epoch_idx] for seed in seeds) / len(seeds) + avg_losses_train.append(avg_epoch_loss) + + # Compute standard deviation for this epoch + variance_loss = sum((losses_train[seed][epoch_idx] - avg_epoch_loss) ** 2 for seed in seeds) / len(seeds) + std_epoch_loss = math.sqrt(variance_loss) + std_losses_train.append(std_epoch_loss) + +# Compute average and standard deviation for training accuracy +avg_accs_train = [] +for epoch_idx in range(num_training_steps): + # Average the accuracy for this epoch over all seeds + avg_epoch_acc = sum(ACCs_train[seed][epoch_idx] for seed in seeds) / len(seeds) + avg_accs_train.append(avg_epoch_acc) + + # Compute standard deviation for this epoch + variance_acc = sum((ACCs_train[seed][epoch_idx] - avg_epoch_acc) ** 2 for seed in seeds) / len(seeds) + std_epoch_acc = math.sqrt(variance_acc) + std_accs_train.append(std_epoch_acc) + + +# Compute average and standard deviation for validation loss +avg_losses_val_len = [] +std_losses_val_len = [] +for i in config.num_nodes_max_test: + avg_losses_val = [] + std_losses_val = [] + for epoch_idx in range(num_training_steps): + avg_epoch_loss_val = sum(losses_val[seed][i][epoch_idx] for seed in seeds) / len(seeds) + avg_losses_val.append(avg_epoch_loss_val) + variance_loss_val = sum( + (losses_val[seed][i][epoch_idx] - avg_epoch_loss_val) ** 2 for seed in seeds) / len(seeds) + std_epoch_loss_val = math.sqrt(variance_loss_val) + std_losses_val.append(std_epoch_loss_val) + avg_losses_val_len.append(avg_losses_val) + std_losses_val_len.append(std_losses_val) + +#Compute average and standard deviation for validation accuracy +avg_accs_val_len = [] +std_accs_val_len = [] +for i in config.num_nodes_max_test: + avg_accs_val = [] + std_accs_val = [] + for epoch_idx in range(num_training_steps): + avg_epoch_acc_val = sum(ACCs_val[seed][i][epoch_idx] for seed in seeds) / len(seeds) + avg_accs_val.append(avg_epoch_acc_val) + + # Compute standard deviation for this epoch + variance_acc_val = sum((ACCs_val[seed][i][epoch_idx] - avg_epoch_acc_val) ** 2 for seed in seeds) / len(seeds) + std_epoch_acc_val = math.sqrt(variance_acc_val) + std_accs_val.append(std_epoch_acc_val) + avg_accs_val_len.append(avg_accs_val) + std_accs_val_len.append(std_accs_val) + + + +list_of_list_name = [f'loss_val_len{value}' for value in losses_val[seed]] +# Append losses_train to the list of lists +# avg_losses_val_len.append(avg_losses_train) +# list_of_list_name.append('loss_train') +# std_losses_val_len.append(std_accs_train) + +plot_curves_2( + avg_losses_val_len,std_losses_val_len, + os.path.join(save_path, "Losses.pdf"), + "All_Losses", + legend_labels = list_of_list_name, +) +# Append losses_train to the list of lists +# avg_losses_val_len.append(avg_losses_train) +# list_of_list_name.append('loss_train') +# std_losses_val_len.append(std_accs_train) + +plot_curves_2( + [avg_losses_train], [std_losses_train], + os.path.join(save_path, "Losses_train.pdf"), + "All_Losses", + legend_labels = "Loss Train", +) + +plot_curves_2( + [ + avg_accs_train , + ],[std_accs_train], + os.path.join(save_path, "ACCs_train.pdf"), + "ACC Train", + legend_labels=["ACC Train"], +) + +list_of_list_name = [f'loss_val_len{value}' for value in losses_val[seed]] +plot_curves_2( + avg_accs_val_len,std_accs_val_len, os.path.join(save_path, "ACCs_val.pdf"), + "ACC val", + legend_labels=list_of_list_name, +) +print() + +#TODO: They all have different evaluation ( netwokr ) do we want ot eval ( for the average it should be ifne) +#TODO: Think about nice visualisaiton +#TODO: update the plotting for the other curves +# I need to check the logging of the results +# TODO. : plan a set of experiements to run, sudy how different initialisiton +#TODO: Get a different set of valisaion lenght for each run + +# TODO : the set of seed changes every run. so it is fine. The question is +#TODO: What is the the best way to have the dedges no features if plot: +#os.makedirs(os.path.join(self.save_path, "results"), exist_ok=True) +#self.save_path = os.path.join(self.save_path, "results") +# file_name = f"Losses_{seed}.pdf" + + # Combine the path and file name +# list_of_lists = [value for value in self.losses_val.values()] +# list_of_list_name = [f'loss_val_len{value}' for value in self.losses_val] +# # Append losses_train to the list of lists +# list_of_lists.append(self.losses_train) +# list_of_list_name.append('loss_train') + +# plot_curves( +# list_of_lists, +# os.path.join(self.save_path, file_name), +# "All_Losses", +# legend_labels=list_of_list_name, +# ) + +# file_name = f"ACCs_val_{seed}.pdf" +# plot_curves([value for value in self.ACCs_val.values()], +# os.path.join(self.save_path, file_name), +# "ACC Val", +# legend_labels=[f'ACC_val_len{value}' for value in self.losses_val], +# ) +# file_name = f"ACCs_train_{seed}.pdf" + +# pl#ot_curves( +# [ + # # self.ACCs_train, +# ], +# os.path.join(self.save_path, file_name), +# "ACC train", +# legend_labels=["ACC train"], +# ) diff --git a/neuralplayground/agents/domine_2023_3.py b/neuralplayground/agents/domine_2023_3.py deleted file mode 100644 index 9d37bdb9..00000000 --- a/neuralplayground/agents/domine_2023_3.py +++ /dev/null @@ -1,389 +0,0 @@ -import argparse -import os -os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" -from datetime import datetime -import torch -import shutil -from datetime import datetime -from pathlib import Path -import torch.nn as nn -import torch.optim as optim -import wandb -import numpy as np -from neuralplayground.agents.agent_core import AgentCore -from neuralplayground.agents.domine_2023_extras_2.utils.plotting_utils import plot_curves, plot_2dgraphs -from neuralplayground.agents.domine_2023_extras_2.models.GCN_model import GCNModel -from neuralplayground.agents.domine_2023_extras_2.class_grid_run_config import GridConfig -from neuralplayground.agents.domine_2023_extras_2.utils.utils import set_device -from neuralplayground.agents.domine_2023_extras_2.processing.Graph_generation import sample_random_graph, sample_target, sample_omniglot_graph -from torchmetrics import Accuracy, Precision, AUROC, Recall, MatthewsCorrCoef -from torchmetrics.classification import BinaryAccuracy -# from neuralplayground.agents.domine_2023_extras_2.evaluate import Evaluator -os.environ["KMP_DUPLICATE_LIB_OK"] = "True" - - -class Domine2023(AgentCore): - def __init__(self, experiment_name="smaller size generalisation graph with no position feature", - train_on_shortest_path=True, resample=True, wandb_on=False, seed=41, dataset = 'random', - weighted=True, num_hidden=100, num_layers=2, num_message_passing_steps=3, learning_rate=0.001, - num_training_steps=10, residual=True, layer_norm=True, batch_size=4, num_features=4, num_nodes_max=7, - batch_size_test=4, num_nodes_min_test=4, num_nodes_max_test=7, plot=True, **mod_kwargs): - super(Domine2023, self).__init__() - - # General - self.plot = plot - self.obs_history = [] - self.grad_history = [] - self.experiment_name = experiment_name - self.wandb_on = wandb_on - self.seed = seed - self.log_every = 500 - - # Network - self.num_hidden = num_hidden - self.num_layers = num_layers - self.num_message_passing_steps = num_message_passing_steps - self.learning_rate = learning_rate - self.num_training_steps = num_training_steps - self.batch_size = batch_size - self.residual = residual - self.layer_norm = layer_norm - - # Task - self.dataset = dataset - self.weighted = weighted - self.num_features = num_features - self.num_nodes_max = num_nodes_max - self.num_nodes_max_test = num_nodes_max_test - - self.batch_size_test = batch_size_test - self.arena_x_limits = mod_kwargs["arena_x_limits"] - self.arena_y_limits = mod_kwargs["arena_y_limits"] - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - if self.dataset == 'random': - self.model = GCNModel(self.num_hidden, self.num_features + 2, self.num_layers, - self.num_message_passing_steps, self.residual, - self.layer_norm).to(self.device) - else: - num_features = 784 - self.model = GCNModel(self.num_hidden, num_features + 2, self.num_layers, - self.num_message_passing_steps, self.residual, - self.layer_norm).to(self.device) - - - self.auroc = AUROC(task="binary") - self.MCC = MatthewsCorrCoef(task='binary') - self.metric = BinaryAccuracy() - - - self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) - self.criterion = nn.MSELoss() - - if self.wandb_on: - dateTimeObj = datetime.now() - wandb.init(project="New", entity="graph-brain", - name=experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S")) - self.wandb_logs = {} - save_path = wandb.run.dir - os.mkdir(os.path.join(save_path,"results")) - self.save_path = os.path.join(save_path, "results") - - self.reset() - self.wandb_logs = { # This is thought of the state density - "batch_size": self.batch_size, - "num_node_min": self.num_nodes_max, # This is thought of the state density - "seed": self.seed, - "dataset": self.dataset, - "weighted": self.weighted, - "num_hidden": self.num_hidden, - "num_layers": self.num_layers, - "num_message_passing_steps": self.num_message_passing_steps, - "learning_rate": self.learning_rate, - "num_training_steps": self.num_training_steps, - "residual": self.residual, - "layer_norm": self.layer_norm, - } - if self.wandb_on: - wandb.log(self.wandb_logs) - else: - dateTimeObj = datetime.now() - save_path = os.path.join(Path(os.getcwd()).resolve(), "results") - os.mkdir( - os.path.join( - save_path, - self.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), - ) - ) - self.save_path = os.path.join( - os.path.join( - save_path, - self.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), - ) - ) - self.save_run_parameters() - - def save_run_parameters(self): - """Save configuration files and scripts.""" - files_to_copy = [ - ("run.py", "domine_2023_2.py"), - ("Graph_generation.py", "domine_2023_extras_2/processing/Graph_generation.py"), - ("utils.py", "domine_2023_extras_2/utils/utils.py"), - ("plotting_utils.py", "domine_2023_extras_2/utils/plotting_utils.py"), - ("config_run.yaml", "domine_2023_extras_2/config.yaml"), - ] - for file_name, source in files_to_copy: - shutil.copyfile(os.path.join(Path(os.getcwd()).resolve(), source), os.path.join(self.save_path, file_name)) - - def load_data(self, train, dataset, batch_size): - # Initialize lists to store batch data - node_features_batch, edges_batch, edge_features_batch, target_batch = [], [], [], [] - - # Determine the max nodes to use based on whether it is training or testing - num_nodes = self.num_nodes_max if train else self.num_nodes_max_test - - # Loop to generate a batch of data - for _ in range(batch_size): - # Handle Omniglot dataset - if dataset == 'omniglot': - node_features, edges, edge_features_tensor, source, sink = sample_omniglot_graph(num_nodes) - # Handle Random dataset - elif dataset == 'random': - node_features, edges, edge_features_tensor, source, sink = sample_random_graph(self.num_features, - num_nodes) - - # Sample the target based on source and sink - target = sample_target(source, sink) - - # Append each graph data to the batch list - node_features_batch.append(node_features) - edges_batch.append(edges) - edge_features_batch.append(edge_features_tensor) - target_batch.append(target) - - return node_features_batch, edges_batch, edge_features_batch, target_batch - - def compute_loss(self, outputs, targets): - loss = self.criterion(outputs, targets) - return loss - - def run_model(self, node, edges,edges_features): - outputs = self.model(node,edges,edges_features) - return outputs - - def update_step(self, node_batch, edges_batch, edges_features_batch, target_batch, train, batch_size=1): - if train: - self.model.train() - self.optimizer.zero_grad() - else: - self.model.eval() - - batch_losses = 0 - all_outputs = [] - - # Loop over the batch - for i in range(batch_size): - data = node_batch[i].to(self.device) - edges = edges_batch[i].to(self.device) - edges_features = edges_features_batch[i].to(self.device) - - # Forward pass - outputs = self.run_model(data, edges, edges_features) - all_outputs.append(outputs) - - # Compute loss for this sample - loss = self.compute_loss(outputs, target_batch[i]) - batch_losses += loss - - # Average loss over the batch - avg_loss = batch_losses / batch_size - - if train: - avg_loss.backward() - self.optimizer.step() - - # Concatenate all outputs for evaluation over the batch - all_outputs = torch.cat(all_outputs) - all_target = torch.cat(target_batch) - - # Evaluate using the full batch's predictions - roc_auc, mcc = self.evaluate(all_outputs, all_target) - - return avg_loss.item(), roc_auc, mcc - - def evaluate(self,outputs,targets): - with (torch.no_grad()): - roc_auc = self.auroc(outputs, targets) - # roc_auc_score(targets.cpu(), outputs.cpu()) - mcc = self.MCC(outputs, targets) - acc = self.metric(outputs, targets) - return roc_auc, mcc - - def log_training(self, train_loss, val_loss, train_roc_auc, val_roc_auc, train_mcc, val_mcc): - """Log training and validation metrics.""" - wandb_logs = { - "train_loss": train_loss, - "val_loss": val_loss, - "roc_auc_train": train_roc_auc, - "roc_auc_val": val_roc_auc, - "MCC_train": train_mcc, - "MCC_val": val_mcc - } - if self.wandb_on: - wandb.log(wandb_logs) - - - def train(self): - node_features_val, edges_val, edge_features_tensor_val, target_val = self.load_data(train=False, - dataset=self.dataset, - batch_size=self.batch_size) - - for epoch in range(self.num_training_steps): - train_losses, train_roc_auc, train_mcc = 0, 0, 0 - node_features, edges, edge_features_tensor, target = self.load_data(train=True, - dataset=self.dataset, - batch_size=self.batch_size) - #Train on each batch - batch_losses, batch_roc_auc, batch_mcc = self.update_step(node_features, edges, - edge_features_tensor, target, train=True, batch_size=self.batch_size) - - #Aggregate batch results - train_losses = batch_losses - train_roc_auc = batch_roc_auc.detach().numpy() - train_mcc = batch_mcc.detach().numpy() - # Average batch results over the batch size - - # Store results for plotting - self.losses_train.append(train_losses) - self.roc_aucs_train.append(train_roc_auc) - self.MCCs_train.append(train_mcc) - with torch.no_grad(): - val_losses, val_roc_auc, val_mcc = self.update_step(node_features_val, edges_val, - edge_features_tensor_val, target_val, - train=False,batch_size=self.batch_size) - self.losses_val.append(val_losses) - self.roc_aucs_val.append(val_roc_auc.detach().numpy()) - self.MCCs_val.append(val_mcc) - - # Log training details - self.log_training(train_losses, val_losses, train_roc_auc, - val_roc_auc.detach().numpy(), train_mcc, val_mcc) - - # Plot progress every epoch - if self.global_steps % self.log_every == 0: - print( - f"Epoch {epoch}: Train Loss = {train_losses}, Val Loss = {val_losses}, ROC AUC Train = {train_roc_auc}, ROC AUC Val = {val_roc_auc}") - self.global_steps += 1 - print("Finished training") - - if self.plot: - os.mkdir(os.path.join(self.save_path, "results")) - self.save_path = os.path.join(self.save_path, "results") - plot_curves( - [ - self.losses_train, - self.losses_val], - os.path.join(self.save_path, "Losses.pdf"), - "All_Losses", - legend_labels=["loss", "loss tesft"], - ) - plot_curves( - [ - self.MCCs_train, - ], - os.path.join(self.save_path, "MCCs_train.pdf"), - "MCC Train", - legend_labels=["MCC Train"], - ) - plot_curves( - [ - self.roc_aucs_train, - ], - os.path.join(self.save_path, "AUROC_train.pdf"), - "AUROC Train", - legend_labels=["AUROC Train"], - ) - plot_curves( - [ - self.MCCs_val, - ], - os.path.join(self.save_path, "MCCs_val.pdf"), - "MCC val", - legend_labels=["MCC val"], - ) - - def sample_and_store(n): - # Initialize empty lists to store each sample's output - node_features_list = [] - edges_list = [] - edge_features_tensor_list = [] - target_list = [] - # Loop n times to sample data and store the outputs - for _ in range(n): - # Sample data by calling load_data - node_features, edges, edge_features_tensor, target = self.load_data(train=True,dataset = self.dataset) - # Append the results to the corresponding lists - node_features_list.append(node_features) - edges_list.append(edges) - edge_features_tensor_list.append(edge_features_tensor) - target_list.append(target) - return node_features_list, edges_list, edge_features_tensor_list, target_list - - n=2 - node_features_list, edges_list, edge_features_tensor_list, target_list = sample_and_store(n) - plot_2dgraphs(edges_list, node_features_list, edge_features_tensor_list,['',''], os.path.join(self.save_path, "graph.pdf"), colorscale='Plasma',size=5,show=True) - return - def reset(self): - self.obs_history = [] - self.grad_history = [] - self.global_steps = 0 - self.losses_train = [] - self.losses_val = [] - self.MCCs_train = [] - self.MCCs_val = [] - self.roc_aucs_train = [] - self.roc_aucs_val = [] - return - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--config_path", metavar="-C", default="domine_2023_extras_2/config.yaml", - help="path to base configuration file.") - args = parser.parse_args() - set_device() - config_class = GridConfig - config = config_class(args.config_path) - - arena_x_limits = [-100, 100] - arena_y_limits = [-100, 100] - - agent = Domine2023( - experiment_name=config.experiment_name, - resample=config.resample, - wandb_on=config.wandb_on, - seed=config.seed, - dataset=config.dataset, - weighted=config.weighted, - num_hidden=config.num_hidden, - num_layers=config.num_layers, - num_message_passing_steps=config.num_message_passing_steps, - learning_rate=config.learning_rate, - num_training_steps=config.num_training_steps, - batch_size=config.batch_size, - num_features=config.num_features, - num_nodes_max=config.num_nodes_max, - num_nodes_min=config.num_nodes_min, - batch_size_test=config.batch_size_test, - num_nodes_min_test=config.num_nodes_min_test, - num_nodes_max_test=config.num_nodes_max_test, - arena_y_limits=arena_y_limits, - arena_x_limits=arena_x_limits, - residual=config.residual, - layer_norm=config.layer_norm, - plot=config.plot, - ) - - agent.train() - - #TO DO : figure out how to build the graph and the task in that setting, will it be a batch of multople graphs, how to i compute the loss on asingle param?? Global ?? - # I need to check the saving and the logging of the results \ No newline at end of file diff --git a/neuralplayground/agents/domine_2023_4.py b/neuralplayground/agents/domine_2023_4.py deleted file mode 100644 index 423828d0..00000000 --- a/neuralplayground/agents/domine_2023_4.py +++ /dev/null @@ -1,736 +0,0 @@ -# TODO: NOTE to self: This is a work in progress, it has not been tested to work, I think Jax is not a good way to implement in object oriented coding. -# I think if I want to implement it here I should use neuralplayground it would be in pytorch. - -import argparse -import os -import shutil -from datetime import datetime -from typing import Union -from pathlib import Path -import haiku as hk -import jax -import jax.numpy as jnp -import numpy as np -import optax -import wandb -from neuralplayground.agents.agent_core import AgentCore - -os.environ["KMP_DUPLICATE_LIB_OK"] = "True" -from neuralplayground.agents.domine_2023_extras.class_Graph_generation import ( - sample_padded_grid_batch_shortest_path, -) -from neuralplayground.agents.domine_2023_extras.class_grid_run_config import GridConfig -from neuralplayground.agents.domine_2023_extras.class_models import get_forward_function -from neuralplayground.agents.domine_2023_extras.class_plotting_utils import ( - plot_graph_grid_activations, - plot_input_target_output, - plot_message_passing_layers, - plot_curves, - -) -from neuralplayground.agents.domine_2023_extras.class_utils import ( - rng_sequence_from_rng, - set_device, - update_outputs_test, -) -from sklearn.metrics import matthews_corrcoef, roc_auc_score - - -class Domine2023( - AgentCore, -): - def __init__( # autogenerated - self, - # agent_name: str = "SR", - experiment_name="smaller size generalisation graph with no position feature", - train_on_shortest_path: bool = True, - resample: bool = True, - wandb_on: bool = False, - seed: int = 41, - feature_position: bool = False, - weighted: bool = True, - num_hidden: int = 100, - num_layers: int = 2, - num_message_passing_steps: int = 3, - learning_rate: float = 0.001, - num_training_steps: int = 10, - batch_size: int = 4, - nx_min: int = 4, - nx_max: int = 7, - batch_size_test: int = 4, - nx_min_test: int = 4, - nx_max_test: int = 7, - **mod_kwargs, - ): - self.obs_history = [] - self.grad_history = [] - self.train_on_shortest_path = train_on_shortest_path - self.experiment_name = experiment_name - self.resample = resample - self.wandb_on = wandb_on - - self.seed = seed - self.feature_position = feature_position - self.weighted = weighted - - self.num_hidden = num_hidden - self.num_layers = num_layers - self.num_message_passing_steps = num_message_passing_steps - self.learning_rate = learning_rate - self.num_training_steps = num_training_steps - # cconfig.num_training_steps # @param - - # This can be tought of the brain making different rep of different granularity - # Could be explained during sleep - self.batch_size_test = batch_size_test - self.nx_min_test = nx_min_test # This is thought of the state density - self.nx_max_test = nx_max_test # This is thought of the state density - self.batch_size = batch_size - self.nx_min = nx_min # This is thought of the state density - self.nx_max = nx_max - - self.arena_x_limits = mod_kwargs["arena_y_limits"] - self.arena_y_limits = mod_kwargs["arena_y_limits"] - self.room_width = np.diff(self.arena_x_limits)[0] - self.room_depth = np.diff(self.arena_y_limits)[0] - self.agent_step_size = 0 - - self.log_every = num_training_steps // 10 - if self.weighted: - self.edge_lables = True - else: - self.edge_lables = True - - if self.wandb_on: - dateTimeObj = datetime.now() - wandb.init( - project="graph-test", - entity="graph-brain", - name="Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M_%S"), - ) - self.wandb_logs = {} - save_path = wandb.run.dir - os.mkdir(os.path.join(save_path, "results")) - self.save_path = os.path.join(save_path, "results") - - wandb_logs = { - "train_on_shortest_path": train_on_shortest_path, - "resample": resample, - - "batch_size_test": batch_size_test, - "nx_min_test": nx_min_test, # This is thought of the state density - "nx_max_test": nx_max_test, # This is thought of the state density - "batch_size": batch_size, - "nx_min": nx_min, # This is thought of the state density - "nx_max": nx_max, - - "seed": seed, - "feature_position": feature_position, - "weighted": weighted, - - "num_hidden": num_hidden, - "num_layers": num_layers, - "num_message_passing_steps": num_message_passing_steps, - "learning_rate": learning_rate, - "num_training_steps": num_training_steps, - } - - if self.wandb_on: - wandb.log(wandb_logs) - - else: - dateTimeObj = datetime.now() - save_path = os.path.join(Path(os.getcwd()).resolve(), "results") - os.mkdir( - os.path.join( - save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M_%S") - ) - ) - self.save_path = os.path.join( - os.path.join( - save_path, "Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M_%S") - ) - ) - - self.reset() - self.saving_run_parameters() - - rng = jax.random.PRNGKey(self.seed) - self.rng_seq = rng_sequence_from_rng(rng) - - if self.train_on_shortest_path: - self.graph, self.targets = sample_padded_grid_batch_shortest_path( - rng, - self.batch_size, - self.feature_position, - self.weighted, - self.nx_min, - self.nx_max, - ) - else: - self.graph, self.targets = sample_padded_grid_batch_shortest_path( - rng, - self.batch_size, - self.feature_position, - self.weighted, - self.nx_min, - self.nx_max, - ) - forward = get_forward_function( - self.num_hidden, self.num_layers, self.num_message_passing_steps - ) - net_hk = hk.without_apply_rng(hk.transform(forward)) - params = net_hk.init(rng, self.graph) - self.params = params - optimizer = optax.adam(self.learning_rate) - opt_state = optimizer.init(self.params) - self.opt_state = opt_state - - def compute_loss(params, inputs, targets): - outputs = net_hk.apply(params, inputs) - return jnp.mean((outputs[0].nodes - targets) ** 2) - - self._compute_loss = jax.jit(compute_loss) - - def compute_loss_per_graph(params, inputs, targets): - outputs = net_hk.apply(params, inputs) - #for each graph: - return (outputs[0].nodes - targets) ** 2 - - - def update_step(params, opt_state): - loss, grads = jax.value_and_grad(compute_loss)( - params, self.graph, self.targets - ) # jits inside of value_and_grad - updates, opt_state = optimizer.update(grads, opt_state, params) - params = optax.apply_updates(params, updates) - return params, opt_state, loss - - self._update_step = jax.jit(update_step) - - def evaluate(params, inputs, target,wse_value=True,indices=None): - outputs = net_hk.apply(params, inputs) - if wse_value: - roc_auc = roc_auc_score( - np.squeeze(target), np.squeeze(outputs[0].nodes) - ) - MCC = matthews_corrcoef( - np.squeeze(target), round(np.squeeze(outputs[0].nodes)) - ) - else: - output = outputs[0].nodes - for ind in indices: - output = output.at[ind].set(0) - - MCC = matthews_corrcoef( - np.squeeze(target), round(np.squeeze(output)) - ) - roc_auc = False - - return outputs, roc_auc, MCC - - self._evaluate = evaluate - - def saving_run_parameters(self): - path = os.path.join(self.save_path, "run.py") - HERE = os.path.join(Path(os.getcwd()).resolve(), "domine_2023.py") - shutil.copyfile(HERE, path) - - path = os.path.join(self.save_path, "class_Graph_generation.py") - HERE = os.path.join( - Path(os.getcwd()).resolve(), "domine_2023_extras/class_Graph_generation.py" - ) - shutil.copyfile(HERE, path) - - path = os.path.join(self.save_path, "class_utils.py") - HERE = os.path.join( - Path(os.getcwd()).resolve(), "domine_2023_extras/class_utils.py" - ) - shutil.copyfile(HERE, path) - - path = os.path.join(self.save_path, "class_plotting_utils.py") - HERE = os.path.join( - Path(os.getcwd()).resolve(), "domine_2023_extras/class_plotting_utils.py" - ) - shutil.copyfile(HERE, path) - - path = os.path.join(self.save_path, "class_config_run.yaml") - HERE = os.path.join( - Path(os.getcwd()).resolve(), "domine_2023_extras/class_config.yaml" - ) - shutil.copyfile(HERE, path) - - def reset(self, a=1): - self.obs_history = [] # Initialize observation history to update weights later - self.grad_history = [] - self.global_steps = 0 - self.losses_train = [] - self.losses_test = [] - self.losses_train_wse = [] - self.losses_test_wse = [] - self.roc_aucs_train = [] - self.MCCs_train = [] - self.MCCs_test = [] - self.roc_aucs_test = [] - self.MCCs_train_wse = [] - self.MCCs_test_wse = [] - return - - def update(self): - rng = next(self.rng_seq) - if self.train_on_shortest_path: - graph_test, target_test = sample_padded_grid_batch_shortest_path( - rng, - self.batch_size, - self.feature_position, - self.weighted, - self.nx_min, - self.nx_max, - ) - rng = next(self.rng_seq) - - if self.resample: - self.graph, self.targets = sample_padded_grid_batch_shortest_path( - rng, - self.batch_size, - self.feature_position, - self.weighted, - self.nx_min, - self.nx_max, - ) - else: - graph_test, target_test = sample_padded_grid_batch_shortest_path( - rng, - self.batch_size_test, - self.feature_position, - self.weighted, - self.nx_min_test, - self.nx_max_test, - ) - target_test = np.reshape( - graph_test.nodes[:, 0], (graph_test.nodes[:, 0].shape[0], -1) - ) - - rng = next(self.rng_seq) - # Sample - if self.resample: - self.graph, self.targets = sample_padded_grid_batch_shortest_path( - rng, - self.batch_size, - self.feature_position, - self.weighted, - self.nx_min, - self.nx_max, - ) - self.targets = np.reshape( - self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1) - ) - - if self.feature_position: - indices_train = np.where(self.graph.nodes[:, 0] == 1)[0] - indices_test = np.where(graph_test.nodes[:, 0] == 1)[0] - target_test_wse = target_test - np.reshape( - graph_test.nodes[:, 0], (graph_test.nodes[:, 0].shape[0], -1) - ) - target_wse = self.targets - np.reshape( - self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1) - ) - else: - indices_train = np.where( self.graph.nodes[:]== 1)[0] - indices_test = np.where(graph_test.nodes[:] == 1)[0] - target_test_wse = target_test - graph_test.nodes[:] - target_wse = self.targets - self.graph.nodes[:] - - # Train - self.params, self.opt_state, loss = self._update_step( - self.params, self.opt_state - ) - - self.losses_train.append(loss) - outputs_train, roc_auc_train, MCC_train = self._evaluate( - self.params, self.graph, self.targets, True - ) - self.roc_aucs_train.append(roc_auc_train) - self.MCCs_train.append(MCC_train) - - # Train without end start in the target - loss_wse = self._compute_loss(self.params, self.graph, target_wse) - self.losses_train_wse.append(loss_wse) - outputs_train_wse, roc_auc_train_wse, MCC_train_wse = self._evaluate( - self.params, self.graph, target_wse, False, indices_train - ) - - self.MCCs_train_wse.append(MCC_train_wse) - - # Test - loss_test = self._compute_loss(self.params, graph_test, target_test) - self.losses_test.append(loss_test) - outputs_test, roc_auc_test, MCC_test = self._evaluate( - self.params, graph_test, target_test, True - ) - self.roc_aucs_test.append(roc_auc_test) - self.MCCs_test.append(MCC_test) - - # Test without end start in the target - loss_test_wse = self._compute_loss(self.params, graph_test, target_test_wse) - self.losses_test_wse.append(loss_test_wse) - outputs_test_wse_wrong, roc_auc_test_wse, MCC_test_wse = self._evaluate( - self.params, graph_test, target_test_wse, False, indices_test - ) - self.MCCs_test_wse.append(MCC_test_wse) - - # Log - wandb_logs = { - - "log_loss_test": np.log(loss_test), - "log_loss_test_wse": np.log(loss_test_wse), - "log_loss": np.log(loss), - "log_loss_wse": np.log(loss_wse), - - "roc_auc_test": roc_auc_test, - "roc_auc_test_wse": roc_auc_test_wse, - "roc_auc_train": roc_auc_train, - "roc_auc_train_wse": roc_auc_train_wse, - - - "MCC_test": MCC_test, - "MCC_test_wse": MCC_test_wse, - "MCC_train": MCC_train, - "MCC_train_wse": MCC_train_wse, - - } - if self.wandb_on: - wandb.log(wandb_logs) - self.global_steps = self.global_steps + 1 - if self.global_steps % self.log_every == 0: - print( - f"Training step {self.global_steps}: log_loss = {np.log(loss)} , log_loss_test = {np.log(loss_test)}, roc_auc_test = {roc_auc_test}, roc_auc_train = {roc_auc_train}" - ) - return - - def print_and_plot(self): - # EVALUATE - rng = next(self.rng_seq) - if self.train_on_shortest_path: - graph_test, target_test = sample_padded_grid_batch_shortest_path( - rng, - self.batch_size_test, - self.feature_position, - self.weighted, - self.nx_min_test, - self.nx_max_test, - ) - else: - rng = next(self.rng_seq) - graph_test, target_test = sample_padded_grid_batch_shortest_path( - rng, - self.batch_size_test, - self.feature_position, - self.weighted, - self.nx_min_test, - self.nx_max_test, - ) - target_test = np.reshape( - graph_test.nodes[:, 0], (graph_test.nodes[:, 0].shape[0], -1) - ) - - if self.feature_position: - indices_train = np.where(self.graph.nodes[:, 0] == 1)[0] - indices_test = np.where(graph_test.nodes[:, 0] == 1)[0] - target_test_wse = target_test - np.reshape( - graph_test.nodes[:, 0], (graph_test.nodes[:, 0].shape[0], -1) - ) - target_wse = self.targets - np.reshape( - self.graph.nodes[:, 0], (self.graph.nodes[:, 0].shape[0], -1) - ) - else: - indices_train = np.where(self.graph.nodes[:] == 1)[0] - indices_test = np.where(graph_test.nodes[:] == 1)[0] - target_test_wse = target_test - graph_test.nodes[:] - target_wse = self.targets - self.graph.nodes[:] - - - - outputs_test, roc_auc_test_wse, MCC_test_wse = self._evaluate( - self.params, graph_test, target_test_wse, False,indices_test - ) - outputs_test_wse= update_outputs_test(outputs_test, indices_test) - - outputs_test, roc_auc_test, MCC_test = self._evaluate( - self.params, graph_test, target_test, True - ) - - outputs, roc_auc_wse, MCC_wse = self._evaluate( - self.params, self.graph, target_wse, False, indices_train - ) - outputs_train_wse = update_outputs_test(outputs, indices_train) - - outputs, roc_auc, MCC = self._evaluate( - self.params, self.graph, self.targets, True - ) - - # SAVE PARAMETER (NOT WE SAVE THE FILES SO IT SHOULD BE THERE AS WELL ) - if self.wandb_on: - with open("readme.txt", "w") as f: - f.write("readme") - with open(os.path.join(self.save_path, "Constant.txt"), "w") as outfile: - outfile.write( - "num_message_passing_steps" - + str(self.num_message_passing_steps) - + "\n" - ) - outfile.write("Learning_rate:" + str(self.learning_rate) + "\n") - outfile.write("num_training_steps:" + str(self.num_training_steps)) - outfile.write("roc_auc" + str(roc_auc_test)) - outfile.write("MCC" + str(MCC_test)) - outfile.write("roc_auc_wse" + str(roc_auc_test_wse)) - outfile.write("MCC_wse" + str(MCC_test_wse)) - - # PLOTTING THE LOSS and AUC RO - plot_curves( - [self.losses_train, self.losses_test, self.losses_train_wse, self.losses_test_wse], - os.path.join(self.save_path, "Losses.pdf"), - "All_Losses", - legend_labels=["loss", "loss test", "loss_wse", "loss_test_wse"], - ) - - plot_curves( - [ - np.log(self.losses_train), - np.log(self.losses_test), - np.log(self.losses_train_wse), - np.log(self.losses_test_wse), - ], - os.path.join(self.save_path, "Log_Losses.pdf"), - "All_log_Losses", - legend_labels=[ - "log_loss", - "log_loss test", - "log_loss_wse", - "log_loss_test_wse", - ], - ) - - plot_curves([self.losses_train], os.path.join(self.save_path, "Losses_train.pdf"), "Losses") - plot_curves( - [self.losses_test], - os.path.join(self.save_path, "losses_test.pdf"), - "losses_test", - ) - plot_curves( - [self.losses_train_wse], - os.path.join(self.save_path, "Losses_wse.pdf"), - "Losses_wse", - ) - plot_curves( - [ self.losses_test_wse], - os.path.join(self.save_path, "losses_test_wse.pdf"), - "losses_test_wse", - ) - - plot_curves( - [self.roc_aucs_test, self.roc_aucs_train], - os.path.join(self.save_path, "auc_rocs.pdf"), - "All_auc_roc", - legend_labels=["auc_roc_test", "auc_roc_train"], - ) - plot_curves( - [self.roc_aucs_test], - os.path.join(self.save_path, "auc_roc_test.pdf"), - "auc_roc_test", - ) - plot_curves( - [self.roc_aucs_train], - os.path.join(self.save_path, "auc_roc_train.pdf"), - "auc_roc_train", - ) - - plot_curves( - [self.MCCs_train, self.MCCs_test, self.MCCs_train_wse, self.MCCs_test_wse], - os.path.join(self.save_path, "MCCs.pdf"), - "All_MCCs", - legend_labels=["MCC", "MCC test", "MCC_wse", "MCC_test_wse"], - ) - plot_curves( - [self.MCCs_train], os.path.join(self.save_path, "MCC_train.pdf"), "MCC_train" - ) - plot_curves( - [self.MCCs_test], os.path.join(self.save_path, "MCC_test.pdf"), "MCC_test" - ) - plot_curves( - [self.MCCs_train_wse], - os.path.join(self.save_path, "MCC_train_wse.pdf"), - "MCC_train_wse", - ) - plot_curves( - [self.MCCs_test_wse], - os.path.join(self.save_path, "MCC_test_wse.pdf"), - "MCC_test_wse", - ) - - # PLOTTING ACTIVATION FOR TEST AND THE TARGET OF THE THING ( NOTE THAT IS WAS TRANED ON THE ALL THING) - plot_input_target_output( - list(graph_test.nodes.sum(-1)), - target_test.sum(-1), - np.squeeze(outputs_test[0].nodes).tolist(), - graph_test, - 2, - self.edge_lables, - os.path.join(self.save_path, "in_out_targ_test.pdf"), - "in_out_targ_test", - ) - - new_vector = [1 if val > 0.3 else 0 for val in outputs_test[0].nodes] - plot_input_target_output( - list(graph_test.nodes.sum(-1)), - target_test.sum(-1), - new_vector, - graph_test, - 2, - self.edge_lables, - os.path.join(self.save_path, "in_out_targ_test_threshold.pdf"), - "in_out_targ_test", - ) - - plot_message_passing_layers( - list(graph_test.nodes.sum(-1)), - outputs_test[1], - target_test.sum(-1), - np.squeeze( - outputs_test[0].nodes).tolist(), - graph_test, - 2, - self.num_message_passing_steps, - self.edge_lables, - os.path.join( - self.save_path, - "message_passing_graph_test.pdf", - ), - "message_passing_graph_test", - ) - - plot_input_target_output( - list(graph_test.nodes.sum(-1)), - target_test_wse.sum(-1), - np.squeeze(outputs_test_wse).tolist(), - graph_test, - 2, - self.edge_lables, - os.path.join(self.save_path, "in_out_targ_test_wse.pdf"), - "in_out_targ_test_wse", - ) - - # Train - - # PLOTTING ACTIVATION OF THE FIRST 2 GRAPH OF THE BATCH - new_vector = [1 if val > 0.3 else 0 for val in outputs[0].nodes] - plot_input_target_output( - list(self.graph.nodes.sum(-1)), - self.targets.sum(-1), - new_vector, - self.graph, - 2, - self.edge_lables, - os.path.join(self.save_path, "in_out_targ_train_threshol.pdf"), - "in_out_targ_train", - ) - - plot_input_target_output( - list(self.graph.nodes.sum(-1)), - target_wse.sum(-1), - np.squeeze(outputs_train_wse).tolist(), - self.graph, - 2, - self.edge_lables, - os.path.join(self.save_path, "in_out_targ_train_wse.pdf"), - "in_out_targ_train_wse", - ) - - plot_input_target_output( - list(self.graph.nodes.sum(-1)), - self.targets.sum(-1), - np.squeeze( - outputs[0].nodes).tolist(), - self.graph, - 2, - self.edge_lables, - os.path.join(self.save_path, "in_out_targ_train.pdf"), - "in_out_targ_train", - ) - - plot_message_passing_layers( - list(self.graph.nodes.sum(-1)), - outputs[1], - self.targets.sum(-1), - np.squeeze( - outputs[0].nodes).tolist(), - self.graph, - 2, - self.num_message_passing_steps, - self.edge_lables, - os.path.join(self.save_path, "message_passing_graph_train.pdf"), - "message_passing_graph_train", - ) - - print('End') - - -if __name__ == "__main__": - from neuralplayground.arenas import Simple2D - - # @title Graph net functions - parser = argparse.ArgumentParser() - parser.add_argument( - "--config_path", - metavar="-C", - default="domine_2023_extras/class_config.yaml", - help="path to base configuration file.", - ) - - args = parser.parse_args() - set_device() - config_class = GridConfig - config = config_class(args.config_path) - - # Init environment - arena_x_limits = [-100, 100] - arena_y_limits = [-100, 100] - - - agent = Domine2023( - experiment_name=config.experiment_name, - train_on_shortest_path=config.train_on_shortest_path, - resample=config.resample, # @param - wandb_on=config.wandb_on, - seed=config.seed, - feature_position=config.feature_position, - weighted=config.weighted, - num_hidden=config.num_hidden, # @param - num_layers=config.num_layers, # @param - num_message_passing_steps=config.num_message_passing_steps, # @param - learning_rate=config.learning_rate, # @param - num_training_steps=config.num_training_steps, # @param - batch_size=config.batch_size, - nx_min=config.nx_min, - nx_max=config.nx_max, - batch_size_test=config.batch_size_test, - nx_min_test=config.nx_min_test, - nx_max_test=config.nx_max_test, - arena_y_limits=arena_y_limits, - arena_x_limits=arena_x_limits, - ) - - for n in range(config.num_training_steps): - agent.update() - agent.print_and_plot() - -# TODO: Run manadger (not possible for now), to get a seperated code we would juste need to change the paths and config this would mean get rid of the comfig -# The other alternative is to see that we have multiple env that we resample every time -# TODO: Make juste an env type (so that is accomodates for not only 2 d env// different transmats) -# TODO: Make The plotting in the general plotting utilse -# if __name__ == "__main__": -# x = Domine2023() -# x = x.replace(obs_history=[1, 2], num_hidden=2) -# x.num_hidden = 5 -# -# x.update() diff --git a/neuralplayground/agents/domine_2023_extras_2/class_config_template.py b/neuralplayground/agents/domine_2023_extras_2/class_config_template.py index 87bdf59c..5a5bffd7 100644 --- a/neuralplayground/agents/domine_2023_extras_2/class_config_template.py +++ b/neuralplayground/agents/domine_2023_extras_2/class_config_template.py @@ -8,10 +8,6 @@ class ConfigTemplate: name="experiment_name", types=[str, type(None)], ), - config_field.Field( - name="resample", - types=[bool], - ), config_field.Field( name="wandb_on", types=[bool], @@ -20,10 +16,6 @@ class ConfigTemplate: name="batch_size", types=[int], ), - config_field.Field( - name="num_nodes_min", - types=[int], - ), config_field.Field( name="num_features", types=[int], @@ -36,10 +28,6 @@ class ConfigTemplate: name="batch_size_test", types=[int], ), - config_field.Field( - name="num_nodes_min_test", - types=[int], - ), config_field.Field( name="num_nodes_max_test", types=[list], @@ -70,10 +58,6 @@ class ConfigTemplate: name="dataset", types=[str], ), - config_field.Field( - name="weighted", - types=[bool], - ), config_field.Field( name="num_training_steps", types=[float, int], @@ -82,10 +66,6 @@ class ConfigTemplate: name="residual", types=[bool], ), - config_field.Field( - name="layer_norm", - types=[bool], - ), config_field.Field( name="plot", types=[bool], diff --git a/neuralplayground/agents/domine_2023_extras_2/config.yaml b/neuralplayground/agents/domine_2023_extras_2/config.yaml index 21a38fc4..e88c784e 100644 --- a/neuralplayground/agents/domine_2023_extras_2/config.yaml +++ b/neuralplayground/agents/domine_2023_extras_2/config.yaml @@ -1,25 +1,20 @@ -experiment_name: 'hello' -resample: False # @param +experiment_name: 'random_mp_0' wandb_on: False seed: 45 plot: True -dataset: 'positional' # 'random' or 'omniglot' or positional'positional_no_edges -weighted: False +dataset: 'random' # 'random' or 'omniglot' or positional'positional_no_edges num_hidden: 15 # @param num_layers: 1 # @param -num_message_passing_steps: 0 # @param -learning_rate: 0.005 # @param -num_training_steps: 80 # @param +num_message_passing_steps: 0 # @param which is effectivly 5 +learning_rate: 0.001 # @param +num_training_steps: 15000 # @param residual: True -layer_norm: False # Env Stuff batch_size: 2 -num_nodes_max: 5 -num_nodes_min: 5 +num_nodes_max: 15 num_features: 1 - batch_size_test: 2 -num_nodes_max_test: [6,7,8,9] #min2 -num_nodes_min_test: 10 \ No newline at end of file +num_nodes_max_test: [4,5,6,7,8,10] + diff --git a/neuralplayground/agents/domine_2023_extras_2/models/GCN_model.py b/neuralplayground/agents/domine_2023_extras_2/models/GCN_model.py index 4418d323..95798f7a 100644 --- a/neuralplayground/agents/domine_2023_extras_2/models/GCN_model.py +++ b/neuralplayground/agents/domine_2023_extras_2/models/GCN_model.py @@ -4,21 +4,14 @@ from torch_geometric.nn import GCNConv,global_mean_pool # TODO(clementine): set up object oriented GNN classes (eventually) class GCNModel(nn.Module): - def __init__(self, num_hidden, num_feature, num_layers, num_message_passing_steps, residual, layer_norm): + def __init__(self, num_hidden, num_feature, num_layers, num_message_passing_steps, residual): super(GCNModel, self).__init__() self.num_message_passing_steps = num_message_passing_steps self.conv_1 = GCNConv(num_feature, num_hidden) self.conv_layers = nn.ModuleList([GCNConv(num_hidden, num_hidden) for _ in range(num_message_passing_steps)]) - # Output layer with size 2 for binary classification logits self.fc = nn.Linear(num_hidden, 2) - self.residual = residual - self.layer_norm = layer_norm - self.norm_layers = nn.ModuleList( - [nn.LayerNorm(num_hidden) for _ in range(num_message_passing_steps)]) if layer_norm else None - # Define the softmax layer - self.softmax = nn.Softmax(dim=1) # Apply softmax across the logits (along the class dimension) def forward(self, node, edges, edges_attr): x, edge_index = node, edges @@ -27,21 +20,56 @@ def forward(self, node, edges, edges_attr): for i, conv in enumerate(self.conv_layers): x_res = x x = conv(x, edge_index, edges_attr) - if self.layer_norm: - x = self.norm_layers[i](x) if self.residual: x += x_res x = torch.relu(x) - # The output layer now produces 2 logits for each graph node x = self.fc(x) - # Pooling to get a graph-level representation x = global_mean_pool(x, batch=None) - - # Apply softmax to the logits to convert them to probabilities - # x = x.view(-1) #Flatten the tensor to 1D if necessary #TODO: ask here if this makes sense + return x + +# This is juste to test what happens when we don't have the message passing layers + + +class MLP(nn.Module): + def __init__(self, num_hidden, num_feature, num_layers, num_message_passing_steps, residual): + super(MLP, self).__init__() + self.fc1 = nn.Linear(num_feature, num_hidden) # First fully connected layer + self.fc2 = nn.Linear(num_hidden, 2) + def forward(self, node, edges, edges_attr): + x, edge_index = node, edges + # The output layer now produces 2 logits for each graph node + x = torch.relu(self.fc1(x)) # Apply ReLU activation to first layer's output + x = self.fc2(x) # Pass through the second layer + # Pooling to get a graph-level representation + x = global_mean_pool(x, batch=None) + return x + +class GCNModel_2(nn.Module): + def __init__(self, num_hidden, num_feature, num_layers, num_message_passing_steps, residual): + super(GCNModel_2, self).__init__() + self.num_message_passing_steps = num_message_passing_steps - 1 + self.conv_1 = GCNConv(num_feature, num_hidden) + self.conv_layers = nn.ModuleList([GCNConv(num_hidden, num_hidden) for _ in range(num_message_passing_steps)]) + # Output layer with size 2 for binary classification logits + self.fc = nn.Linear(num_hidden, 2) + self.residual = residual + + def forward(self, node, edges,edges_attr): + - return x \ No newline at end of file + x, edge_index = node, edges + x = self.conv_1(x, edge_index) + x = torch.relu(x) + for i, conv in enumerate(self.conv_layers): + x_res = x + x = conv(x, edge_index) + if self.residual: + x += x_res + x = torch.relu(x) + x = self.fc(x) + x = global_mean_pool(x, batch=None) + return x diff --git a/neuralplayground/agents/domine_2023_extras_2/processing/Graph_generation.py b/neuralplayground/agents/domine_2023_extras_2/processing/Graph_generation.py index 0dd2fc48..d532ce6e 100644 --- a/neuralplayground/agents/domine_2023_extras_2/processing/Graph_generation.py +++ b/neuralplayground/agents/domine_2023_extras_2/processing/Graph_generation.py @@ -163,7 +163,9 @@ def sample_graph(num_features, num_nodes, feature_type='random'): # Append position features if specified if feature_type == 'positional' or feature_type == 'positional_no_edges': - position = torch.tensor(np.arange(num_nodes)).unsqueeze(1) # Shape: (num_nodes, 1) + # Position also do the moving + position = torch.tensor(np.linspace(0,1,num_nodes)).unsqueeze(1) + # Shape: (num_nodes, 1) combined_node_features = np.concatenate([combined_node_features, position], axis=1) # Convert combined node features to a tensor @@ -171,7 +173,7 @@ def sample_graph(num_features, num_nodes, feature_type='random'): # Return based on feature_type if feature_type == 'positional_no_edges': - return node_features, edges, source, sink # No edge features + return node_features, edges, edge_features_tensor, source, sink # No edge features else: return node_features, edges, edge_features_tensor, source, sink @@ -185,7 +187,6 @@ def sample_fixed_graph(num_features, num_nodes, feature_type='random', sositype= source = 2 input_node_features[source, 0] = 1 # Set source node feature input_node_features[sink, 1] = 1 # Set sink node feature - # Combine node features and input features combined_node_features = np.concatenate([node_features.T, input_node_features], axis=1) # Append position features if specified @@ -198,7 +199,7 @@ def sample_fixed_graph(num_features, num_nodes, feature_type='random', sositype= # Return based on feature_type if feature_type == 'positional_no_edges': - return node_features, edges, source, sink # No edge features + return node_features, edges, edge_features_tensor, source, sink # No edge features else: return node_features, edges, edge_features_tensor, source, sink diff --git a/neuralplayground/agents/domine_2023_extras_2/utils/plotting_utils.py b/neuralplayground/agents/domine_2023_extras_2/utils/plotting_utils.py index 32415f45..8de8136f 100644 --- a/neuralplayground/agents/domine_2023_extras_2/utils/plotting_utils.py +++ b/neuralplayground/agents/domine_2023_extras_2/utils/plotting_utils.py @@ -9,7 +9,7 @@ import networkx as nx import numpy as np import torch - +from scipy.ndimage import convolve1d def plot_2dgraphs(edges_list, node_features, edge_features_list, subplots_titles, path=None, colorscale='Plasma', size=5, show=True): """ @@ -120,7 +120,7 @@ def color_map(output): sm._A = [] return sm -def plot_curves(curves, path, title, legend_labels=None, x_label=None, y_label=None): +def plot_curves(curves, path, title, legend_labels=None, x_label=None, y_label=None, time_steps=100): fig, ax = plt.subplots(figsize=(8, 6)) ax.set_title(title) @@ -137,6 +137,9 @@ def plot_curves(curves, path, title, legend_labels=None, x_label=None, y_label=N # Map the values to colors using the chosen colormap colors = [colormap(value) for value in values] + # Create a simple kernel for convolution (moving average) + kernel = np.ones(time_steps) / time_steps + for i, curve in enumerate(curves): label = legend_labels[i] if legend_labels else None color = colors[i % len(colors)] @@ -145,6 +148,9 @@ def plot_curves(curves, path, title, legend_labels=None, x_label=None, y_label=N if isinstance(curve, torch.Tensor): curve = curve.detach().numpy() + # Apply convolution to smooth the curve over the specified time steps + curve = convolve1d(curve, kernel, mode='reflect') + ax.plot(curve, label=label, color=color) if legend_labels: @@ -152,52 +158,60 @@ def plot_curves(curves, path, title, legend_labels=None, x_label=None, y_label=N plt.savefig(path) plt.show() plt.close() +def plot_curves_2(curves, std_devs=None, path=None, title=None, legend_labels=None, x_label=None, y_label=None, time_steps=1): + fig, ax = plt.subplots(figsize=(8, 6)) + ax.set_title(title) -def plot_curves_2(curves, std_devs=None, path=None, title=None, legend_labels=None, x_label=None, y_label=None): - fig, ax = plt.subplots(figsize=(8, 6)) - ax.set_title(title) + if x_label: + ax.set_xlabel(x_label) + if y_label: + ax.set_ylabel(y_label) - if x_label: - ax.set_xlabel(x_label) - if y_label: - ax.set_ylabel(y_label) + colormap = plt.get_cmap("viridis") - colormap = plt.get_cmap("viridis") + # Use numpy linspace to create the values array + values = np.linspace(0, 1, len(curves)) - # Use numpy linspace to create the values array - values = np.linspace(0, 1, len(curves)) + # Map the values to colors using the chosen colormap + colors = [colormap(value) for value in values] - # Map the values to colors using the chosen colormap - colors = [colormap(value) for value in values] + # Create a simple kernel for convolution (moving average) + kernel = np.ones(time_steps) / time_steps - for i, curve in enumerate(curves): - label = legend_labels[i] if legend_labels else None - color = colors[i % len(colors)] + for i, curve in enumerate(curves): + label = legend_labels[i] if legend_labels else None + color = colors[i % len(colors)] - # Convert PyTorch tensors to NumPy arrays for plotting - if isinstance(curve, torch.Tensor): - curve = curve.detach().numpy() + # Convert PyTorch tensors to NumPy arrays for plotting + if isinstance(curve, torch.Tensor): + curve = curve.detach().numpy() - # Plot the average curve - ax.plot(curve, label=label, color=color) + # Apply convolution to smooth the curve over the specified time steps + curve = convolve1d(curve, kernel, mode='reflect') - # If std_devs are provided, plot the shaded region for the standard deviation - if std_devs is not None: - std_dev = std_devs[i] + # Plot the smoothed average curve + ax.plot(curve, label=label, color=color) - # Convert std_dev to numpy if it is a PyTorch tensor - if isinstance(std_dev, torch.Tensor): - std_dev = std_dev.detach().numpy() + # If std_devs are provided, plot the shaded region for the standard deviation + if std_devs is not None: + std_dev = std_devs[i] - # Shaded region (curve ± std deviation) - ax.fill_between(np.arange(len(curve)), np.asarray(curve) - np.asarray( std_dev), np.asarray(curve) + np.asarray(std_dev), color=color, alpha=0.3) + # Convert std_dev to numpy if it is a PyTorch tensor + if isinstance(std_dev, torch.Tensor): + std_dev = std_dev.detach().numpy() - if legend_labels: - ax.legend() + # Apply convolution to smooth the standard deviation over the specified time steps + std_dev = convolve1d(std_dev, kernel, mode='reflect') - # Save plot to file if path is provided - if path: - plt.savefig(path) + # Shaded region (curve ± std deviation) + ax.fill_between(np.arange(len(curve)), curve - std_dev, curve + std_dev, color=color, alpha=0.3) - plt.show() - plt.close() \ No newline at end of file + if legend_labels: + ax.legend() + + # Save plot to file if path is provided + if path: + plt.savefig(path) + + plt.show() + plt.close() \ No newline at end of file diff --git a/neuralplayground/agents/lenghts.py b/neuralplayground/agents/lenghts.py new file mode 100644 index 00000000..82d91e08 --- /dev/null +++ b/neuralplayground/agents/lenghts.py @@ -0,0 +1,557 @@ +import argparse +import os +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" +import math +import torch +import shutil +from datetime import datetime +from pathlib import Path +import torch.nn as nn +import torch.optim as optim +import wandb +import numpy as np +from neuralplayground.agents.agent_core import AgentCore +from neuralplayground.agents.domine_2023_extras_2.utils.plotting_utils import plot_curves, plot_curves_2, plot_2dgraphs +from neuralplayground.agents.domine_2023_extras_2.models.GCN_model import GCNModel, MLP ,GCNModel_2 +from neuralplayground.agents.domine_2023_extras_2.class_grid_run_config import GridConfig +from neuralplayground.agents.domine_2023_extras_2.utils.utils import set_device +from neuralplayground.agents.domine_2023_extras_2.processing.Graph_generation import sample_graph, sample_target, sample_omniglot_graph, sample_fixed_graph +from torchmetrics import Accuracy, Precision, AUROC, Recall, MatthewsCorrCoef +from torchmetrics.classification import BinaryAccuracy + +# from neuralplayground.agents.domine_2023_extras_2.evaluate import Evaluator +os.environ["KMP_DUPLICATE_LIB_OK"] = "True" + +class Domine2023(AgentCore): + def __init__(self, experiment_name="smaller size generalisation graph with no position feature", + train_on_shortest_path=True, wandb_on=False, seed=41, dataset = 'random', + num_hidden=100, num_layers=2, num_message_passing_steps=3, learning_rate=0.001, + num_training_steps=10, residual=True, batch_size=4, num_features=4, num_nodes_max=7, + batch_size_test=4, num_nodes_min_test=4, num_nodes_max_test=[7], plot=True, **mod_kwargs): + super(Domine2023, self).__init__() + + # General + np.random.seed(seed) + self.plot = plot + self.obs_history = [] + self.grad_history = [] + self.experiment_name = experiment_name + self.wandb_on = wandb_on + self.seed = seed + self.log_every = 500 + + # Network + self.num_hidden = num_hidden + self.num_layers = num_layers + self.num_message_passing_steps = num_message_passing_steps + self.learning_rate = learning_rate + self.num_training_steps = num_training_steps + self.batch_size = batch_size + self.residual = residual + + + # Task + self.dataset = dataset + self.num_features = num_features + self.num_nodes_max = num_nodes_max + self.num_nodes_max_test = num_nodes_max_test + + + def set_initial_seed(seed): + # Set the seed for NumPy + np.random.seed(seed) + + # Set the seed for PyTorch + torch.manual_seed(seed) + + # If using CUDA + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # If you are using multi-GPU + + # Optional: For deterministic behavior, you can use the following: + # This is usually needed for reproducibility with certain layers like convolution. + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + set_initial_seed(seed) + + self.batch_size_test = batch_size_test + self.arena_x_limits = mod_kwargs["arena_x_limits"] + self.arena_y_limits = mod_kwargs["arena_y_limits"] + save_path = mod_kwargs["save_path"] + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if self.dataset == 'random': + self.model = GCNModel(self.num_hidden, self.num_features + 2, self.num_layers, + self.num_message_passing_steps, self.residual + ).to(self.device) + elif self.dataset == 'positional': + self.model = GCNModel(self.num_hidden, self.num_features + 3, self.num_layers, + self.num_message_passing_steps, self.residual + ).to(self.device) + elif self.dataset == 'positional_no_edges': + self.model = GCNModel_2(self.num_hidden, self.num_features + 3, self.num_layers, + self.num_message_passing_steps, self.residual + ).to(self.device) + else: + num_features = 784 + self.model = GCNModel(self.num_hidden, num_features + 2, self.num_layers, + self.num_message_passing_steps, self.residual, + ).to(self.device) + + + self.auroc = AUROC(task="binary") + self.ACC = BinaryAccuracy() + self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate) + self.criterion = nn.CrossEntropyLoss() + if self.wandb_on: + dateTimeObj = datetime.now() + wandb.init(project="New", entity="graph-brain", + name=experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S")) + self.wandb_logs = {} + save_path = wandb.run.dir + os.mkdir(os.path.join(save_path,"results")) + self.save_path = os.path.join(save_path, "results") + + self.reset() + self.wandb_logs = { # This is thought of the state density + "batch_size": self.batch_size, + "num_node_min": self.num_nodes_max, # This is thought of the state density + "seed": self.seed, + "dataset": self.dataset, + "num_hidden": self.num_hidden, + "num_layers": self.num_layers, + "num_message_passing_steps": self.num_message_passing_steps, + "learning_rate": self.learning_rate, + "num_training_steps": self.num_training_steps, + "residual": self.residual, + } + if self.wandb_on: + wandb.log(self.wandb_logs) + else: + self.save_path= save_path + self.save_run_parameters() + + def save_run_parameters(self): + """Save configuration files and scripts.""" + files_to_copy = [ + ("run.py", "domine_2023_2.py"), + ("Graph_generation.py", "domine_2023_extras_2/processing/Graph_generation.py"), + ("utils.py", "domine_2023_extras_2/utils/utils.py"), + ("plotting_utils.py", "domine_2023_extras_2/utils/plotting_utils.py"), + ("config_run.yaml", "domine_2023_extras_2/config.yaml"), + ] + for file_name, source in files_to_copy: + shutil.copyfile(os.path.join(Path(os.getcwd()).resolve(), source), os.path.join(self.save_path, file_name)) + + def load_data(self, fixed, dataset, batch_size,num_nodes): + # Initialize lists to store batch data + node_features_batch, edges_batch, edge_features_batch, target_batch = [], [], [], [] + + # Determine the max nodes to use based on whether it is training or testing + + # Loop to generate a batch of data + for _ in range(batch_size): + if fixed: + node_features, edges, edge_features_tensor, source, sink = sample_fixed_graph(self.num_features, + num_nodes,feature_type= dataset) + else: + # Handle Omniglot dataset + if dataset == 'omniglot': + node_features, edges, edge_features_tensor, source, sink = sample_omniglot_graph(num_nodes) + # Handle Random dataset + else: + node_features, edges, edge_features_tensor, source, sink = sample_graph(self.num_features, + num_nodes,feature_type= dataset) + + + + # Sample the target based on source and sink + target = sample_target(source, sink) + + # Append each graph data to the batch list + node_features_batch.append(node_features) + edges_batch.append(edges) + edge_features_batch.append(edge_features_tensor) + target_batch.append(target) + + return node_features_batch, edges_batch, edge_features_batch, target_batch + + def compute_loss(self, outputs, targets): + loss = self.criterion(outputs, targets) + return loss + + def run_model(self, node, edges,edges_features): + outputs = self.model(node,edges,edges_features) + return outputs + + def update_step(self, node_batch, edges_batch, edges_features_batch, target_batch, train, batch_size=1): + if train: + self.model.train() + self.optimizer.zero_grad() + else: + self.model.eval() + + batch_losses = 0 + all_outputs = [] + + # Loop over the batch + for i in range(batch_size): + data = node_batch[i].to(self.device) + edges = edges_batch[i].to(self.device) + edges_features = edges_features_batch[i].to(self.device) + + # Forward pass + outputs = self.run_model(data, edges, edges_features) + all_outputs.append(outputs.view(-1)) + # Compute loss for this sample + loss = self.compute_loss(outputs, target_batch[i]) + #target = torch.randn(2, 2).softmax(dim=1) + #input = target * 0.99999 + #self.compute_loss(input, target) + + batch_losses += loss + + # Average loss over the batch + avg_loss = batch_losses / batch_size + #all_outputs = torch.stack(all_outputs) + # all_target_1 = torch.stack(target_batch).view(-1) + # all_target = torch.stack(target_batch) + # loss = self.compute_loss(all_outputs, all_target_1) + + if train: + avg_loss.backward() + self.optimizer.step() + + # Concatenate all outputs for evaluation over the batch + all_outputs = torch.stack(all_outputs) + all_target = torch.stack(target_batch) + # Evaluate using the full batch's predictions + acc = self.evaluate(all_outputs, all_target) + + return avg_loss.item(), acc + + def evaluate(self,outputs,targets): + with (torch.no_grad()): + # roc_auc_score(targets.cpu(), outputs.cpu()) + labels = targets.view(-1) # Outputs: [0, 1, 0, 1] + # Convert predicted probabilities to class labels using argmax + Softmax = nn.Softmax(dim=1) + outputs = Softmax(outputs) + predicted_labels = torch.argmax(outputs, dim=1) # Outputs: [0, 1, 0, 1] + # Initialize the BinaryAccuracy metric + acc = self.ACC(predicted_labels, labels) + return acc + + def log_training(self, train_loss, val_loss, train_acc, val_acc): + """Log training and validation metrics.""" + wandb_logs = { + "train_loss": train_loss, + "val_loss": val_loss, + "ACC_train": train_acc, + "ACC_val": val_acc + } + if self.wandb_on: + wandb.log(wandb_logs) + + def train(self): + # Load validation data + val_graphs = {num_node: [] for num_node in self.num_nodes_max_test} + for i in range(len(self.num_nodes_max_test)): + node_features_val, edges_val, edge_features_tensor_val, target_val = self.load_data(fixed=False, + dataset=self.dataset, + batch_size=self.batch_size, num_nodes= self.num_nodes_max_test[i]) + val_graphs[self.num_nodes_max_test[i]] = [node_features_val, edges_val, edge_features_tensor_val, target_val] + + #for i in len(self.num_nodes_max_test): + # This is an attemp + node_features_val_f, edges_val_f, edge_features_tensor_val_f, target_val_f = self.load_data(fixed=True, + dataset=self.dataset, + batch_size=self.batch_size, + num_nodes= + self.num_nodes_max_test[0], + ) + # need to save the fixed one node_featur + + for epoch in range(self.num_training_steps): + + node_features, edges, edge_features_tensor, target = self.load_data(fixed=False, + dataset=self.dataset, + batch_size=self.batch_size, num_nodes= self.num_nodes_max) + #Train on each batch + batch_losses, batch_acc = self.update_step(node_features, edges, + edge_features_tensor, target, train=True, batch_size=self.batch_size) + + #Aggregate batch results + train_losses = batch_losses + train_acc = batch_acc.detach().numpy() + # Average batch results over the batch size + + # Store results for plotting + self.losses_train.append(train_losses) + + self.ACCs_train.append(train_acc) + + # Validation + with torch.no_grad(): + for num_node in self.num_nodes_max_test: + node_features_val, edges_val, edge_features_tensor_val, target_val = val_graphs[num_node] + val_losses, val_acc = self.update_step(node_features_val, edges_val, + edge_features_tensor_val, target_val, + train=False,batch_size=self.batch_size) + self.losses_val[num_node].append(val_losses) + self.ACCs_val[num_node].append(val_acc) + + + # Log training details + #TODO: Need to update this + self.log_training(train_losses, val_losses, train_acc, val_acc) + + # Plot progress every epoch + if self.global_steps % self.log_every == 0: + print( + f"Epoch {epoch}: Train Loss = {train_losses}, Val Loss = {val_losses}, ACC Train = {train_acc}, ACC Val = {val_acc} ") + self.global_steps += 1 + print("Finished training") + + + if self.plot: + os.makedirs(os.path.join(self.save_path, "results"), exist_ok=True) + self.save_path = os.path.join(self.save_path, "results") + file_name = f"Losses_{seed}.pdf" + + # Combine the path and file name + list_of_lists = [value for value in self.losses_val.values()] + list_of_list_name = [f'loss_val_len{value}' for value in self.losses_val] + # Append losses_train to the list of lists + list_of_lists.append(self.losses_train) + list_of_list_name.append('loss_train') + + plot_curves( + list_of_lists , + os.path.join(self.save_path, file_name ), + "All_Losses", + legend_labels=list_of_list_name, + ) + + file_name = f"ACCs_val_{seed}.pdf" + plot_curves( [value for value in self.ACCs_val.values()], + os.path.join(self.save_path, file_name), + "ACC Val", + legend_labels=[f'ACC_val_len{value}' for value in self.losses_val], + ) + file_name = f"ACCs_train_{seed}.pdf" + + plot_curves( + [ + self.ACCs_train, + ], + os.path.join(self.save_path, file_name), + "ACC train", + legend_labels=["ACC train"], + ) + + + def sample_and_store(n): + # Initialize empty lists to store each sample's output + node_features_list = [] + edges_list = [] + edge_features_tensor_list = [] + target_list = [] + # Loop n times to sample data and store the outputs + for _ in range(n): + # Sample data by calling load_data + node_features, edges, edge_features_tensor, target = self.load_data(train=False, + dataset=self.dataset, + batch_size=self.batch_size) + # Append the results to the corresponding lists + node_features_list.append(node_features) + edges_list.append(edges) + edge_features_tensor_list.append(edge_features_tensor) + target_list.append(target) + return node_features_list, edges_list, edge_features_tensor_list, target_list + n=2 + #node_features_list, edges_list, edge_features_tensor_list, target_list = sample_and_store(n) + #plot_2dgraphs(edges_list, node_features_list, edge_features_tensor_list,['',''], os.path.join(self.save_path, "graph.pdf"), colorscale='Plasma',size=5,show=True) + return self.losses_train, self.ACCs_train, self.losses_val, self.ACCs_val + def reset(self): + self.obs_history = [] + self.grad_history = [] + self.global_steps = 0 + self.losses_train = [] + self.losses_val = {num_node: [] for num_node in self.num_nodes_max_test} + self.ACCs_val = {num_node: [] for num_node in self.num_nodes_max_test} + self.ACCs_train = [] + + return + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config_path", metavar="-C", default="domine_2023_extras_2/config.yaml", + help="path to base configuration file.") + args = parser.parse_args() + set_device() + config_class = GridConfig + config = config_class(args.config_path) + + arena_x_limits = [-100, 100] + arena_y_limits = [-100, 100] + + seeds = [41,42] + losses_train = {seed: [] for seed in seeds} + losses_val = {seed: [] for seed in seeds} + ACCs_train = {seed: [] for seed in seeds} + ACCs_val = {seed: [] for seed in seeds} + dateTimeObj = datetime.now() + save_path = os.path.join(Path(os.getcwd()).resolve(), "results") + os.mkdir( + os.path.join( + save_path, + config.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), + ) + ) + save_path = os.path.join( + os.path.join( + save_path, + config.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), + ) + ) + for seed in seeds: + agent = Domine2023( + experiment_name=config.experiment_name, + wandb_on=config.wandb_on, + seed=seed, + dataset=config.dataset, + num_hidden=config.num_hidden, + num_layers=config.num_layers, + num_message_passing_steps=config.num_message_passing_steps, + learning_rate=config.learning_rate, + num_training_steps=config.num_training_steps, + batch_size=config.batch_size, + num_features=config.num_features, + num_nodes_max=config.num_nodes_max, + num_nodes_min=config.num_nodes_min, + batch_size_test=config.batch_size_test, + num_nodes_min_test=config.num_nodes_min_test, + num_nodes_max_test=config.num_nodes_max_test, + arena_y_limits=arena_y_limits, + arena_x_limits=arena_x_limits, + residual=config.residual, + plot=config.plot, + save_path = save_path + ) + + + + losse_train, ACC_train, losse_val, ACC_val = agent.train() + losses_train[seed] = losse_train + losses_val[seed] = losse_val + ACCs_train[seed]= ACC_train + ACCs_val[seed] = ACC_val + + save_path = os.path.join(save_path, "results") + + num_training_steps = config.num_training_steps + # Initialize lists to store standard deviation results + std_losses_train = [] + std_accs_train = [] + + # Compute average and standard deviation for training loss + + avg_losses_train = [] + for epoch_idx in range(num_training_steps): + # Average the loss for this epoch over all seeds + avg_epoch_loss = sum(losses_train[seed][epoch_idx] for seed in seeds) / len(seeds) + avg_losses_train.append(avg_epoch_loss) + + # Compute standard deviation for this epoch + variance_loss = sum((losses_train[seed][epoch_idx] - avg_epoch_loss) ** 2 for seed in seeds) / len(seeds) + std_epoch_loss = math.sqrt(variance_loss) + std_losses_train.append(std_epoch_loss) + + # Compute average and standard deviation for training accuracy + avg_accs_train = [] + for epoch_idx in range(num_training_steps): + # Average the accuracy for this epoch over all seeds + avg_epoch_acc = sum(ACCs_train[seed][epoch_idx] for seed in seeds) / len(seeds) + avg_accs_train.append(avg_epoch_acc) + + # Compute standard deviation for this epoch + variance_acc = sum((ACCs_train[seed][epoch_idx] - avg_epoch_acc) ** 2 for seed in seeds) / len(seeds) + std_epoch_acc = math.sqrt(variance_acc) + std_accs_train.append(std_epoch_acc) + + + # Compute average and standard deviation for validation loss + avg_losses_val_len = [] + std_losses_val_len = [] + for i in config.num_nodes_max_test: + avg_losses_val = [] + std_losses_val = [] + for epoch_idx in range(num_training_steps): + avg_epoch_loss_val = sum(losses_val[seed][i][epoch_idx] for seed in seeds) / len(seeds) + avg_losses_val.append(avg_epoch_loss_val) + variance_loss_val = sum( + (losses_val[seed][i][epoch_idx] - avg_epoch_loss_val) ** 2 for seed in seeds) / len(seeds) + std_epoch_loss_val = math.sqrt(variance_loss_val) + std_losses_val.append(std_epoch_loss_val) + avg_losses_val_len.append(avg_losses_val) + std_losses_val_len.append(std_losses_val) + + #Compute average and standard deviation for validation accuracy + avg_accs_val_len = [] + std_accs_val_len = [] + for i in config.num_nodes_max_test: + avg_accs_val = [] + std_accs_val = [] + for epoch_idx in range(num_training_steps): + avg_epoch_acc_val = sum(ACCs_val[seed][i][epoch_idx] for seed in seeds) / len(seeds) + avg_accs_val.append(avg_epoch_acc_val) + + # Compute standard deviation for this epoch + variance_acc_val = sum((ACCs_val[seed][i][epoch_idx] - avg_epoch_acc_val) ** 2 for seed in seeds) / len(seeds) + std_epoch_acc_val = math.sqrt(variance_acc_val) + std_accs_val.append(std_epoch_acc_val) + avg_accs_val_len.append(avg_accs_val) + std_accs_val_len.append(std_accs_val) + + + + list_of_list_name = [f'loss_val_len{value}' for value in losses_val[seed]] + # Append losses_train to the list of lists + avg_losses_val_len.append(avg_losses_train) + list_of_list_name.append('loss_train') + std_losses_val_len.append(std_accs_train) + + plot_curves_2( + avg_losses_val_len,std_losses_val_len, + os.path.join(save_path, "Losses.pdf"), + "All_Losses", + legend_labels= list_of_list_name, + ) + plot_curves_2( + [ + avg_accs_train , + ],[std_accs_train], + os.path.join(save_path, "ACCs_train.pdf"), + "ACC Train", + legend_labels=["ACC Train"], + ) + + list_of_list_name = [f'loss_val_len{value}' for value in losses_val[seed]] + plot_curves_2( + avg_accs_val_len,std_accs_val_len, os.path.join(save_path, "ACCs_val.pdf"), + "ACC val", + legend_labels=list_of_list_name, + ) + print() + + #TODO: They all have different evaluation ( netwokr ) do we want ot eval ( for the average it should be ifne) + #TODO: Think about nice visualisaiton + #TODO: update the plotting for the other curves + # I need to check the logging of the results + # TODO. : plan a set of experiements to run, sudy how different initialisiton + #TODO: Get a different set of valisaion lenght for each run + + # TODO : the set of seed changes every run. so it is fine. The question is + #TODO: What is the the best way to have the dedges no features \ No newline at end of file