Skip to content

Commit

Permalink
further integration
Browse files Browse the repository at this point in the history
  • Loading branch information
ClementineDomine committed Oct 2, 2023
1 parent 058271c commit 6384881
Show file tree
Hide file tree
Showing 575 changed files with 1,057 additions and 40,613 deletions.
1 change: 1 addition & 0 deletions neuralplayground/agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
from .agent_core import AgentCore, RandomAgent, LevyFlightAgent
from .stachenfeld_2018 import Stachenfeld2018
from .weber_2018 import Weber2018
from .domine_2023 import Domine2023

# from .whittington_2020 import Whittington2020
2 changes: 1 addition & 1 deletion neuralplayground/agents/agent_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def save_agent(self, save_path: str, raw_object: bool = True):
if raw_object:
pickle.dump(self, open(os.path.join(save_path), "wb"), pickle.HIGHEST_PROTOCOL)
else:
pickle.dump(self.__dict__, open(os.path.join(save_path), "wb"), pickle.HIGHEST_PROTOCOL)
pickle.dump(self, open(os.path.join(save_path), "wb"), pickle.HIGHEST_PROTOCOL)

def restore_agent(self, restore_path: str):
"""Restore saved environment
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import jraph
import networkx as nx
import numpy as np
from class_utils import rng_sequence_from_rng
from neuralplayground.agents.domine_2023_extras.class_utils import rng_sequence_from_rng


def get_grid_adjacency(n_x, n_y, atol=1e-1):
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, Union

from class_config_template import ConfigTemplate
from neuralplayground.agents.domine_2023_extras.class_config_template import ConfigTemplate
from config_manager import base_configuration


Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import networkx as nx
from class_utils import convert_jraph_to_networkx_graph, get_activations_graph_n, get_node_pad

from neuralplayground.config import PLOT_CONFIG
config_vars = PLOT_CONFIG.GRAPH

def plot_input_target_output(inputs, targets, outputs, graph, n, edege_lables, save_path):
# minim 2 otherwise it breaks
Expand All @@ -13,7 +15,7 @@ def plot_input_target_output(inputs, targets, outputs, graph, n, edege_lables, s
nx_graph = convert_jraph_to_networkx_graph(graph, i)
pos = nx.spring_layout(nx_graph, iterations=100, seed=39775)
input = get_activations_graph_n(inputs, graph, i)
target = get_activations_graph_n(targets, graph, i)
target = get_activations_graph_n(targets, graph, i)y
output = get_activations_graph_n(outputs, graph, i)
axes[0, i].title.set_text("Input")
axes[1, i].title.set_text("Target")
Expand All @@ -32,9 +34,9 @@ def plot_input_target_output(inputs, targets, outputs, graph, n, edege_lables, s
nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=axes[2, i])
# labels = nx.get_edge_attributes(nx_graph, 'weight')

nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[0, i])
nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=target, font_color="white", ax=axes[1, i])
nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=output, font_color="white", ax=axes[2, i])
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=input, font_color=config_vars.FONT_COLOR, ax=axes[0, i])
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=target, font_color=config_vars.FONT_COLOR, ax=axes[1, i])
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=output, font_color=config_vars.FONT_COLOR, ax=axes[2, i])

for axes, row in zip(axes[:, 0], rows):
axes.set_ylabel(row, rotation=0, size="large")
Expand Down Expand Up @@ -62,27 +64,27 @@ def plot_message_passing_layers(inputs, activations, targets, outputs, graph, n,
axes[i, j].title.set_text("input")
input = get_activations_graph_n(inputs, graph, j)
nx.draw(
nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[i, j]
nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=input, font_color=config_vars.FONT_COLOR, ax=axes[i, j]
)
elif i == (n_message_passing + 1):
axes[i, j].title.set_text("target")
target = get_activations_graph_n(targets, graph, j)
nx.draw(
nx_graph, pos=pos, with_labels=True, node_size=200, node_color=target, font_color="white", ax=axes[i, j]
nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=target, font_color=config_vars.FONT_COLOR, ax=axes[i, j]
)

elif i == (n_message_passing):
axes[i, j].title.set_text("output")
output = get_activations_graph_n(outputs, graph, j)
nx.draw(
nx_graph, pos=pos, with_labels=True, node_size=200, node_color=output, font_color="white", ax=axes[i, j]
nx_graph, pos=pos, with_labels=True, node_size=config_vars.NODE_SIZE, node_color=output, font_color=config_vars.FONT_COLOR, ax=axes[i, j]
)
else:
activation = activations[i]
axes[i, j].title.set_text("graph_" + str(j) + "message_passing_" + str(i))
input = get_activations_graph_n(activation.nodes[:, j].tolist(), graph, j)
nx.draw(
nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[i, j]
nx_graph, pos=pos, with_labels=True, node_size=config_vars.NODE_SIZE, node_color=input, font_color=config_vars.FONT_COLOR, ax=axes[i, j]
)

plt.savefig(save_path)
Expand Down Expand Up @@ -110,7 +112,7 @@ def plot_graph_grid_activations(
u = u + 1
labels = nx.get_edge_attributes(nx_graph, "weight")
nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=ax)
nx.draw(nx_graph, pos=pos, with_labels=True, node_size=500, node_color=output, font_color="white", ax=ax)
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=output, font_color=config_vars.FONT_COLOR, ax=ax)
plt.savefig(save_path)


Expand All @@ -134,7 +136,7 @@ def plot_message_passing_layers_units(
axes[i, j].title.set_text("first_graph_unit_" + str(j) + "message_passing_" + str(i))
# We select the first graph only
input = get_activations_graph_n(activation.nodes[:, j].tolist(), graph, 0)
nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[i, j])
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS , node_size=config_vars.NODE_SIZE, node_color=input, font_color=config_vars.FONT_COLOR, ax=axes[i, j])
if edege_lables:
nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=axes[i, j])
plt.savefig(save_path)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
parser.add_argument(
"--config_path",
metavar="-C",
default="config.yaml",
default="class_config.yaml",
help="path to base configuration file.",
)

Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion neuralplayground/agents/readme.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
readme
readme
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import networkx as nx
from class_utils import convert_jraph_to_networkx_graph, get_activations_graph_n, get_node_pad

from neuralplayground.config import PLOT_CONFIG
config_vars = PLOT_CONFIG.GRAPH

def plot_input_target_output(inputs, targets, outputs, graph, n, edege_lables, save_path):
# minim 2 otherwise it breaks
Expand Down Expand Up @@ -32,10 +34,9 @@ def plot_input_target_output(inputs, targets, outputs, graph, n, edege_lables, s
nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=axes[2, i])
# labels = nx.get_edge_attributes(nx_graph, 'weight')

nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[0, i])
nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=target, font_color="white", ax=axes[1, i])

nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=output, font_color="white", ax=axes[2, i])
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=input, font_color=config_vars.FONT_COLOR, ax=axes[0, i])
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=target, font_color=config_vars.FONT_COLOR, ax=axes[1, i])
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=output, font_color=config_vars.FONT_COLOR, ax=axes[2, i])

for axes, row in zip(axes[:, 0], rows):
axes.set_ylabel(row, rotation=0, size="large")
Expand Down Expand Up @@ -63,27 +64,27 @@ def plot_message_passing_layers(inputs, activations, targets, outputs, graph, n,
axes[i, j].title.set_text("input")
input = get_activations_graph_n(inputs, graph, j)
nx.draw(
nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[i, j]
nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=input, font_color=config_vars.FONT_COLOR, ax=axes[i, j]
)
elif i == (n_message_passing + 1):
axes[i, j].title.set_text("target")
target = get_activations_graph_n(targets, graph, j)
nx.draw(
nx_graph, pos=pos, with_labels=True, node_size=200, node_color=target, font_color="white", ax=axes[i, j]
nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=target, font_color=config_vars.FONT_COLOR, ax=axes[i, j]
)

elif i == (n_message_passing):
axes[i, j].title.set_text("output")
output = get_activations_graph_n(outputs, graph, j)
nx.draw(
nx_graph, pos=pos, with_labels=True, node_size=200, node_color=output, font_color="white", ax=axes[i, j]
nx_graph, pos=pos, with_labels=True, node_size=config_vars.NODE_SIZE, node_color=output, font_color=config_vars.FONT_COLOR, ax=axes[i, j]
)
else:
activation = activations[i]
axes[i, j].title.set_text("graph_" + str(j) + "message_passing_" + str(i))
input = get_activations_graph_n(activation.nodes[:, j].tolist(), graph, j)
nx.draw(
nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[i, j]
nx_graph, pos=pos, with_labels=True, node_size=config_vars.NODE_SIZE, node_color=input, font_color=config_vars.FONT_COLOR, ax=axes[i, j]
)

plt.savefig(save_path)
Expand Down Expand Up @@ -111,7 +112,7 @@ def plot_graph_grid_activations(
u = u + 1
labels = nx.get_edge_attributes(nx_graph, "weight")
nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=ax)
nx.draw(nx_graph, pos=pos, with_labels=True, node_size=500, node_color=output, font_color="white", ax=ax)
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS, node_size=config_vars.NODE_SIZE, node_color=output, font_color=config_vars.FONT_COLOR, ax=ax)
plt.savefig(save_path)


Expand All @@ -135,7 +136,7 @@ def plot_message_passing_layers_units(
axes[i, j].title.set_text("first_graph_unit_" + str(j) + "message_passing_" + str(i))
# We select the first graph only
input = get_activations_graph_n(activation.nodes[:, j].tolist(), graph, 0)
nx.draw(nx_graph, pos=pos, with_labels=True, node_size=200, node_color=input, font_color="white", ax=axes[i, j])
nx.draw(nx_graph, pos=pos, with_labels=config_vars.WITH_LABELS , node_size=config_vars.NODE_SIZE, node_color=input, font_color=config_vars.FONT_COLOR, ax=axes[i, j])
if edege_lables:
nx.draw_networkx_edge_labels(nx_graph, pos=pos, edge_labels=labels, ax=axes[i, j])
plt.savefig(save_path)
Expand Down
Loading

0 comments on commit 6384881

Please sign in to comment.