Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementineDomine committed Oct 22, 2023
1 parent 712f9b6 commit 623462d
Show file tree
Hide file tree
Showing 10 changed files with 856 additions and 323 deletions.
423 changes: 306 additions & 117 deletions neuralplayground/agents/domine_2023.py

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@ def get_grid_adjacency(n_x, n_y, atol=1e-1):


def sample_padded_grid_batch_shortest_path(
rng, batch_size, feature_position, weighted, nx_min, nx_max, ny_min=None, ny_max=None
rng,
batch_size,
feature_position,
weighted,
nx_min,
nx_max,
ny_min=None,
ny_max=None,
):
rng_seq = rng_sequence_from_rng(rng)
"""Sample a batch of grid graphs with variable sizes.
Expand Down Expand Up @@ -43,37 +50,53 @@ def sample_padded_grid_batch_shortest_path(
for n_x, n_y in zip(n_xs, n_ys):
nx_graph = get_grid_adjacency(n_x, n_y)

senders, receivers, node_positions, edge_displacements, n_node, n_edge, global_context = grid_networkx_to_graphstuple(
nx_graph
)
(
senders,
receivers,
node_positions,
edge_displacements,
n_node,
n_edge,
global_context,
) = grid_networkx_to_graphstuple(nx_graph)

weights = add_weighted_edge(edge_displacements, n_edge, 1)
for i, j in nx_graph.edges():
nx_graph[i][j]['weight'] = weights[i]
nx_graph[i][j]["weight"] = weights[i]

i_start_1 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval= n_x)
i_start_1 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_x)
i_start_2 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_y)
i_end_1 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval= n_x)
i_end_1 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_x)
i_end_2 = jax.random.randint(next(rng_seq), shape=(1,), minval=0, maxval=n_y)

start = tuple(np.concatenate( (i_start_1,i_start_2), axis=0 ))
end = tuple(np.concatenate( (i_end_1,i_end_2), axis=0 ))
start = tuple(np.concatenate((i_start_1, i_start_2), axis=0))
end = tuple(np.concatenate((i_end_1, i_end_2), axis=0))

nodes_on_shortest_path_indexes_not_weighted = nx.shortest_path(nx_graph, start, end)
nodes_on_shortest_path_indexes = nx.shortest_path(nx_graph, start, end, weight='weight')
nodes_on_shortest_path_indexes_not_weighted = nx.shortest_path(
nx_graph, start, end
)
nodes_on_shortest_path_indexes = nx.shortest_path(
nx_graph, start, end, weight="weight"
)
# make it a node feature of the input graph if a node is a start/end node
input_node_features = jnp.zeros((n_node, 1))

node_number_start = (i_start_1) * n_y + (i_start_2)
node_number_end = (i_end_1) * n_y + (i_end_2)

input_node_features = input_node_features.at[node_number_start, 0].set(1) # set start node feature
input_node_features = input_node_features.at[node_number_end, 0].set(1) # set end node feature
input_node_features = input_node_features.at[node_number_start, 0].set(
1
) # set start node feature
input_node_features = input_node_features.at[node_number_end, 0].set(
1
) # set end node feature
if feature_position:
input_node_features = jnp.concatenate((input_node_features, node_positions), axis=1)
input_node_features = jnp.concatenate(
(input_node_features, node_positions), axis=1
)

if weighted:
edge_displacement = jnp.concatenate((edge_displacements,weights), axis=1)
edge_displacement = jnp.concatenate((edge_displacements, weights), axis=1)
graph = jraph.GraphsTuple(
nodes=input_node_features,
senders=senders,
Expand All @@ -98,16 +121,21 @@ def sample_padded_grid_batch_shortest_path(
graphs.append(graph)
nodes_on_shortest_labels = jnp.zeros((n_node, 1))
for i in nodes_on_shortest_path_indexes:
l=np.argwhere(np.all((node_positions - np.asarray(i)) == 0, axis=1))
nodes_on_shortest_labels = nodes_on_shortest_labels.at[l[0,0]].set(1)
l = np.argwhere(np.all((node_positions - np.asarray(i)) == 0, axis=1))
nodes_on_shortest_labels = nodes_on_shortest_labels.at[l[0, 0]].set(1)
target.append(nodes_on_shortest_labels) # set start node feature
targets = jnp.concatenate(target)
target_pad = jnp.zeros(((max_n - len(targets)), 1))
padded_target = jnp.concatenate((targets, target_pad), axis=0)
graph_batch = jraph.batch(graphs)
padded_graph_batch = jraph.pad_with_graphs(graph_batch, n_node=max_n, n_edge=max_e, n_graph=len(graphs) + 1)
padded_graph_batch = jraph.pad_with_graphs(
graph_batch, n_node=max_n, n_edge=max_e, n_graph=len(graphs) + 1
)

return padded_graph_batch, jnp.asarray(padded_target),
return (
padded_graph_batch,
jnp.asarray(padded_target),
)


def grid_networkx_to_graphstuple(nx_graph):
Expand All @@ -116,7 +144,9 @@ def grid_networkx_to_graphstuple(nx_graph):
node_positions = jnp.array(nx_graph.nodes)
node_to_inds = {n: i for i, n in enumerate(nx_graph.nodes)}
senders_receivers = [(node_to_inds[s], node_to_inds[r]) for s, r in nx_graph.edges]
edge_displacements = jnp.array([np.array(r) - np.array(s) for s, r in nx_graph.edges])
edge_displacements = jnp.array(
[np.array(r) - np.array(s) for s, r in nx_graph.edges]
)
senders, receivers = zip(*senders_receivers)
n_node = node_positions.shape[0]
n_edge = edge_displacements.shape[0]
Expand All @@ -131,7 +161,9 @@ def grid_networkx_to_graphstuple(nx_graph):
)


def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple, number_graph_batch) -> nx.Graph:
def convert_jraph_to_networkx_graph(
jraph_graph: jraph.GraphsTuple, number_graph_batch
) -> nx.Graph:
nodes, edges, receivers, senders, _, _, _ = jraph_graph
node_padd = 0
edges_padd = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ num_hidden: 500 # @param
num_layers: 2 # @param
num_message_passing_steps: 4 # @param
learning_rate: 0.001 # @param
num_training_steps: 300 # @param
num_training_steps: 10 # @param


# Env Stuff
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Dict, Union

from neuralplayground.agents.domine_2023_extras.class_config_template import ConfigTemplate
from neuralplayground.agents.domine_2023_extras.class_config_template import (
ConfigTemplate,
)
from config_manager import base_configuration


Expand Down
10 changes: 7 additions & 3 deletions neuralplayground/agents/domine_2023_extras/class_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def _forward(x):

# Map features to desired feature size.
x = jraph.GraphMapFeatures(
embed_edge_fn=hk.Linear(output_size=num_hidden), embed_node_fn=hk.Linear(output_size=num_hidden)
embed_edge_fn=hk.Linear(output_size=num_hidden),
embed_node_fn=hk.Linear(output_size=num_hidden),
)(x)

# Apply rounds of message passing.
Expand All @@ -29,7 +30,8 @@ def _forward(x):

# Map features to desired feature size.
x = jraph.GraphMapFeatures(
embed_edge_fn=hk.Linear(output_size=edge_output_size), embed_node_fn=hk.Linear(output_size=node_output_size)
embed_edge_fn=hk.Linear(output_size=edge_output_size),
embed_node_fn=hk.Linear(output_size=node_output_size),
)(x)

return x, message_passing
Expand All @@ -40,5 +42,7 @@ def _forward(x):
def message_passing_layer(x, edge_mlp_sizes, node_mlp_sizes):
update_edge_fn = jraph.concatenated_args(hk.nets.MLP(output_sizes=edge_mlp_sizes))
update_node_fn = jraph.concatenated_args(hk.nets.MLP(output_sizes=node_mlp_sizes))
x = jraph.GraphNetwork(update_edge_fn=update_edge_fn, update_node_fn=update_node_fn)(x)
x = jraph.GraphNetwork(
update_edge_fn=update_edge_fn, update_node_fn=update_node_fn
)(x)
return x
Loading

0 comments on commit 623462d

Please sign in to comment.