Skip to content

Commit

Permalink
new task representation
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementineDomine committed Nov 2, 2023
1 parent 3a99009 commit 0a8f2d9
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 155 deletions.
83 changes: 11 additions & 72 deletions neuralplayground/agents/domine_2023.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,18 +95,16 @@ def __init__( # autogenerated
self.room_depth = np.diff(self.arena_y_limits)[0]
self.agent_step_size = 0



self.log_every = num_training_steps // 10
if self.weighted:
self.edge_lables = True
else:
self.edge_lables = False
self.edge_lables = True

if self.wandb_on:
dateTimeObj = datetime.now()
wandb.init(
project="graph-brain",
project="graph-test",
entity="graph-brain",
name="Grid_shortest_path" + dateTimeObj.strftime("%d%b_%H_%M_%S"),
)
Expand Down Expand Up @@ -282,13 +280,13 @@ def update(self):
rng = next(self.rng_seq)
if self.train_on_shortest_path:
graph_test, target_test = sample_padded_grid_batch_shortest_path(
rng,
self.batch_size_test,
self.feature_position,
self.weighted,
self.nx_min_test,
self.nx_max_test,
)
rng,
self.batch_size,
self.feature_position,
self.weighted,
self.nx_min,
self.nx_max,
)
rng = next(self.rng_seq)

if self.resample:
Expand Down Expand Up @@ -362,7 +360,6 @@ def update(self):
self.params, self.graph, target_wse, False, indices_train
)


self.MCCs_train_wse.append(MCC_train_wse)

# Test
Expand All @@ -384,10 +381,6 @@ def update(self):

# Log
wandb_logs = {
# "loss_test": loss_test,
# "loss_test_wse": loss_test_wse,
# "loss_train": loss,
# "loss_train_wse": loss_wse,

"log_loss_test": np.log(loss_test),
"log_loss_test_wse": np.log(loss_test_wse),
Expand Down Expand Up @@ -653,36 +646,6 @@ def print_and_plot(self):
"in_out_targ_train_wse",
)

# graph_test, target_test = sample_padded_grid_batch_shortest_path(
# rng, self.batch_size_test, self.feature_position, self.weighted, self.nx_min_test, self.nx_max_test
# )
# graph_test= self.graph
# target_test = self.targets

# plot_message_passing_layers_units(outputs[1], target_test.sum(-1), outputs[0].nodes.tolist(),graph_test,config.num_hidden,config.num_message_passing_steps,edege_lables,os.path.join(save_path, 'message_passing_hidden_unit.pdf'))

# Plot each seperatly
# plot_graph_grid_activations(
# outputs[0].nodes.tolist(),
# graph_test,
# os.path.join(self.save_path, "outputs_test.pdf"),
# "Predicted Node Assignments with GCN test",
# self.edge_lables,
# )
# plot_graph_grid_activations(
# list(graph_test.nodes.sum(-1)),
# graph_test,
# os.path.join(self.save_path, "Inputs_test.pdf"),
# "Inputs node assigments test ",
# self.edge_lables,
# )
# plot_graph_grid_activations(
# target_test.sum(-1), graph_test, os.path.join(self.save_path, "Target_test.pdf"), "Target_test", self.edge_lables
# )

# PLOTTING ACTIVATION OF THE FIRST 2 GRAPH OF THE BATCHe


plot_input_target_output(
list(self.graph.nodes.sum(-1)),
self.targets.sum(-1),
Expand All @@ -708,25 +671,6 @@ def print_and_plot(self):
os.path.join(self.save_path, "message_passing_graph_train.pdf"),
"message_passing_graph_train",
)
# plot_message_passing_layers_units(outputs[1], target_test.sum(-1), outputs[0].nodes.tolist(),graph_test,config.num_hidden,config.num_message_passing_steps,edege_lables,os.path.join(save_path, 'message_passing_hidden_unit.pdf'))

# Plot each seperatly
# plot_graph_grid_activations(
# outputs[0].nodes.tolist(),
# graph_test,
# os.path.join(self.save_path, "outputs_train.pdf"),
# "Predicted Node Assignments with GCN",
# self.edge_lables,
# )
# plot_graph_grid_activations(
# list(graph_test.nodes.sum(-1)),
# graph_test,
# os.path.join(self.save_path, "Inputs_train.pdf"),
# "Inputs node assigments",
# self.edge_lables,
# )
# plot_graph_grid_activations(
# target_test.sum(-1), graph_test, os.path.join(self.save_path, "Target_train.pdf"), "Target", self.edge_lables

print('End')

Expand All @@ -751,12 +695,7 @@ def print_and_plot(self):
# Init environment
arena_x_limits = [-100, 100]
arena_y_limits = [-100, 100]
# env = Simple2D
# time_step_size=time_step_size,
# agent_step_size=agent_step_size,
# arena_x_limits=arena_x_limits,
# arena_y_limits=arena_y_limits,
# )


agent = Domine2023(
experiment_name=config.experiment_name,
Expand All @@ -783,7 +722,7 @@ def print_and_plot(self):

for n in range(config.num_training_steps):
agent.update()
agent.print_and_plot()
agent.print_and_plot()

# TODO: Run manadger (not possible for now), to get a seperated code we would juste need to change the paths and config this would mean get rid of the comfig
# The other alternative is to see that we have multiple env that we resample every time
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,10 @@ def sample_padded_grid_batch_shortest_path(
for n_x, n_y in zip(n_xs, n_ys):
nx_graph = get_grid_adjacency(n_x, n_y)
weights = add_weighted_edge(int(nx_graph.number_of_edges()), 1)
l=0

r=0
for i, j in nx_graph.edges:
l=l+1
nx_graph[i][j]["weight"] = weights[l]

nx_graph = nx.DiGraph(nx_graph)
r=r+1
nx_graph[i][j]["weight"] = weights[r]

i_start_1 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_x)
i_start_2 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_y)
Expand Down Expand Up @@ -85,43 +82,47 @@ def sample_padded_grid_batch_shortest_path(
input_node_features = input_node_features.at[node_number_end, 0].set(
1
) # set end node feature

(
senders,
receivers,
node_positions,
edge_displacements,
n_node,
n_edge,
global_context,
) = grid_networkx_to_graphstuple(nx_graph)

if feature_position:

input_node_features = jnp.concatenate(
(input_node_features, node_positions), axis=1
)

(
senders,
receivers,
node_positions,
edge_displacements,
n_node,
n_edge,
global_context,
) = grid_networkx_to_graphstuple(nx_graph)


nx_graph = nx.DiGraph(nx_graph)
if weighted:
edges = jnp.array(
edges_features = jnp.array(
[nx_graph[s][r]["weight"] for s, r in nx_graph.edges]
)
edge_displacement = jnp.concatenate((edge_displacements, edges), axis=1)
graph = jraph.GraphsTuple(
nodes=input_node_features,
senders=senders,
receivers=receivers,
edges=edge_displacement,
edges=edges_features,
n_node=jnp.array([n_node], dtype=int),
n_edge=jnp.array([n_edge], dtype=int),
globals=global_context,
)

else:
#TODO:Clementine: Chamge this line
edge_displacement=abs(np.sum(edge_displacements,1)).reshape(-1, 1)
graph = jraph.GraphsTuple(
nodes=input_node_features,
senders=senders,
edges= edge_displacement,
receivers=receivers,
edges=edge_displacements,
n_node=jnp.array([n_node], dtype=int),
n_edge=jnp.array([n_edge], dtype=int),
globals=global_context,
Expand Down Expand Up @@ -183,7 +184,7 @@ def grid_networkx_to_graphstuple(nx_graph):
def add_weighted_edge(n_edge, sigma_on_edge_weight_noise):
weights = jnp.zeros((n_edge, 1))
for k in range(n_edge):
weight = np.max([sigma_on_edge_weight_noise * np.random.rand() + 1.0, 0.5])
weight = round(np.max([sigma_on_edge_weight_noise * np.random.rand() + 1.0, 0.5]),2)
weights = weights.at[k, 0].set(weight)
# edge_displacement = edge_displacement.at[k,l].set(edge_displacement[k][l] + weight) # weights=sigma_on_edge_weight_noise * np.random.rand() Because nedd postiove and add as features and need ot be used by the neural networks :)
return weights
18 changes: 9 additions & 9 deletions neuralplayground/agents/domine_2023_extras/class_config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
experiment_name: 'Learn identity no resample with Feature position '
train_on_shortest_path: True # make sure it works when this is the case
train_on_shortest_path: True # make sure it works when this is the case
resample: True # @param
wandb_on: False
wandb_on: True
seed: 42

feature_position: False # make sure it works when this is the case
weighted: True

num_hidden: 300 # @param
num_layers: 2 # @param
num_message_passing_steps: 5 # @param
learning_rate: 0.0001 # @param
num_training_steps: 10 # @param
num_message_passing_steps: 3 # @param
learning_rate: 0.000001 # @param
num_training_steps: 300 # @param

# Env Stuff
batch_size: 2
nx_min: 2
nx_max: 3
nx_min: 5
nx_max: 6

batch_size_test: 2
nx_min_test: 2
nx_max_test: 3
nx_min_test: 5
nx_max_test: 6
Loading

0 comments on commit 0a8f2d9

Please sign in to comment.