Skip to content

Commit

Permalink
wenrnel tetrode, get_grid score, title figure , table figure and con…
Browse files Browse the repository at this point in the history
…fig update
  • Loading branch information
ClementineDomine committed Jul 17, 2023
1 parent 5a9fdc6 commit c00b760
Show file tree
Hide file tree
Showing 9 changed files with 549 additions and 718 deletions.
866 changes: 226 additions & 640 deletions examples/comparisons_examples/comparison_examples_score.ipynb

Large diffs are not rendered by default.

43 changes: 29 additions & 14 deletions neuralplayground/agents/stachenfeld_2018.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import matplotlib.pyplot as plt
import numpy as np

from neuralplayground.comparison import GridScorer
from neuralplayground.plotting.plot_utils import make_plot_rate_map

from .agent_core import AgentCore
Expand Down Expand Up @@ -349,9 +350,7 @@ def update_successor_rep_td_full(self, n_episode: int = None, t_episode: int = N
successor representation matrix
"""

random_state = np.random.RandomState(1234)

t_elapsed = 0
srmat0 = np.eye(self.n_state)
srmat_full = srmat0.copy()
Expand All @@ -374,6 +373,33 @@ def update_successor_rep_td_full(self, n_episode: int = None, t_episode: int = N
self.srmat_full_td = srmat_full
return self.srmat_full_td

def get_rate_map_matrix(
self,
sr_matrix=None,
eigen_vector: int = 10,
):
if sr_matrix is None:
sr_matrix = self.successor_rep_solution()
evals, evecs = np.linalg.eig(sr_matrix)
r_out_im = evecs[:, eigen_vector].reshape((self.resolution_width, self.resolution_depth)).real
return r_out_im

def get_grid_score(self, plot=False, eigen_vector=10):
"""
Get the grid score of the network
Returns
-------
grid_score : float
Grid score of the network
"""
r_out_im = self.get_rate_map_matrix(eigen_vector=eigen_vector)
GridScorer_SR = GridScorer(self.resolution_width)
score = GridScorer_SR.get_scores(np.asarray(r_out_im))
if plot:
GridScorer_SR.plot_sac(score[0])
return score[1]

def plot_transition(self, save_path: str = None, ax: mpl.axes.Axes = None):
"""
Plot the input matrix and compare it to the transition matrix from the rectangular
Expand All @@ -396,21 +422,10 @@ def plot_transition(self, save_path: str = None, ax: mpl.axes.Axes = None):
plt.close("all")
return ax

def get_rate_map_matrix(
self,
sr_matrix,
eigen_vector: int = 0,
):
if sr_matrix is None:
sr_matrix = self.successor_rep_solution()
evals, evecs = np.linalg.eig(sr_matrix)
r_out_im = evecs[:, eigen_vector].reshape((self.resolution_width, self.resolution_depth)).real
return r_out_im

def plot_rate_map(
self,
sr_matrix=None,
eigen_vectors: Union[int, list, tuple] = 0,
eigen_vectors: Union[int, list, tuple] = 10,
ax: mpl.axes.Axes = None,
save_path: str = None,
):
Expand Down
45 changes: 32 additions & 13 deletions neuralplayground/agents/weber_2018.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from scipy.stats import multivariate_normal
from tqdm import tqdm

from neuralplayground.comparison import GridScorer
from neuralplayground.plotting.plot_utils import make_plot_rate_map

from .agent_core import AgentCore
Expand Down Expand Up @@ -380,6 +381,37 @@ def full_update(self, exc_normalization: bool = True):
for i in range(self.xy_combinations.shape[0]):
self.update(exc_normalization=exc_normalization, pos=xy_array[i, :])

def get_rate_map_matrix(
self,
):
"""
Get the ratemap matrix of the network
Returns
-------
ratemap_matrix : ndarray
(self.resolution_width, self.resolution_depth) with the ratemap matrix
"""
r_out_im = self.get_full_output_rate()
r_out_im = r_out_im.reshape((self.resolution_width, self.resolution_depth))
return r_out_im

def get_grid_score(self, plot=False):
"""
Get the grid score of the network
Returns
-------
grid_score : float
Grid score of the network
"""
r_out_im = self.get_rate_map_matrix()
GridScorer_Webber = GridScorer(self.resolution_width)
score = GridScorer_Webber.get_scores(np.asarray(r_out_im))
if plot:
GridScorer_Webber.plot_sac(score[0])
return score[1]

def plot_rate_map(self, save_path: str = None, ax: mpl.axes.Axes = None):
"""
Plot current rates and an example of inhibitory and excitatory neuron
Expand Down Expand Up @@ -434,16 +466,3 @@ def plot_all_rates(self, save_path: str = None, ax: mpl.axes.Axes = None):
plt.close("all")
else:
return ax

def get_ratemap_matrix(self):
"""
Get the ratemap matrix of the network
Returns
-------
ratemap_matrix : ndarray
(self.resolution_width, self.resolution_depth) with the ratemap matrix
"""
r_out_im = self.get_full_output_rate()
r_out_im = r_out_im.reshape((self.resolution_width, self.resolution_depth))
return r_out_im
1 change: 1 addition & 0 deletions neuralplayground/arenas/hafting_2008.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def show_data(self, full_dataframe: bool = False):
List of available experiment, columns with rat_id, recording session and recorded variables
"""
self.experiment.show_data(full_dataframe=full_dataframe)
return self.experiment.show_data(full_dataframe=full_dataframe)

def plot_recording_tetr(
self,
Expand Down
25 changes: 19 additions & 6 deletions neuralplayground/config/default_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,30 @@ plot_config:
scatter_alpha: 1
scatter_marker: "o"
scatter_marker_size: 1
label_fontsize: 24
label_fontsize: 16
tick_label_fontsize: 12
colorbar_label_fontsize: 24
title_fontsize: 24
colorbar_label_fontsize: 16
title_fontsize: 16
grid: False

ratemap_plot:
ratemap_colormap: "jet"
bin_size: 2.0
label_fontsize: 24
label_fontsize: 18
tick_label_fontsize: 12
colorbar_label_fontsize: 24
title_fontsize: 24
colorbar_label_fontsize: 16
title_fontsize: 16
grid: False

agent_comparison_plot:
fontsize: 10

table_plot:
table_fontsize: 7
col_width: 30
row_height: 0.625
header_color: "C0"
row_colors: ["#f1f1f2", "w"]
edge_color: "w"
bbox: [0, 0, 1, 1]
header_columns: 0
92 changes: 79 additions & 13 deletions neuralplayground/config/plot_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,33 @@ class TrajectoryConfig(NPGConfig):
FIGURE_SIZE: tuple
Figure size
TRAJECTORY_COLORMAP: str
Colormap for agent trajectory plots
Colormap for trajectory plots
TRAJECTORY_ALPHA: float
Alpha value for agent trajectory plots
Alpha value for the trajectory plot
EXTERNAL_WALL_COLOR: str
Color of external walls
Color for the external wall of the arena
EXTERNAL_WALL_THICKNESS: float
Thickness of external walls
Thickness of the external wall of the arena
CUSTOM_WALL_COLOR: str
Color of custom walls
Color for the custom wall of the arena
CUSTOM_WALL_THICKNESS: float
Thickness of custom walls
Thickness of the custom wall of the arena
SCATTER_ALPHA: float
Alpha value for scatter dots in the agent trajectory plot
Alpha value for the scatter plot
SCATTER_MARKER: str
Marker for scatter dots in the agent trajectory plot
Marker for the scatter plot
SCATTER_MARKER_SIZE: float
Marker size for scatter dots in the agent trajectory plot
Size of the marker for the scatter plot
LABEL_FONTSIZE: float
Fontsize of labels in the plot
TICK_LABEL_FONTSIZE: float
Fontsize of tick labels in the plot
PLOT_EVERY_POINTS: int
Time steps skipped to make the plot to reduce cluttering
GRID: bool
Boolean value to plot grid in the background of trajectory plots
Whether to show grid in the plot
TITLE_FONTSIZE: float
Fontsize of the title in the plot
COLORBAR_LABEL_FONTSIZE: float
Fontsize of the colorbar label in the plot
"""

def __init__(self, **kwargs):
Expand Down Expand Up @@ -82,9 +84,73 @@ def __init__(self, **kwargs):
self.GRID = kwargs["grid"]


class AgentCompatisonConfig(NPGConfig):
"""Config object for ratemap plots
Attributes
----------
TABLE_FONTSIZE: float
Fontsize of table
"""

def __init__(self, **kwargs):
self.FONTSIZE = kwargs["fontsize"]


class TableConfig(NPGConfig):
"""Config object for ratemap plots
Attributes
----------
ROW_HEIGHT: float
Height of rows in the table
COL_WIDTH: float
Width of columns in the table
TABLE_FONTSIZE: float
Fontsize of table
ROW_COLOR: str
Color of rows in the table
HEADER_COLOR: str
Color of header in the table
EDGE_COLOR: str
Color of edges in the table
HEADER_COLLUMNS: list
List of header columns in the table
BBOX: tuple
Bounding box of the table
"""

def __init__(self, **kwargs):
self.ROW_HEIGHT = kwargs["row_height"]
self.COL_WIDTH = kwargs["col_width"]
self.TABLE_FONTSIZE = kwargs["table_fontsize"]
self.ROW_COLOR = kwargs["row_colors"]
self.HEADER_COLOR = kwargs["header_color"]
self.EDGE_COLOR = kwargs["edge_color"]
self.HEADER_COLUMNS = kwargs["header_columns"]
self.BBOX = kwargs["bbox"]


class PlotsConfig(NPGConfig):
"""Config object for plots, all plots are config are stored in this object"""
"""Config object for plots, all plots are config are stored in this object
Attributes
----------
TRAJECTORY: TrajectoryConfig
Config object for trajectory plots
RATEMAP: RateMapConfig
Config object for ratemap plots
AGENT_COMPARISON: AgentCompatisonConfig
Config object for agent comparison plots
TABLE: TableConfig
Config object for table plots
"""

def __init__(self, plot_config: dict):
self.TRAJECTORY = TrajectoryConfig(**plot_config["trajectory_plot"])
self.RATEMAP = RateMapConfig(**plot_config["ratemap_plot"])
self.AGENT_COMPARISON = AgentCompatisonConfig(**plot_config["agent_comparison_plot"])
self.TABLE = TableConfig(**plot_config["table_plot"])
28 changes: 24 additions & 4 deletions neuralplayground/experiments/hafting_2008_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from neuralplayground.datasets import fetch_data_path
from neuralplayground.utils import clean_data, get_2D_ratemap
from neuralplayground.plotting.plot_utils import make_plot_trajectories , make_plot_rate_map
from neuralplayground.comparison import GridScorer

from .experiment_core import Experiment

Expand Down Expand Up @@ -267,6 +268,23 @@ def get_tetrode_data(self, session_data: str = None, tetrode_id: str = None):

return time_array, test_spikes, x, y

def get_grid_score(self, plot=False):
"""
Get the grid score of the network
Returns
-------
grid_score : float
Grid score of the network
"""
r_out_im ,x_bin, y_bin = self.recording_tetr( )
# , tetrode_id="T6C1")
GridScorer_exp = GridScorer(x_bin.size - 1)
score = GridScorer_exp.get_scores(np.asarray(r_out_im))
if plot:
GridScorer_exp.plot_sac(score[0])
return score[1]

def plot_recording_tetr(
self,
recording_index: Union[int, tuple, list] = None,
Expand Down Expand Up @@ -331,12 +349,16 @@ def plot_recording_tetr(
# Generate axis in case ax is None
if ax is None:
f, ax = plt.subplots(1, 1, figsize=(10, 8))

# Compute ratemap matrices from data
session_data, rev_vars, rat_info = self.get_recording_data(recording_index)
if tetrode_id is None:
tetrode_id = self._find_tetrode(rev_vars)


h, binx, biny = self.recording_tetr(recording_index, save_path, tetrode_id, bin_size)

# Use auxiliary function to make the plot
ax = make_plot_rate_map(h, ax, tetrode_id,"width","depth","Firing rate")
ax = make_plot_rate_map(h, ax, 'rat: '+str(rat_info['rat_id'])+' sess: '+str(rat_info['sess'])+' tetrode: '+tetrode_id,"width","depth","Firing rate")
if save_path is None:
return h, binx, biny
else:
Expand Down Expand Up @@ -462,6 +484,4 @@ def recording_tetr(self, recording_index: Union[int, tuple, list] = None,
h, binx, biny = get_2D_ratemap(time_array, test_spikes, x, y, x_size=int(arena_width / bin_size),
y_size=int(arena_depth / bin_size), filter_result=True)


# Return ratemap values, x bin limits and y bin limits
return h, binx, biny
Loading

0 comments on commit c00b760

Please sign in to comment.