Skip to content

Commit

Permalink
task
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementineDomine committed Oct 15, 2024
1 parent a664702 commit 2085fe7
Show file tree
Hide file tree
Showing 7 changed files with 305 additions and 395 deletions.
131 changes: 71 additions & 60 deletions neuralplayground/agents/domine_2023_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
import torch.optim as optim
import wandb
import numpy as np
from sklearn.metrics import matthews_corrcoef, roc_auc_score
from neuralplayground.agents.agent_core import AgentCore
from neuralplayground.agents.domine_2023_extras_2.utils.plotting_utils import plot_curves
from neuralplayground.agents.domine_2023_extras_2.utils.plotting_utils import plot_curves, plot_2dgraphs
from neuralplayground.agents.domine_2023_extras_2.models.GCN_model import GCNModel
from neuralplayground.agents.domine_2023_extras_2.class_grid_run_config import GridConfig
from neuralplayground.agents.domine_2023_extras_2.utils.utils import set_device
from neuralplayground.agents.domine_2023_extras_2.processing.Graph_generation import sample_graph
from neuralplayground.agents.domine_2023_extras_2.processing.Graph_generation import sample_graph, sample_target
from torchmetrics import Accuracy, Precision, AUROC, Recall, MatthewsCorrCoef

# from neuralplayground.agents.domine_2023_extras_2.evaluate import Evaluator
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"

Expand All @@ -26,8 +26,8 @@ class Domine2023(AgentCore):
def __init__(self, experiment_name="smaller size generalisation graph with no position feature",
train_on_shortest_path=True, resample=True, wandb_on=False, seed=41, feature_position=False,
weighted=True, num_hidden=100, num_layers=2, num_message_passing_steps=3, learning_rate=0.001,
num_training_steps=10, residual=True, layer_norm=True, batch_size=4, nx_min=4, nx_max=7,
batch_size_test=4, nx_min_test=4, nx_max_test=7, plot=True, **mod_kwargs):
num_training_steps=10, residual=True, layer_norm=True, batch_size=4, num_features=4, num_nodes_max=7,
batch_size_test=4, num_nodes_min_test=4, num_nodes_max_test=7, plot=True, **mod_kwargs):
super(Domine2023, self).__init__()

# General
Expand All @@ -53,19 +53,19 @@ def __init__(self, experiment_name="smaller size generalisation graph with no po
#self.resample = resample
self.feature_position = feature_position
self.weighted = weighted
self.nx_min = nx_min
self.nx_max = nx_max
self.num_features = num_features
self.num_nodes_max = num_nodes_max
self.num_nodes_max_test = num_nodes_max_test

self.batch_size_test = batch_size_test
self.nx_min_test = nx_min_test
self.nx_max_test = nx_max_test
self.arena_x_limits = mod_kwargs["arena_x_limits"]
self.arena_y_limits = mod_kwargs["arena_y_limits"]

self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model = GCNModel(self.num_hidden, self.num_layers, self.num_message_passing_steps, self.residual,
self.model = GCNModel(self.num_hidden, self.num_features+2, self.num_layers, self.num_message_passing_steps, self.residual,
self.layer_norm).to(self.device)
self.auroc = AUROC(task="binary")
self.MCC = MatthewsCorrCoef(task='binary')

self.optimizer = optim.Adam(self.model.parameters(), lr=self.learning_rate)
self.criterion = nn.MSELoss()

Expand All @@ -79,12 +79,9 @@ def __init__(self, experiment_name="smaller size generalisation graph with no po
self.save_path = os.path.join(save_path, "results")

self.reset()
self.wandb_logs = {
"nx_min_test": self.nx_min_test, # This is thought of the state density
"nx_max_test": self.nx_max_test, # This is thought of the state density
self.wandb_logs = { # This is thought of the state density
"batch_size": self.batch_size,
"nx_min": self.nx_min, # This is thought of the state density
"nx_max": self.nx_max,
"num_node_min": self.num_nodes_max, # This is thought of the state density
"seed": self.seed,
"feature_position": self.feature_position,
"weighted": self.weighted,
Expand All @@ -96,7 +93,6 @@ def __init__(self, experiment_name="smaller size generalisation graph with no po
"residual": self.residual,
"layer_norm": self.layer_norm,
}

if self.wandb_on:
wandb.log(self.wandb_logs)
else:
Expand Down Expand Up @@ -128,45 +124,45 @@ def save_run_parameters(self):
for file_name, source in files_to_copy:
shutil.copyfile(os.path.join(Path(os.getcwd()).resolve(), source), os.path.join(self.save_path, file_name))


def load_data(self,train):
if train:
node, adj = sample_graph(train=True)
node_features, edges, edge_features_tensor, source, sink = sample_graph(self.num_features, self.num_nodes_max)
target = sample_target(source, sink)
else:
node, adj = sample_graph(train=False)
return node, adj

node_features, edges, edge_features_tensor, source, sink = sample_graph(self.num_features, self.num_nodes_max_test)
target = sample_target(source, sink)
return node_features, edges, edge_features_tensor, target

def compute_loss(self, outputs, targets):
loss = self.criterion(outputs, targets)
return loss

def run_model(self, node, edges):
outputs = self.model(node,edges)
def run_model(self, node, edges,edges_features):
outputs = self.model(node,edges,edges_features)
return outputs

def update_step(self,node, edges,target,train):
def update_step(self,node, edges,edges_features,target,train):
data = node.to(self.device)
edges = edges.to(self.device)
edges_features = edges_features.to(self.device)
if train:
self.model.train()
self.optimizer.zero_grad()
else:
self.model.eval()
outputs = self.run_model(data,edges)
outputs = self.run_model(data,edges, edges_features)
loss = self.compute_loss(outputs,target)
if train:
loss.backward()
self.optimizer.step()
roc_auc, mcc = self.evaluate(outputs,target)
return loss,roc_auc, mcc
return loss, roc_auc, mcc

def evaluate(self,outputs,targets):
with (torch.no_grad()):
roc_auc = self.auroc(outputs.to(self.device), targets.to(self.device))
# roc_auc_score(targets.cpu(), outputs.cpu())
# mcc = MatthewsCorrCoef(outputs.cpu().round(), targets.cpu().round())
mcc = 1
mcc = self.MCC(outputs, targets)
return roc_auc, mcc

def log_training(self, train_loss, val_loss, train_roc_auc, val_roc_auc, train_mcc, val_mcc):
Expand All @@ -183,17 +179,17 @@ def log_training(self, train_loss, val_loss, train_roc_auc, val_roc_auc, train_m
wandb.log(wandb_logs)

def train(self):
node_features_val, edges_val, edge_features_tensor_val, target_val = self.load_data(train=False)
for epoch in range(self.num_training_steps):

nodes, edges = self.load_data(train=True)
target = nodes
train_losses, train_roc_auc, train_mcc = self.update_step(nodes, edges, target ,train=True)
#node_features, edges, edge_features_tensor, target = self.load_data(train=True)
node_features, edges, edge_features_tensor, target = self.load_data(train=True)
train_losses, train_roc_auc, train_mcc = self.update_step(node_features, edges, edge_features_tensor, target ,train=True)
self.losses_train.append(train_losses.detach().numpy() )
self.MCCs_train.append(train_mcc)
self.roc_aucs_train.append(train_roc_auc.detach().numpy() )
nodes_val, edges_val = self.load_data(train=False)
#node_features_val, edges_val, edge_features_tensor_val, target_val = self.load_data(train=False)
with torch.no_grad():
val_losses, val_roc_auc, val_mcc = self.update_step(nodes_val,edges_val,target, train=False)
val_losses, val_roc_auc, val_mcc = self.update_step(node_features_val,edges_val,edge_features_tensor_val,target_val, train=False)
self.losses_val.append(val_losses.detach().numpy() )
self.MCCs_val.append(val_mcc)
self.roc_aucs_val.append(val_roc_auc.detach().numpy() )
Expand All @@ -204,49 +200,67 @@ def train(self):
f"Training step {self.global_steps}: log_loss = {np.log(train_losses.detach().numpy() )} , log_loss_test = {np.log(val_losses.detach().numpy() )}, roc_auc_test = {val_roc_auc}, roc_auc_train = {train_roc_auc}"
)
print("Finished training")
train_losses, train_roc_auc, train_mcc = self.update_step(node_features, edges, edge_features_tensor, target,
train=True)
if self.plot:
os.mkdir(os.path.join(self.save_path, "results"))
self.save_path = os.path.join(self.save_path, "results")
plot_curves(
[
self.losses_train,
self.losses_val],
os.path.join(self.save_path, "Losses.pdf"),
"All_Losses",
legend_labels=["loss", "loss test"],
legend_labels=["loss", "loss tesft"],
)
plot_curves(
[
self.MCCs_train,
],
os.path.join(self.save_path, "MCCs_train.pdf"),
"All_Losses",
"MCC Train",
legend_labels=["MCC Train"],
)
plot_curves(
[
self.roc_aucs_train,
],
os.path.join(self.save_path, "MCCs_train.pdf"),
"All_Losses",
legend_labels=["MCC Train"],
os.path.join(self.save_path, "AUROC_train.pdf"),
"AUROC Train",
legend_labels=["AUROC Train"],
)
plot_curves(
[
self.MCCs_val,
],
os.path.join(self.save_path, "MCCs_train.pdf"),
"All_Losses",
legend_labels=["MCC Train"],
os.path.join(self.save_path, "MCCs_val.pdf"),
"MCC val",
legend_labels=["MCC val"],
)
plot_curves(
[
self.roc_aucs_train,
],
os.path.join(self.save_path, "MCCs_train.pdf"),
"All_Losses",
legend_labels=["MCC Train"],
)
return

def sample_and_store(n):
# Initialize empty lists to store each sample's output
node_features_list = []
edges_list = []
edge_features_tensor_list = []
target_list = []
# Loop n times to sample data and store the outputs
for _ in range(n):
# Sample data by calling load_data
node_features, edges, edge_features_tensor, target = self.load_data(train=True)

# Append the results to the corresponding lists
node_features_list.append(node_features)
edges_list.append(edges)
edge_features_tensor_list.append(edge_features_tensor)
target_list.append(target)
# Return the lists of sampled data
return node_features_list, edges_list, edge_features_tensor_list, target_list

n=2
node_features_list, edges_list, edge_features_tensor_list, target_list = sample_and_store(n)
plot_2dgraphs(edges_list, node_features_list, edge_features_tensor_list,['',''], os.path.join(self.save_path, "graph.pdf"), colorscale='Plasma',size=5,show=True)
return
def reset(self):
self.obs_history = []
self.grad_history = []
Expand Down Expand Up @@ -274,7 +288,6 @@ def reset(self):

agent = Domine2023(
experiment_name=config.experiment_name,
train_on_shortest_path=config.train_on_shortest_path,
resample=config.resample,
wandb_on=config.wandb_on,
seed=config.seed,
Expand All @@ -286,19 +299,17 @@ def reset(self):
learning_rate=config.learning_rate,
num_training_steps=config.num_training_steps,
batch_size=config.batch_size,
nx_min=config.nx_min,
nx_max=config.nx_max,
num_features=config.num_features,
num_nodes_max=config.num_nodes_max,
num_nodes_min=config.num_nodes_min,
batch_size_test=config.batch_size_test,
nx_min_test=config.nx_min_test,
nx_max_test=config.nx_max_test,
num_nodes_min_test=config.num_nodes_min_test,
num_nodes_max_test=config.num_nodes_max_test,
arena_y_limits=arena_y_limits,
arena_x_limits=arena_x_limits,
residual=config.residual,
layer_norm=config.layer_norm,
grid=config.grid,
plot=config.plot,
dist_cutoff=config.dist_cutoff,
n_std_dist_cutoff=config.n_std_dist_cutoff,
)

agent.train()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ layer_norm: True

# Env Stuff
batch_size: 2
nx_min: 5
nx_max: 6
num_nodes_max: 5
num_nodes_min: 5
num_features: 6

batch_size_test: 4
nx_min_test: 5
nx_max_test: 6
num_nodes_max_test: 5
num_nodes_min_test: 5
Loading

0 comments on commit 2085fe7

Please sign in to comment.