Skip to content

Commit

Permalink
neuralplayground/agents/domine_2023_extras/
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementineDomine committed Oct 2, 2023
1 parent 6384881 commit a74c55e
Showing 1 changed file with 24 additions and 26 deletions.
50 changes: 24 additions & 26 deletions neuralplayground/agents/domine_2023_extras/class_plotting_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# @title Make rng sequence generator
import matplotlib.pyplot as plt
import networkx as nx
from class_utils import convert_jraph_to_networkx_graph, get_activations_graph_n, get_node_pad
from neuralplayground.agents.domine_2023_extras.class_utils import convert_jraph_to_networkx_graph, get_activations_graph_n, get_node_pad

from neuralplayground.config import PLOT_CONFIG
config_vars = PLOT_CONFIG.GRAPH

def plot_input_target_output(inputs, targets, outputs, graph, n, edege_lables, save_path):
# minim 2 otherwise it breaks
Expand All @@ -15,7 +13,7 @@ def plot_input_target_output(inputs, targets, outputs, graph, n, edege_lables, s
nx_graph = convert_jraph_to_networkx_graph(graph, i)
pos = nx.spring_layout(nx_graph, iterations=100, seed=39775)
input = get_activations_graph_n(inputs, graph, i)
target = get_activations_graph_n(targets, graph, i)y
target = get_activations_graph_n(targets, graph, i)
output = get_activations_graph_n(outputs, graph, i)
axes[0, i].title.set_text("Input")
axes[1, i].title.set_text("Target")
Expand All @@ -33,16 +31,15 @@ def plot_input_target_output(inputs, targets, outputs, graph, n, edege_lables, s
nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=axes[1, i])
nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=axes[2, i])
# labels = nx.get_edge_attributes(nx_graph, 'weight')

nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=input, font_color=config_vars.FONT_COLOR, ax=axes[0, i])
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=target, font_color=config_vars.FONT_COLOR, ax=axes[1, i])
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=output, font_color=config_vars.FONT_COLOR, ax=axes[2, i])

nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[0, i])
nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=target, font_color="white",
ax=axes[1, i])
nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=output, font_color="white",
ax=axes[2, i])
for axes, row in zip(axes[:, 0], rows):
axes.set_ylabel(row, rotation=0, size="large")
plt.savefig(save_path)


def plot_message_passing_layers(inputs, activations, targets, outputs, graph, n, n_message_passing, edege_lables, save_path):
# minim 2 otherwise it breaks
fig, axes = plt.subplots(n_message_passing + 3, n)
Expand All @@ -64,39 +61,39 @@ def plot_message_passing_layers(inputs, activations, targets, outputs, graph, n,
axes[i, j].title.set_text("input")
input = get_activations_graph_n(inputs, graph, j)
nx.draw(
nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=input, font_color=config_vars.FONT_COLOR, ax=axes[i, j]
nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[i, j]
)
elif i == (n_message_passing + 1):
axes[i, j].title.set_text("target")
target = get_activations_graph_n(targets, graph, j)
nx.draw(
nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=target, font_color=config_vars.FONT_COLOR, ax=axes[i, j]
nx_graph, pos=pos, with_labels=True, node_size=200, node_color=target, font_color="white", ax=axes[i, j]
)

elif i == (n_message_passing):
axes[i, j].title.set_text("output")
output = get_activations_graph_n(outputs, graph, j)
nx.draw(
nx_graph, pos=pos, with_labels=True, node_size=config_vars.NODE_SIZE, node_color=output, font_color=config_vars.FONT_COLOR, ax=axes[i, j]
nx_graph, pos=pos, with_labels=True, node_size=200, node_color=output, font_color="white", ax=axes[i, j]
)
else:
activation = activations[i]
axes[i, j].title.set_text("graph_" + str(j) + "message_passing_" + str(i))
input = get_activations_graph_n(activation.nodes[:, j].tolist(), graph, j)
nx.draw(
nx_graph, pos=pos, with_labels=True, node_size=config_vars.NODE_SIZE, node_color=input, font_color=config_vars.FONT_COLOR, ax=axes[i, j]
nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[i, j]
)

plt.savefig(save_path)




def plot_graph_grid_activations(
node_colour,
graph,
save_path,
title,
edege_lables,
number_graph_batch=0,
node_colour,
graph,
save_path,
title,
edege_lables,
number_graph_batch=0,
):
nx_graph = convert_jraph_to_networkx_graph(graph, number_graph_batch)
output = get_activations_graph_n(node_colour, graph, number_graph_batch=0)
Expand All @@ -112,12 +109,12 @@ def plot_graph_grid_activations(
u = u + 1
labels = nx.get_edge_attributes(nx_graph, "weight")
nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=ax)
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=output, font_color=config_vars.FONT_COLOR, ax=ax)
nx.draw(nx_graph, pos=pos, with_labels=True, node_size=500, node_color=output, font_color="white", ax=ax)
plt.savefig(save_path)


def plot_message_passing_layers_units(
activations, targets, outputs, graph, number_hidden, n_message_passing, edege_lables, save_path
activations, targets, outputs, graph, number_hidden, n_message_passing, edege_lables, save_path
):
# minim 2 otherwise it breaks
fig, axes = plt.subplots(n_message_passing, number_hidden)
Expand All @@ -136,7 +133,8 @@ def plot_message_passing_layers_units(
axes[i, j].title.set_text("first_graph_unit_" + str(j) + "message_passing_" + str(i))
# We select the first graph only
input = get_activations_graph_n(activation.nodes[:, j].tolist(), graph, 0)
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS , node_size=config_vars.NODE_SIZE, node_color=input, font_color=config_vars.FONT_COLOR, ax=axes[i, j])
nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white",
ax=axes[i, j])
if edege_lables:
nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=axes[i, j])
plt.savefig(save_path)
Expand All @@ -147,4 +145,4 @@ def plot_xy(auc_roc, path, title):
ax = fig.add_subplot(111)
ax.title.set_text(title)
ax.plot(auc_roc)
plt.savefig(path)
plt.savefig(path)

0 comments on commit a74c55e

Please sign in to comment.