Skip to content

Commit

Permalink
Merge pull request #148 from jo-mueller/add-feature-histogram
Browse files Browse the repository at this point in the history
added a Feature Histogram Widget
  • Loading branch information
dstansby authored Aug 25, 2023
2 parents 3a8261a + e1ccfb1 commit dc1f3bb
Show file tree
Hide file tree
Showing 10 changed files with 231 additions and 12 deletions.
Binary file added baseline/test_feature_histogram2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
6 changes: 5 additions & 1 deletion docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
Changelog
=========
1.0.3
1.1.0
-----
Additions
~~~~~~~~~
- Added a widget to draw a histogram of features.

Changes
~~~~~~~
- The slice widget is now limited to slicing along the x/y dimensions. Support
Expand Down
1 change: 1 addition & 0 deletions docs/user_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ These widgets plot the data stored in the ``.features`` attribute of individual
Currently available are:

- 2D scatter plots of two features against each other.
- Histograms of individual features.

To use these:

Expand Down
9 changes: 9 additions & 0 deletions src/napari_matplotlib/features.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from napari.layers import Labels, Points, Shapes, Tracks, Vectors

FEATURES_LAYER_TYPES = (
Labels,
Points,
Shapes,
Tracks,
Vectors,
)
117 changes: 114 additions & 3 deletions src/napari_matplotlib/histogram.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from typing import Optional
from typing import Any, List, Optional, Tuple

import napari
import numpy as np
from qtpy.QtWidgets import QWidget
import numpy.typing as npt
from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget

from .base import SingleAxesWidget
from .features import FEATURES_LAYER_TYPES
from .util import Interval

__all__ = ["HistogramWidget"]
__all__ = ["HistogramWidget", "FeaturesHistogramWidget"]

_COLORS = {"r": "tab:red", "g": "tab:green", "b": "tab:blue"}

Expand Down Expand Up @@ -61,3 +63,112 @@ def draw(self) -> None:
self.axes.hist(data.ravel(), bins=bins, label=layer.name)

self.axes.legend()


class FeaturesHistogramWidget(SingleAxesWidget):
"""
Display a histogram of selected feature attached to selected layer.
"""

n_layers_input = Interval(1, 1)
# All layers that have a .features attributes
input_layer_types = FEATURES_LAYER_TYPES

def __init__(
self,
napari_viewer: napari.viewer.Viewer,
parent: Optional[QWidget] = None,
):
super().__init__(napari_viewer, parent=parent)

self.layout().addLayout(QVBoxLayout())
self._key_selection_widget = QComboBox()
self.layout().addWidget(QLabel("Key:"))
self.layout().addWidget(self._key_selection_widget)

self._key_selection_widget.currentTextChanged.connect(
self._set_axis_keys
)

self._update_layers(None)

@property
def x_axis_key(self) -> Optional[str]:
"""Key to access x axis data from the FeaturesTable"""
return self._x_axis_key

@x_axis_key.setter
def x_axis_key(self, key: Optional[str]) -> None:
self._x_axis_key = key
self._draw()

def _set_axis_keys(self, x_axis_key: str) -> None:
"""Set both axis keys and then redraw the plot"""
self._x_axis_key = x_axis_key
self._draw()

def _get_valid_axis_keys(self) -> List[str]:
"""
Get the valid axis keys from the layer FeatureTable.
Returns
-------
axis_keys : List[str]
The valid axis keys in the FeatureTable. If the table is empty
or there isn't a table, returns an empty list.
"""
if len(self.layers) == 0 or not (hasattr(self.layers[0], "features")):
return []
else:
return self.layers[0].features.keys()

def _get_data(self) -> Tuple[Optional[npt.NDArray[Any]], str]:
"""Get the plot data.
Returns
-------
data : List[np.ndarray]
List contains X and Y columns from the FeatureTable. Returns
an empty array if nothing to plot.
x_axis_name : str
The title to display on the x axis. Returns
an empty string if nothing to plot.
"""
if not hasattr(self.layers[0], "features"):
# if the selected layer doesn't have a featuretable,
# skip draw
return None, ""

feature_table = self.layers[0].features

if (len(feature_table) == 0) or (self.x_axis_key is None):
return None, ""

data = feature_table[self.x_axis_key]
x_axis_name = self.x_axis_key.replace("_", " ")

return data, x_axis_name

def on_update_layers(self) -> None:
"""
Called when the layer selection changes by ``self.update_layers()``.
"""
# reset the axis keys
self._x_axis_key = None

# Clear combobox
self._key_selection_widget.clear()
self._key_selection_widget.addItems(self._get_valid_axis_keys())

def draw(self) -> None:
"""Clear the axes and histogram the currently selected layer/slice."""
data, x_axis_name = self._get_data()

if data is None:
return

self.axes.hist(data, bins=50, edgecolor="white", linewidth=0.3)

# set ax labels
self.axes.set_xlabel(x_axis_name)
self.axes.set_ylabel("Counts [#]")
7 changes: 7 additions & 0 deletions src/napari_matplotlib/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@ contributions:
python_name: napari_matplotlib:FeaturesScatterWidget
title: Make a scatter plot of layer features

- id: napari-matplotlib.features_histogram
python_name: napari_matplotlib:FeaturesHistogramWidget
title: Plot feature histograms

- id: napari-matplotlib.slice
python_name: napari_matplotlib:SliceWidget
title: Plot a 1D slice
Expand All @@ -28,5 +32,8 @@ contributions:
- command: napari-matplotlib.features_scatter
display_name: FeaturesScatter

- command: napari-matplotlib.features_histogram
display_name: FeaturesHistogram

- command: napari-matplotlib.slice
display_name: 1D slice
9 changes: 2 additions & 7 deletions src/napari_matplotlib/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from qtpy.QtWidgets import QComboBox, QLabel, QVBoxLayout, QWidget

from .base import SingleAxesWidget
from .features import FEATURES_LAYER_TYPES
from .util import Interval

__all__ = ["ScatterBaseWidget", "ScatterWidget", "FeaturesScatterWidget"]
Expand Down Expand Up @@ -94,13 +95,7 @@ class FeaturesScatterWidget(ScatterBaseWidget):

n_layers_input = Interval(1, 1)
# All layers that have a .features attributes
input_layer_types = (
napari.layers.Labels,
napari.layers.Points,
napari.layers.Shapes,
napari.layers.Tracks,
napari.layers.Vectors,
)
input_layer_types = FEATURES_LAYER_TYPES

def __init__(
self,
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
94 changes: 93 additions & 1 deletion src/napari_matplotlib/tests/test_histogram.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from copy import deepcopy

import numpy as np
import pytest

from napari_matplotlib import HistogramWidget
from napari_matplotlib import FeaturesHistogramWidget, HistogramWidget
from napari_matplotlib.tests.helpers import (
assert_figures_equal,
assert_figures_not_equal,
)


@pytest.mark.mpl_image_compare
Expand All @@ -28,3 +33,90 @@ def test_histogram_3D(make_napari_viewer, brain_data):
# Need to return a copy, as original figure is too eagerley garbage
# collected by the widget
return deepcopy(fig)


def test_feature_histogram(make_napari_viewer):
n_points = 1000
random_points = np.random.random((n_points, 3)) * 10
feature1 = np.random.random(n_points)
feature2 = np.random.normal(size=n_points)

viewer = make_napari_viewer()
viewer.add_points(
random_points,
properties={"feature1": feature1, "feature2": feature2},
name="points1",
)
viewer.add_points(
random_points,
properties={"feature1": feature1, "feature2": feature2},
name="points2",
)

widget = FeaturesHistogramWidget(viewer)
viewer.window.add_dock_widget(widget)

# Check whether changing the selected key changes the plot
widget._set_axis_keys("feature1")
fig1 = deepcopy(widget.figure)

widget._set_axis_keys("feature2")
assert_figures_not_equal(widget.figure, fig1)

# check whether selecting a different layer produces the same plot
viewer.layers.selection.clear()
viewer.layers.selection.add(viewer.layers[1])
assert_figures_equal(widget.figure, fig1)


@pytest.mark.mpl_image_compare
def test_feature_histogram2(make_napari_viewer):
import numpy as np

np.random.seed(0)
n_points = 1000
random_points = np.random.random((n_points, 3)) * 10
feature1 = np.random.random(n_points)
feature2 = np.random.normal(size=n_points)

viewer = make_napari_viewer()
viewer.add_points(
random_points,
properties={"feature1": feature1, "feature2": feature2},
name="points1",
)
viewer.add_points(
random_points,
properties={"feature1": feature1, "feature2": feature2},
name="points2",
)

widget = FeaturesHistogramWidget(viewer)
viewer.window.add_dock_widget(widget)
widget._set_axis_keys("feature1")

fig = FeaturesHistogramWidget(viewer).figure
return deepcopy(fig)


def test_change_layer(make_napari_viewer, brain_data, astronaut_data):
viewer = make_napari_viewer()
widget = HistogramWidget(viewer)

viewer.add_image(brain_data[0], **brain_data[1])
viewer.add_image(astronaut_data[0], **astronaut_data[1])

# Select first layer
viewer.layers.selection.clear()
viewer.layers.selection.add(viewer.layers[0])
fig1 = deepcopy(widget.figure)

# Re-selecting first layer should produce identical plot
viewer.layers.selection.clear()
viewer.layers.selection.add(viewer.layers[0])
assert_figures_equal(widget.figure, fig1)

# Plotting the second layer should produce a different plot
viewer.layers.selection.clear()
viewer.layers.selection.add(viewer.layers[1])
assert_figures_not_equal(widget.figure, fig1)

0 comments on commit dc1f3bb

Please sign in to comment.