diff --git a/neuralplayground/agents/domine_2023_2.py b/neuralplayground/agents/domine_2023_2.py index b83bb78..bd71e3e 100644 --- a/neuralplayground/agents/domine_2023_2.py +++ b/neuralplayground/agents/domine_2023_2.py @@ -1,6 +1,7 @@ import argparse import os os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" +import math import torch import shutil from datetime import datetime @@ -8,12 +9,13 @@ import torch.nn as nn import torch.optim as optim import wandb +import numpy as np from neuralplayground.agents.agent_core import AgentCore from neuralplayground.agents.domine_2023_extras_2.utils.plotting_utils import plot_curves, plot_curves_2, plot_2dgraphs from neuralplayground.agents.domine_2023_extras_2.models.GCN_model import GCNModel from neuralplayground.agents.domine_2023_extras_2.class_grid_run_config import GridConfig from neuralplayground.agents.domine_2023_extras_2.utils.utils import set_device -from neuralplayground.agents.domine_2023_extras_2.processing.Graph_generation import sample_random_graph, sample_target, sample_omniglot_graph, sample_random_graph_position +from neuralplayground.agents.domine_2023_extras_2.processing.Graph_generation import sample_graph, sample_target, sample_omniglot_graph, sample_fixed_graph from torchmetrics import Accuracy, Precision, AUROC, Recall, MatthewsCorrCoef from torchmetrics.classification import BinaryAccuracy @@ -25,10 +27,11 @@ def __init__(self, experiment_name="smaller size generalisation graph with no po train_on_shortest_path=True, resample=True, wandb_on=False, seed=41, dataset = 'random', weighted=True, num_hidden=100, num_layers=2, num_message_passing_steps=3, learning_rate=0.001, num_training_steps=10, residual=True, layer_norm=True, batch_size=4, num_features=4, num_nodes_max=7, - batch_size_test=4, num_nodes_min_test=4, num_nodes_max_test=[7], plot=True, **mod_kwargs): + batch_size_test=4, num_nodes_min_test=4, num_nodes_max_test=[7], plot=True, **mod_kwargs): super(Domine2023, self).__init__() # General + np.random.seed(seed) self.plot = plot self.obs_history = [] self.grad_history = [] @@ -54,9 +57,30 @@ def __init__(self, experiment_name="smaller size generalisation graph with no po self.num_nodes_max = num_nodes_max self.num_nodes_max_test = num_nodes_max_test + + def set_initial_seed(seed): + # Set the seed for NumPy + np.random.seed(seed) + + # Set the seed for PyTorch + torch.manual_seed(seed) + + # If using CUDA + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # If you are using multi-GPU + + # Optional: For deterministic behavior, you can use the following: + # This is usually needed for reproducibility with certain layers like convolution. + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + set_initial_seed(seed) + self.batch_size_test = batch_size_test self.arena_x_limits = mod_kwargs["arena_x_limits"] self.arena_y_limits = mod_kwargs["arena_y_limits"] + save_path = mod_kwargs["save_path"] self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if self.dataset == 'random': self.model = GCNModel(self.num_hidden, self.num_features + 2, self.num_layers, @@ -104,20 +128,7 @@ def __init__(self, experiment_name="smaller size generalisation graph with no po if self.wandb_on: wandb.log(self.wandb_logs) else: - dateTimeObj = datetime.now() - save_path = os.path.join(Path(os.getcwd()).resolve(), "results") - os.mkdir( - os.path.join( - save_path, - self.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), - ) - ) - self.save_path = os.path.join( - os.path.join( - save_path, - self.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), - ) - ) + self.save_path= save_path self.save_run_parameters() def save_run_parameters(self): @@ -132,25 +143,27 @@ def save_run_parameters(self): for file_name, source in files_to_copy: shutil.copyfile(os.path.join(Path(os.getcwd()).resolve(), source), os.path.join(self.save_path, file_name)) - def load_data(self, train, dataset, batch_size,num_nodes,seed): + def load_data(self, fixed, dataset, batch_size,num_nodes): # Initialize lists to store batch data node_features_batch, edges_batch, edge_features_batch, target_batch = [], [], [], [] # Determine the max nodes to use based on whether it is training or testing - # Loop to generate a batch of data for _ in range(batch_size): + if fixed: + node_features, edges, edge_features_tensor, source, sink = sample_fixed_graph(self.num_features, + num_nodes,feature_type= dataset) + else: # Handle Omniglot dataset - if dataset == 'omniglot': - node_features, edges, edge_features_tensor, source, sink = sample_omniglot_graph(num_nodes,seed) - # Handle Random dataset - elif dataset == 'random': - node_features, edges, edge_features_tensor, source, sink = sample_random_graph(self.num_features, - num_nodes,seed) - elif dataset == 'positional': - node_features, edges, edge_features_tensor, source, sink = sample_random_graph_position(self.num_features , - num_nodes,seed) + if dataset == 'omniglot': + node_features, edges, edge_features_tensor, source, sink = sample_omniglot_graph(num_nodes) + # Handle Random dataset + else: + node_features, edges, edge_features_tensor, source, sink = sample_graph(self.num_features, + num_nodes,feature_type= dataset) + + # Sample the target based on source and sink target = sample_target(source, sink) @@ -226,7 +239,6 @@ def evaluate(self,outputs,targets): outputs = Softmax(outputs) predicted_labels = torch.argmax(outputs, dim=1) # Outputs: [0, 1, 0, 1] # Initialize the BinaryAccuracy metric - # TODO : My accuracy is not working, the way it should be i think due to the labes aot acc = self.ACC(predicted_labels, labels) return acc @@ -243,15 +255,28 @@ def log_training(self, train_loss, val_loss, train_acc, val_acc): def train(self): # Load validation data - node_features_val, edges_val, edge_features_tensor_val, target_val = self.load_data(train=False, + val_graphs = {num_node: [] for num_node in self.num_nodes_max_test} + for i in range(len(self.num_nodes_max_test)): + node_features_val, edges_val, edge_features_tensor_val, target_val = self.load_data(fixed=False, dataset=self.dataset, - batch_size=self.batch_size, num_nodes= self.num_nodes_max_test[0],seed=self.seed) + batch_size=self.batch_size, num_nodes= self.num_nodes_max_test[i]) + val_graphs[self.num_nodes_max_test[i]] = [node_features_val, edges_val, edge_features_tensor_val, target_val] + + #for i in len(self.num_nodes_max_test): + # This is an attemp + node_features_val_f, edges_val_f, edge_features_tensor_val_f, target_val_f = self.load_data(fixed=True, + dataset=self.dataset, + batch_size=self.batch_size, + num_nodes= + self.num_nodes_max_test[0], + ) + # need to save the fixed one node_featur for epoch in range(self.num_training_steps): - train_losses, train_roc_auc, train_mcc = 0, 0, 0 - node_features, edges, edge_features_tensor, target = self.load_data(train=True, + + node_features, edges, edge_features_tensor, target = self.load_data(fixed=False, dataset=self.dataset, - batch_size=self.batch_size, num_nodes= self.num_nodes_max, seed=seed) + batch_size=self.batch_size, num_nodes= self.num_nodes_max) #Train on each batch batch_losses, batch_acc = self.update_step(node_features, edges, edge_features_tensor, target, train=True, batch_size=self.batch_size) @@ -266,17 +291,19 @@ def train(self): self.ACCs_train.append(train_acc) - + # Validation with torch.no_grad(): - - val_losses, val_acc = self.update_step(node_features_val, edges_val, + for num_node in self.num_nodes_max_test: + node_features_val, edges_val, edge_features_tensor_val, target_val = val_graphs[num_node] + val_losses, val_acc = self.update_step(node_features_val, edges_val, edge_features_tensor_val, target_val, train=False,batch_size=self.batch_size) - self.losses_val.append(val_losses) + self.losses_val[num_node].append(val_losses) + self.ACCs_val[num_node].append(val_acc) - self.ACCs_val.append(val_acc) # Log training details + #TODO: Need to update this self.log_training(train_losses, val_losses, train_acc, val_acc) # Plot progress every epoch @@ -288,37 +315,39 @@ def train(self): if self.plot: - os.mkdir(os.path.join(self.save_path, "results")) + os.makedirs(os.path.join(self.save_path, "results"), exist_ok=True) self.save_path = os.path.join(self.save_path, "results") file_name = f"Losses_{seed}.pdf" # Combine the path and file name + list_of_lists = [value for value in self.losses_val.values()] + list_of_list_name = [f'loss_val_len{value}' for value in self.losses_val] + # Append losses_train to the list of lists + list_of_lists.append(self.losses_train) + list_of_list_name.append('loss_train') plot_curves( - [ - self.losses_train, - self.losses_val], + list_of_lists , os.path.join(self.save_path, file_name ), "All_Losses", - legend_labels=["loss", "loss tesft"], + legend_labels=list_of_list_name, ) - file_name = f"ACCs_train_{seed}.pdf" - plot_curves( - [ - self.ACCs_train, - ], + + file_name = f"ACCs_val_{seed}.pdf" + plot_curves( [value for value in self.ACCs_val.values()], os.path.join(self.save_path, file_name), - "ACC Train", - legend_labels=["ACC Train"], + "ACC Val", + legend_labels=[f'ACC_val_len{value}' for value in self.losses_val], ) - file_name = f"ACCs_val_{seed}.pdf" + file_name = f"ACCs_train_{seed}.pdf" + plot_curves( [ - self.ACCs_val, + self.ACCs_train, ], os.path.join(self.save_path, file_name), - "ACC val", - legend_labels=["ACC val"], + "ACC train", + legend_labels=["ACC train"], ) @@ -349,9 +378,10 @@ def reset(self): self.grad_history = [] self.global_steps = 0 self.losses_train = [] - self.losses_val = [] + self.losses_val = {num_node: [] for num_node in self.num_nodes_max_test} + self.ACCs_val = {num_node: [] for num_node in self.num_nodes_max_test} self.ACCs_train = [] - self.ACCs_val = [] + return @@ -367,12 +397,25 @@ def reset(self): arena_x_limits = [-100, 100] arena_y_limits = [-100, 100] - seeds = [41, 42, 43, 45, 46 ] + seeds = [41,42] losses_train = {seed: [] for seed in seeds} losses_val = {seed: [] for seed in seeds} ACCs_train = {seed: [] for seed in seeds} ACCs_val = {seed: [] for seed in seeds} - + dateTimeObj = datetime.now() + save_path = os.path.join(Path(os.getcwd()).resolve(), "results") + os.mkdir( + os.path.join( + save_path, + config.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), + ) + ) + save_path = os.path.join( + os.path.join( + save_path, + config.experiment_name + dateTimeObj.strftime("%d%b_%H_%M_%S"), + ) + ) for seed in seeds: agent = Domine2023( experiment_name=config.experiment_name, @@ -398,26 +441,25 @@ def reset(self): residual=config.residual, layer_norm=config.layer_norm, plot=config.plot, + save_path = save_path ) + losse_train, ACC_train, losse_val, ACC_val = agent.train() losses_train[seed] = losse_train losses_val[seed] = losse_val ACCs_train[seed]= ACC_train ACCs_val[seed] = ACC_val + save_path = os.path.join(save_path, "results") num_training_steps = config.num_training_steps - import math - - num_training_steps = config.num_training_steps - # Initialize lists to store standard deviation results std_losses_train = [] std_accs_train = [] - std_losses_val = [] - std_accs_val = [] + # Compute average and standard deviation for training loss + avg_losses_train = [] for epoch_idx in range(num_training_steps): # Average the loss for this epoch over all seeds @@ -441,36 +483,53 @@ def reset(self): std_epoch_acc = math.sqrt(variance_acc) std_accs_train.append(std_epoch_acc) - # Compute average and standard deviation for validation loss - avg_losses_val = [] - for epoch_idx in range(num_training_steps): - avg_epoch_loss_val = sum(losses_val[seed][epoch_idx] for seed in seeds) / len(seeds) - avg_losses_val.append(avg_epoch_loss_val) - - # Compute standard deviation for this epoch - variance_loss_val = sum((losses_val[seed][epoch_idx] - avg_epoch_loss_val) ** 2 for seed in seeds) / len(seeds) - std_epoch_loss_val = math.sqrt(variance_loss_val) - std_losses_val.append(std_epoch_loss_val) - - # Compute average and standard deviation for validation accuracy - avg_accs_val = [] - for epoch_idx in range(num_training_steps): - avg_epoch_acc_val = sum(ACCs_val[seed][epoch_idx] for seed in seeds) / len(seeds) - avg_accs_val.append(avg_epoch_acc_val) - # Compute standard deviation for this epoch - variance_acc_val = sum((ACCs_val[seed][epoch_idx] - avg_epoch_acc_val) ** 2 for seed in seeds) / len(seeds) - std_epoch_acc_val = math.sqrt(variance_acc_val) - std_accs_val.append(std_epoch_acc_val) + # Compute average and standard deviation for validation loss + avg_losses_val_len = [] + std_losses_val_len = [] + for i in config.num_nodes_max_test: + avg_losses_val = [] + std_losses_val = [] + for epoch_idx in range(num_training_steps): + avg_epoch_loss_val = sum(losses_val[seed][i][epoch_idx] for seed in seeds) / len(seeds) + avg_losses_val.append(avg_epoch_loss_val) + variance_loss_val = sum( + (losses_val[seed][i][epoch_idx] - avg_epoch_loss_val) ** 2 for seed in seeds) / len(seeds) + std_epoch_loss_val = math.sqrt(variance_loss_val) + std_losses_val.append(std_epoch_loss_val) + avg_losses_val_len.append(avg_losses_val) + std_losses_val_len.append(std_losses_val) + + #Compute average and standard deviation for validation accuracy + avg_accs_val_len = [] + std_accs_val_len = [] + for i in config.num_nodes_max_test: + avg_accs_val = [] + std_accs_val = [] + for epoch_idx in range(num_training_steps): + avg_epoch_acc_val = sum(ACCs_val[seed][i][epoch_idx] for seed in seeds) / len(seeds) + avg_accs_val.append(avg_epoch_acc_val) + + # Compute standard deviation for this epoch + variance_acc_val = sum((ACCs_val[seed][i][epoch_idx] - avg_epoch_acc_val) ** 2 for seed in seeds) / len(seeds) + std_epoch_acc_val = math.sqrt(variance_acc_val) + std_accs_val.append(std_epoch_acc_val) + avg_accs_val_len.append(avg_accs_val) + std_accs_val_len.append(std_accs_val) + + + + list_of_list_name = [f'loss_val_len{value}' for value in losses_val[seed]] + # Append losses_train to the list of lists + avg_losses_val_len.append(avg_losses_train) + list_of_list_name.append('loss_train') + std_losses_val_len.append(std_accs_train) - save_path = os.path.join(Path(os.getcwd()).resolve(), "results") plot_curves_2( - [ - avg_losses_train, - avg_losses_val],[std_accs_train,std_accs_val], + avg_losses_val_len,std_losses_val_len, os.path.join(save_path, "Losses.pdf"), "All_Losses", - legend_labels=["loss", "loss tesft"], + legend_labels= list_of_list_name, ) plot_curves_2( [ @@ -481,19 +540,18 @@ def reset(self): legend_labels=["ACC Train"], ) + list_of_list_name = [f'loss_val_len{value}' for value in losses_val[seed]] plot_curves_2( - [ - avg_accs_val, - ],[std_accs_val], os.path.join(save_path, "ACCs_val.pdf"), + avg_accs_val_len,std_accs_val_len, os.path.join(save_path, "ACCs_val.pdf"), "ACC val", - legend_labels=["ACC val"], + legend_labels=list_of_list_name, ) print() - #TODO: They all have different evaluation ( netwokr ) do we want ot eva - #TODO: How to actually get the seed to be fixed for each run + #TODO: They all have different evaluation ( netwokr ) do we want ot eval ( for the average it should be ifne) #TODO: Think about nice visualisaiton - # I need to check the saving and the logging of the results + #TODO: update the plotting for the other curves + # I need to check the logging of the results # TODO. : plan a set of experiements to run, sudy how different initialisiton #TODO: Get a different set of valisaion lenght for each run diff --git a/neuralplayground/agents/domine_2023_extras_2/config.yaml b/neuralplayground/agents/domine_2023_extras_2/config.yaml index 8dc7ebb..21a38fc 100644 --- a/neuralplayground/agents/domine_2023_extras_2/config.yaml +++ b/neuralplayground/agents/domine_2023_extras_2/config.yaml @@ -3,14 +3,14 @@ resample: False # @param wandb_on: False seed: 45 plot: True -dataset: 'random' # 'random' or 'omniglot' or position +dataset: 'positional' # 'random' or 'omniglot' or positional'positional_no_edges weighted: False num_hidden: 15 # @param num_layers: 1 # @param -num_message_passing_steps: 4 # @param -learning_rate: 0.001 # @param -num_training_steps: 3000 # @param +num_message_passing_steps: 0 # @param +learning_rate: 0.005 # @param +num_training_steps: 80 # @param residual: True layer_norm: False @@ -21,5 +21,5 @@ num_nodes_min: 5 num_features: 1 batch_size_test: 2 -num_nodes_max_test: [40] +num_nodes_max_test: [6,7,8,9] #min2 num_nodes_min_test: 10 \ No newline at end of file diff --git a/neuralplayground/agents/domine_2023_extras_2/models/GCN_model.py b/neuralplayground/agents/domine_2023_extras_2/models/GCN_model.py index 83b49ca..4418d32 100644 --- a/neuralplayground/agents/domine_2023_extras_2/models/GCN_model.py +++ b/neuralplayground/agents/domine_2023_extras_2/models/GCN_model.py @@ -23,7 +23,7 @@ def __init__(self, num_hidden, num_feature, num_layers, num_message_passing_step def forward(self, node, edges, edges_attr): x, edge_index = node, edges x = self.conv_1(x, edge_index, edges_attr) - + x = torch.relu(x) for i, conv in enumerate(self.conv_layers): x_res = x x = conv(x, edge_index, edges_attr) diff --git a/neuralplayground/agents/domine_2023_extras_2/processing/Graph_generation.py b/neuralplayground/agents/domine_2023_extras_2/processing/Graph_generation.py index d887bd8..0dd2fc4 100644 --- a/neuralplayground/agents/domine_2023_extras_2/processing/Graph_generation.py +++ b/neuralplayground/agents/domine_2023_extras_2/processing/Graph_generation.py @@ -5,7 +5,7 @@ from torchvision import transforms from neuralplayground.agents.domine_2023_extras.class_utils import rng_sequence_from_rng -def create_random_matrix(rows, cols,seed, low=0, high=1): +def create_random_matrix(rows, cols, low=0, high=1): """ Generates a random matrix with the specified dimensions. Parameters: @@ -16,6 +16,7 @@ def create_random_matrix(rows, cols,seed, low=0, high=1): Returns: numpy.ndarray: A matrix of shape (rows, cols) with random values. """ + return np.random.uniform(low, high, (rows, cols)) def get_omniglot_items(n): @@ -114,23 +115,11 @@ def generate_source_and_sink(num_nodes): sink = np.random.randint(0, num_nodes) return source, sink -def sample_random_graph(num_features, num_nodes,seed): - # This is a graph with edges feature and random features - node_features = torch.tensor(create_random_matrix(num_nodes,num_features,seed)) - edges , edge_features_tensor = create_line_graph_edge_list_with_features(num_nodes) - input_node_features = np.zeros((int(num_nodes), 2)) - sink, source = generate_source_and_sink(num_nodes) - input_node_features[source, 0] = 1 # Set source node feature - input_node_features[sink, 1] = 1 # Set sink node feature - # Concatenate the feature matrices along the feature dimension (axis=1) - combined_node_features = np.concatenate([node_features, input_node_features], axis=1) - # Convert combined node features back to a tensor - node_features = torch.tensor(combined_node_features, dtype=torch.float32) - return node_features, edges, edge_features_tensor, source, sink + #TODO: we need to merge this two potentially -def sample_omniglot_graph(num_nodes,seed): +def sample_omniglot_graph(num_nodes): # This is a graph with edges feature and omniglot features node_features = torch.tensor(get_omniglot_items(num_nodes)) edges , edge_features_tensor = create_line_graph_edge_list_with_features(num_nodes) @@ -144,34 +133,74 @@ def sample_omniglot_graph(num_nodes,seed): node_features = torch.tensor(combined_node_features, dtype=torch.float32) return node_features, edges, edge_features_tensor, source, sink -def sample_random_graph_position(num_features, num_nodes,seed): - # This is a graph with edges feature and position features - node_features = torch.tensor(create_random_matrix(num_nodes,num_features)) - edges , edge_features_tensor = create_line_graph_edge_list_with_features(num_nodes) - input_node_features = np.zeros((int(num_nodes), 2)) + +def sample_graph(num_features, num_nodes, feature_type='random'): + """ + Generate a sample graph with different feature types: 'random', 'positional', or 'positional_no_edges'. + + Parameters: + - num_features: Number of features for each node. + - num_nodes: Number of nodes in the graph. + - feature_type: Type of features to include ('random', 'positional', or 'positional_no_edges'). + + Returns: + - node_features: Tensor of node features. + - edges: Edge list tensor. + - edge_features_tensor (optional): Tensor of edge features if feature_type is 'random' or 'positional'. + - source: Source node. + - sink: Sink node. + """ + # Generate base node features and input node features + node_features = torch.tensor(create_random_matrix(num_nodes, num_features)) + edges, edge_features_tensor = create_line_graph_edge_list_with_features(num_nodes) + input_node_features = np.zeros((num_nodes, 2)) sink, source = generate_source_and_sink(num_nodes) input_node_features[source, 0] = 1 # Set source node feature input_node_features[sink, 1] = 1 # Set sink node feature - # Concatenate the feature matrices along the feature dimension (axis=1) + + # Combine node features and input features combined_node_features = np.concatenate([node_features, input_node_features], axis=1) - position = torch.tensor([np.arange(0, num_nodes)]) - combined_node_features_pos = np.concatenate([combined_node_features, position.T], axis=1) - # Convert combined node features back to a tensor - node_features = torch.tensor(combined_node_features_pos , dtype=torch.float32) - return node_features, edges, edge_features_tensor, source, sink -def sample_random_graph_position_no_edges(num_features, num_nodes,seed): - # This is a graph with no edges feature but position features - node_features = torch.tensor(create_random_matrix(num_nodes,num_features)) - edges , edge_features_tensor = create_line_graph_edge_list_with_features(num_nodes) - input_node_features = np.zeros((int(num_nodes), 2)) - sink, source = generate_source_and_sink(num_nodes) + # Append position features if specified + if feature_type == 'positional' or feature_type == 'positional_no_edges': + position = torch.tensor(np.arange(num_nodes)).unsqueeze(1) # Shape: (num_nodes, 1) + combined_node_features = np.concatenate([combined_node_features, position], axis=1) + + # Convert combined node features to a tensor + node_features = torch.tensor(combined_node_features, dtype=torch.float32) + + # Return based on feature_type + if feature_type == 'positional_no_edges': + return node_features, edges, source, sink # No edge features + else: + return node_features, edges, edge_features_tensor, source, sink + +def sample_fixed_graph(num_features, num_nodes, feature_type='random', sositype='random'): + #Generate base node features and input node features + node_features = torch.tensor([[0.54657073, 0.96430735, 0.06389329, 0.38357556, 0.96802482, + 0.12043292]]) + edges, edge_features_tensor = create_line_graph_edge_list_with_features(num_nodes) + input_node_features = np.zeros((num_nodes, 2)) + sink = 1 + source = 2 input_node_features[source, 0] = 1 # Set source node feature input_node_features[sink, 1] = 1 # Set sink node feature - # Concatenate the feature matrices along the feature dimension (axis=1) - combined_node_features = np.concatenate([node_features, input_node_features], axis=1) - # Convert combined node features back to a tensor + + # Combine node features and input features + combined_node_features = np.concatenate([node_features.T, input_node_features], axis=1) + # Append position features if specified + if feature_type == 'positional' or feature_type == 'positional_no_edges': + position = torch.tensor(np.arange(num_nodes)).unsqueeze(1) # Shape: (num_nodes, 1) + combined_node_features = np.concatenate([combined_node_features, position], axis=1) + + # Convert combined node features to a tensor node_features = torch.tensor(combined_node_features, dtype=torch.float32) - return node_features, edges, source, sink + + # Return based on feature_type + if feature_type == 'positional_no_edges': + return node_features, edges, source, sink # No edge features + else: + return node_features, edges, edge_features_tensor, source, sink + #TODO: we need to merge this into one function because this is ungly