Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementineDomine committed Oct 1, 2023
1 parent ec531b4 commit 058271c
Show file tree
Hide file tree
Showing 152 changed files with 1,270 additions and 1,250 deletions.
124 changes: 62 additions & 62 deletions neuralplayground/agents/class.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
from class_plotting_utils import (
plot_graph_grid_activations,
plot_input_target_output,
plot_xy,
plot_message_passing_layers,

plot_xy,
)
from sklearn.metrics import matthews_corrcoef, roc_auc_score
from class_utils import rng_sequence_from_rng, set_device
from sklearn.metrics import matthews_corrcoef, roc_auc_score

# @title Graph net functions
parser = argparse.ArgumentParser()
Expand All @@ -35,29 +34,29 @@
help="path to base configuration file.",
)

class Stachenfel2023(AgentCore):

class Stachenfel2023(AgentCore):
def __init__(self, config_path, config, **mod_kwargs):
self.train_on_shortest_path = config.train_on_shortest_path
# @param
super().__init__()
self.experiment_name =config.experiment_name
self.train_on_shortest_path =config.train_on_shortest_path
self.resample = config.resample # @param
self.wandb_on =config.wandb_on
self.seed =config.seed
self.experiment_name = config.experiment_name
self.train_on_shortest_path = config.train_on_shortest_path
self.resample = config.resample # @param
self.wandb_on = config.wandb_on
self.seed = config.seed

self.feature_position = config.feature_position
self.weighted = config.weighted

self.num_hidden = config.num_hidden# @param
self.num_layers = config.num_layers# @param
self.num_message_passing_steps = config.num_training_steps # @param
self.learning_rate = config.learning_rate # @param
self.num_training_steps = config.num_training_steps # @param
self.num_hidden = config.num_hidden # @param
self.num_layers = config.num_layers # @param
self.num_message_passing_steps = config.num_training_steps # @param
self.learning_rate = config.learning_rate # @param
self.num_training_steps = config.num_training_steps # @param

self.batch_size = config.batch_size
self.nx_min= config.nx_min
self.nx_min = config.nx_min
self.nx_max = config.nx_max
self.arena_x_limits = mod_kwargs["arena_x_limits"]
self.arena_y_limits = mod_kwargs["arena_y_limits"]
Expand All @@ -66,29 +65,29 @@ def __init__(self, config_path, config, **mod_kwargs):

# This can be tought of the brain making different rep of different granularity
# Could be explained during sleep
self.batch_size_test= config.batch_size_test
self.nx_min_test = config.nx_min_test #This is thought of the state density
self.nx_max_test = config.nx_max_test #This is thought of the state density
self.batch_size_test = config.batch_size_test
self.nx_min_test = config.nx_min_test # This is thought of the state density
self.nx_max_test = config.nx_max_test # This is thought of the state density
self.batch_size = config.batch_size
self.nx_min = config.nx_min #This is thought of the state density
self.nx_max = config.nx_max #This is thought of the state density
self.nx_min = config.nx_min # This is thought of the state density
self.nx_max = config.nx_max # This is thought of the state density

#TODO: Make sure that for different graph this changes with the environement
#self.ny_min_test = config.ny_min_test # This is thought of the state density
#self.ny_max_test = config.ny_max_test # This is thought of the state density
#self.ny_min = con
# TODO: Make sure that for different graph this changes with the environement
# self.ny_min_test = config.ny_min_test # This is thought of the state density
# self.ny_max_test = config.ny_max_test # This is thought of the state density
# self.ny_min = con
# fig.ny_min # This is thought of the state density
#self.ny_max = config.ny_max # This is thought of the state density
# self.ny_max = config.ny_max # This is thought of the state density

#self.resolution_x_min_test = int(self.nx_min * self.room_width)
#self.resolution_x_max_test = int(self.nx_max * self.room_depth)
#self.resolution_x_min = int(self.nx_min_test * self.room_width)
#self.resolution_x_max = int(self.nx_max_test * self.room_depth)
# self.resolution_x_min_test = int(self.nx_min * self.room_width)
# self.resolution_x_max_test = int(self.nx_max * self.room_depth)
# self.resolution_x_min = int(self.nx_min_test * self.room_width)
# self.resolution_x_max = int(self.nx_max_test * self.room_depth)

#self.resolution_y_min_test = int(self.nx_min * self.room_width)
#self.resolution_y_max_test = int(self.nx_max * self.room_depth)
#self.resolution_y_min = int(self.nx_min_test * self.room_width)
#self.resolution_y_max = int(self.nx_max_test * self.room_depth)
# self.resolution_y_min_test = int(self.nx_min * self.room_width)
# self.resolution_y_max_test = int(self.nx_max * self.room_depth)
# self.resolution_y_min = int(self.nx_min_test * self.room_width)
# self.resolution_y_max = int(self.nx_max_test * self.room_depth)

self.log_every = config.num_training_steps // 10

Expand All @@ -98,8 +97,9 @@ def __init__(self, config_path, config, **mod_kwargs):
self.edege_lables = False
if self.wandb_on:
dateTimeObj = datetime.now()
wandb.init(project="graph-brain", entity="graph-brain",
name="Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M"))
wandb.init(
project="graph-brain", entity="graph-brain", name="Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M")
)
self.wandb_logs = {}
save_path = wandb.run.dir
os.mkdir(os.path.join(save_path, "results"))
Expand Down Expand Up @@ -145,37 +145,33 @@ def __init__(self, config_path, config, **mod_kwargs):
graph, targets = sample_padded_grid_batch_shortest_path(
self.rng, self.batch_size, self.feature_position, self.weighted, self.nx_min, self.nx_max
)
targets = graph.nodes
self.params = self.net_hk.init(self.rng, graph)
self.optimizer = optax.adam(self.learning_rate)
self.opt_state = self.optimizer.init(self.params)

self.reset()

def reset(self):
#TODO: Actually reset the network
self.global_test=0
# TODO: Actually reset the network
self.global_test = 0
self.losses = []
self.losses_test = []
self.roc_aucs_train = []
self.MCCs_train = []
self.MCCs_test = []
self.roc_aucs_test = []


def compute_loss(self, params, model, inputs, targets):
# not jitted because it will get jitted in jax.value_and_grad
outputs = model.apply(params, inputs)
return jnp.mean((outputs[0].nodes - targets) ** 2) # using MSE


def update_step(self,grads, opt_state, params,optimizer):
def update_step(self, grads, opt_state, params, optimizer):
updates, opt_state = optimizer.update(grads, opt_state, params)
params = optax.apply_updates(params, updates)
return params


def evaluate(self,model, params, inputs, target):
def evaluate(self, model, params, inputs, target):
outputs = model.apply(params, inputs)
roc_auc = roc_auc_score(np.squeeze(target), np.squeeze(outputs[0].nodes))
MCC = matthews_corrcoef(np.squeeze(target), round(np.squeeze(outputs[0].nodes)))
Expand All @@ -200,7 +196,9 @@ def update(self):
)
targets = graph.nodes
# Train
loss, grads = jax.value_and_grad(self.compute_loss)(self.params, self.net_hk, graph, targets) # jits inside of value_and_grad
loss, grads = jax.value_and_grad(self.compute_loss)(
self.params, self.net_hk, graph, targets
) # jits inside of value_and_grad
self.params = self.update_step(grads, self.opt_state, self.params, self.optimizer)
self.losses.append(loss)
outputs_train, roc_auc_train, MCC_train = self.evaluate(self.net_hk, self.params, graph, targets)
Expand All @@ -217,7 +215,7 @@ def update(self):
wandb_logs = {"loss": loss, "losses_test": loss_test, "roc_auc_test": roc_auc_test, "roc_auc": roc_auc_train}
if self.wandb_on:
wandb.log(wandb_logs)
self.global_steps = self.global_steps+1
self.global_steps = self.global_steps + 1
if self.global_steps % self.log_every == 0:
print(f"Training step {n}: loss = {loss}")
return
Expand All @@ -226,8 +224,7 @@ def print_and_plot(self):
# EVALUATE
rng = next(self.rng_seq)
graph_test, target_test = sample_padded_grid_batch_shortest_path(
rng, self.batch_size_test, self.feature_position, self.weighted, self.nx_min_test,
self.nx_max_test
rng, self.batch_size_test, self.feature_position, self.weighted, self.nx_min_test, self.nx_max_test
)
outputs, roc_auc, MCC = self.evaluate(self.net_hk, self.params, graph_test, target_test)
print("roc_auc_score")
Expand Down Expand Up @@ -292,7 +289,9 @@ def print_and_plot(self):
"Inputs node assigments",
self.edege_lables,
)
plot_graph_grid_activations(target_test.sum(-1), graph_test, os.path.join(self.save_path, "Target.pdf"), "Target", self.edege_lables)
plot_graph_grid_activations(
target_test.sum(-1), graph_test, os.path.join(self.save_path, "Target.pdf"), "Target", self.edege_lables
)

plot_graph_grid_activations(
outputs[0].nodes.tolist(),
Expand All @@ -317,7 +316,7 @@ def print_and_plot(self):


if __name__ == "__main__":
from neuralplayground.arenas import BatchEnvironment, Simple2D, DiscreteObjectEnvironment
from neuralplayground.arenas import Simple2D

args = parser.parse_args()
set_device()
Expand All @@ -329,23 +328,24 @@ def print_and_plot(self):
# Init environment
arena_x_limits = [-100, 100]
arena_y_limits = [-100, 100]
env = Simple2D(time_step_size=time_step_size,
agent_step_size=agent_step_size,
arena_x_limits= arena_x_limits,
arena_y_limits=arena_y_limits)

agent= Stachenfel2023(config_path=args.config_path, config=config,arena_y_limits=arena_y_limits,arena_x_limits=arena_x_limits)
env = Simple2D(
time_step_size=time_step_size,
agent_step_size=agent_step_size,
arena_x_limits=arena_x_limits,
arena_y_limits=arena_y_limits,
)

agent = Stachenfel2023(
config_path=args.config_path, config=config, arena_y_limits=arena_y_limits, arena_x_limits=arena_x_limits
)
for n in range(config.num_training_steps):
agent.update()

agent.print_and_plot()


#TODO: Make it work with the config Run manadger
# TODO: Make it work with the config Run manadger
# The other alternative is to see that we have multiple env that we resample every time
#TODO: Make juste an env type (so that is accomodates for not only 2 d env// different transmats)
#TODO: Make The plotting in the general plotting utilse
#TODO: Solve the jitt issue



# TODO: Make juste an env type (so that is accomodates for not only 2 d env// different transmats)
# TODO: Make The plotting in the general plotting utilse
# TODO: Solve the jitt issue
1 change: 0 additions & 1 deletion neuralplayground/agents/class_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,3 @@ nx_max: 7
batch_size_test: 4
nx_min_test: 4
nx_max_test: 7

2 changes: 1 addition & 1 deletion neuralplayground/agents/class_grid_run_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Dict, Union

from config_manager import base_configuration
from class_config_template import ConfigTemplate
from config_manager import base_configuration


class GridConfig(base_configuration.BaseConfiguration):
Expand Down
3 changes: 1 addition & 2 deletions neuralplayground/agents/class_plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# @title Make rng sequence generator
import matplotlib.pyplot as plt
import networkx as nx
from class_utils import convert_jraph_to_networkx_graph,get_activations_graph_n,get_node_pad
from class_utils import convert_jraph_to_networkx_graph, get_activations_graph_n, get_node_pad


def plot_input_target_output(inputs, targets, outputs, graph, n, edege_lables, save_path):
Expand Down Expand Up @@ -140,7 +140,6 @@ def plot_message_passing_layers_units(
plt.savefig(save_path)



def plot_xy(auc_roc, path, title):
fig = plt.figure(figsize=(5, 3))
ax = fig.add_subplot(111)
Expand Down
2 changes: 1 addition & 1 deletion neuralplayground/agents/class_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
plot_message_passing_layers,
plot_xy,
)
from sklearn.metrics import matthews_corrcoef, roc_auc_score
from class_utils import rng_sequence_from_rng, set_device
from sklearn.metrics import matthews_corrcoef, roc_auc_score

# @title Graph net functions
parser = argparse.ArgumentParser()
Expand Down
5 changes: 4 additions & 1 deletion neuralplayground/agents/class_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def set_device():
else:
print("GPU is enabled in this notebook.")


def get_activations_graph_n(
node_colour,
graph,
Expand All @@ -65,10 +66,12 @@ def get_activations_graph_n(
output = node_colour[node_padd : node_padd + graph.n_node[number_graph_batch]]
return output


# maybe actually change the node pad to a node padd function


def get_node_pad(graph, i):
node_padd = 0
for j in range(i):
node_padd = node_padd + graph.n_node[j]
return node_padd
return node_padd
Loading

0 comments on commit 058271c

Please sign in to comment.