Skip to content

Commit

Permalink
Refactor filtering module to take DataArrays as input (#209)
Browse files Browse the repository at this point in the history
* Draft dataarray accessor

* Move dataarray accessor methods to `filtering`

* Add dataarray functions, test equality

* Add tests

* Add integration test

* Remove filters taking Dataset as input

* Reorganise filtering module

* Update filter and smooth examples

* Replace `window_length` with `window`

* Format assert string

* Remove old code accidentally reintroduced during rebase

* Update docstrings

* Add filtering methods to the `move` accessor

* Add example to docstring

* Remove obsolete and unused function imports

* Move util functions to reports.py and logging.py

* Apply suggestions from code review

Co-authored-by: Niko Sirmpilatze <[email protected]>

* Update docstrings

* Add missing docstring

* Add `move` accessor examples in docstrings

* Remove `position` check in kinematics wrapper

* Change`interpolate_over_time` to operate on num of observations

* Add test for different `max_gap` values

* Update `filter_and_interpolate.py` example

* Fix `filtering_wrapper` bug

* Update filter examples

* Use dictionary `update` in `smooth` example

* Move `logger` assignment to top of file

* Add `update` example to "getting started"

* Cover both dataarray and dataset in `test_log_to_attrs`

* Test that ``log`` contains the filtering method applied

* Use :py:meth: syntax for xarray.DataArray.squeeze() in examples

* Update `reports.py` docstrings

* Handle missing `individuals` and `keypoints` dims in NaN-reports

* Return str in `report_nan_values`

* Clean up examples

* Convert filtering multiple data variables tip to section

* Use `update()` in `filter_and_interpolate` example

---------

Co-authored-by: Niko Sirmpilatze <[email protected]>
  • Loading branch information
lochhh and niksirbi authored Jul 17, 2024
1 parent b98cac0 commit b27d7a2
Show file tree
Hide file tree
Showing 17 changed files with 941 additions and 542 deletions.
1 change: 1 addition & 0 deletions docs/source/api_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Information on specific functions, classes, and methods.
movement.analysis.kinematics
movement.utils.vector
movement.utils.logging
movement.utils.reports
movement.sample_data
movement.validators.files
movement.validators.datasets
6 changes: 5 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"sphinx_design",
"sphinx_gallery.gen_gallery",
"sphinx_sitemap",
"sphinx.ext.autosectionlabel",
]

# Configure the myst parser to enable cool markdown features
Expand Down Expand Up @@ -76,6 +77,9 @@
autosummary_generate = True
autodoc_default_flags = ["members", "inherited-members"]

# Prefix section labels with the document name
autosectionlabel_prefix_document = True

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
Expand Down Expand Up @@ -104,7 +108,7 @@
"binderhub_url": "https://mybinder.org",
"dependencies": ["environment.yml"],
},
'remove_config_comments': True,
"remove_config_comments": True,
# do not render config params set as # sphinx_gallery_config [= value]
}

Expand Down
7 changes: 7 additions & 0 deletions docs/source/getting_started/movement_dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,3 +150,10 @@ Custom **attributes** can also be added to the dataset:
ds.attrs["my_custom_attribute"] = "my_custom_value"
# henceforth accessible as ds.my_custom_attribute
```

To update existing **data variables** in-place, e.g. `position`
and `velocity`:

```python
ds.update({"position": position, "velocity": velocity_filtered})
```
12 changes: 7 additions & 5 deletions examples/compute_kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,9 @@

# %%
# We can also easily plot the components of the position vector against time
# using ``xarray``'s built-in plotting methods. We use ``squeeze()`` to
# remove the dimension of length 1 from the data (the keypoints dimension).
# using ``xarray``'s built-in plotting methods. We use
# :py:meth:`xarray.DataArray.squeeze` to
# remove the dimension of length 1 from the data (the ``keypoints`` dimension).
position.squeeze().plot.line(x="time", row="individuals", aspect=2, size=2.5)
plt.gcf().show()

Expand All @@ -130,7 +131,7 @@

# %%
# Notice that we could also compute the displacement (and all the other
# kinematic variables) using the kinematics module:
# kinematic variables) using the :py:mod:`movement.analysis.kinematics` module:

# %%
import movement.analysis.kinematics as kin
Expand Down Expand Up @@ -282,8 +283,9 @@

# %%
# We can plot the components of the velocity vector against time
# using ``xarray``'s built-in plotting methods. We use ``squeeze()`` to
# remove the dimension of length 1 from the data (the keypoints dimension).
# using ``xarray``'s built-in plotting methods. We use
# :py:meth:`xarray.DataArray.squeeze` to
# remove the dimension of length 1 from the data (the ``keypoints`` dimension).

velocity.squeeze().plot.line(x="time", row="individuals", aspect=2, size=2.5)
plt.gcf().show()
Expand Down
151 changes: 113 additions & 38 deletions examples/filter_and_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# Imports
# -------
from movement import sample_data
from movement.filtering import filter_by_confidence, interpolate_over_time

# %%
# Load a sample dataset
Expand All @@ -19,16 +18,21 @@
print(ds)

# %%
# We can see that this dataset contains the 2D pose tracks and confidence
# scores for a single wasp, generated with DeepLabCut. There are 2 keypoints:
# "head" and "stinger".
# We see that the dataset contains the 2D pose tracks and confidence scores
# for a single wasp, generated with DeepLabCut. The wasp is tracked at two
# keypoints: "head" and "stinger" in a video that was recorded at 40 fps and
# lasts for approximately 27 seconds.

# %%
# Visualise the pose tracks
# -------------------------
# Since the data contains only a single wasp, we use
# :py:meth:`xarray.DataArray.squeeze` to remove
# the dimension of length 1 from the data (the ``individuals`` dimension).

position = ds.position.sel(individuals="individual_0")
position.plot.line(x="time", row="keypoints", hue="space", aspect=2, size=2.5)
ds.position.squeeze().plot.line(
x="time", row="keypoints", hue="space", aspect=2, size=2.5
)

# %%
# We can see that the pose tracks contain some implausible "jumps", such
Expand All @@ -46,70 +50,113 @@
# estimation frameworks, and their ranges can vary. Therefore,
# it's always a good idea to inspect the actual confidence values in the data.
#
# Let's first look at a histogram of the confidence scores.
ds.confidence.plot.hist(bins=20)
# Let's first look at a histogram of the confidence scores. As before, we use
# :py:meth:`xarray.DataArray.squeeze` to remove the ``individuals`` dimension
# from the data.

ds.confidence.squeeze().plot.hist(bins=20)

# %%
# Based on the above histogram, we can confirm that the confidence scores
# indeed range between 0 and 1, with most values closer to 1. Now let's see how
# they evolve over time.

confidence = ds.confidence.sel(individuals="individual_0")
confidence.plot.line(x="time", row="keypoints", aspect=2, size=2.5)
ds.confidence.squeeze().plot.line(
x="time", row="keypoints", aspect=2, size=2.5
)

# %%
# Encouragingly, some of the drops in confidence scores do seem to correspond
# to the implausible jumps and spikes we had seen in the position.
# We can use that to our advantage.


# %%
# Filter out points with low confidence
# -------------------------------------
# We can filter out points with confidence scores below a certain threshold.
# Here, we use ``threshold=0.6``. Points in the ``position`` data variable
# with confidence scores below this threshold will be converted to NaN.
# The ``print_report`` argument, which is True by default, reports the number
# of NaN values in the dataset before and after the filtering operation.
# Using the
# :py:meth:`filter_by_confidence()\
# <movement.move_accessor.MovementDataset.filtering_wrapper>`
# method of the ``move`` accessor,
# we can filter out points with confidence scores below a certain threshold.
# The default ``threshold=0.6`` will be used when ``threshold`` is not
# provided.
# This method will also report the number of NaN values in the dataset before
# and after the filtering operation by default (``print_report=True``).
# We will use :py:meth:`xarray.Dataset.update` to update ``ds`` in-place
# with the filtered ``position``.

ds.update({"position": ds.move.filter_by_confidence()})

ds_filtered = filter_by_confidence(ds, threshold=0.6, print_report=True)
# %%
# .. note::
# The ``move`` accessor :py:meth:`filter_by_confidence()\
# <movement.move_accessor.MovementDataset.filtering_wrapper>`
# method is a convenience method that applies
# :py:func:`movement.filtering.filter_by_confidence`,
# which takes ``position`` and ``confidence`` as arguments.
# The equivalent function call using the
# :py:mod:`movement.filtering` module would be:
#
# .. code-block:: python
#
# from movement.filtering import filter_by_confidence
#
# ds.update({"position": filter_by_confidence(position, confidence)})

# %%
# We can see that the filtering operation has introduced NaN values in the
# ``position`` data variable. Let's visualise the filtered data.

position_filtered = ds_filtered.position.sel(individuals="individual_0")
position_filtered.plot.line(
ds.position.squeeze().plot.line(
x="time", row="keypoints", hue="space", aspect=2, size=2.5
)

# %%
# Here we can see that gaps have appeared in the pose tracks, some of which
# are over the implausible jumps and spikes we had seen earlier. Moreover,
# most gaps seem to be brief, lasting < 1 second.
# Here we can see that gaps (consecutive NaNs) have appeared in the
# pose tracks, some of which are over the implausible jumps and spikes we had
# seen earlier. Moreover, most gaps seem to be brief,
# lasting < 1 second (or 40 frames).

# %%
# Interpolate over missing values
# -------------------------------
# We can interpolate over the gaps we've introduced in the pose tracks.
# Here we use the default linear interpolation method and ``max_gap=1``,
# meaning that we will only interpolate over gaps of 1 second or shorter.
# Setting ``max_gap=None`` would interpolate over all gaps, regardless of
# their length, which should be used with caution as it can introduce
# Using the
# :py:meth:`interpolate_over_time()\
# <movement.move_accessor.MovementDataset.filtering_wrapper>`
# method of the ``move`` accessor,
# we can interpolate over the gaps we've introduced in the pose tracks.
# Here we use the default linear interpolation method (``method=linear``)
# and interpolate over gaps of 40 frames or less (``max_gap=40``).
# The default ``max_gap=None`` would interpolate over all gaps, regardless of
# their length, but this should be used with caution as it can introduce
# spurious data. The ``print_report`` argument acts as described above.

ds_interpolated = interpolate_over_time(
ds_filtered, method="linear", max_gap=1, print_report=True
)
ds.update({"position": ds.move.interpolate_over_time(max_gap=40)})

# %%
# .. note::
# The ``move`` accessor :py:meth:`interpolate_over_time()\
# <movement.move_accessor.MovementDataset.filtering_wrapper>`
# is also a convenience method that applies
# :py:func:`movement.filtering.interpolate_over_time`
# to the ``position`` data variable.
# The equivalent function call using the
# :py:mod:`movement.filtering` module would be:
#
# .. code-block:: python
#
# from movement.filtering import interpolate_over_time
#
# ds.update({"position": interpolate_over_time(
# position_filtered, max_gap=40
# )})

# %%
# We see that all NaN values have disappeared, meaning that all gaps were
# indeed shorter than 1 second. Let's visualise the interpolated pose tracks
# indeed shorter than 40 frames.
# Let's visualise the interpolated pose tracks.

position_interpolated = ds_interpolated.position.sel(
individuals="individual_0"
)
position_interpolated.plot.line(
ds.position.squeeze().plot.line(
x="time", row="keypoints", hue="space", aspect=2, size=2.5
)

Expand All @@ -119,9 +166,37 @@
# So, far we've processed the pose tracks first by filtering out points with
# low confidence scores, and then by interpolating over missing values.
# The order of these operations and the parameters with which they were
# performed are saved in the ``log`` attribute of the dataset.
# performed are saved in the ``log`` attribute of the ``position`` data array.
# This is useful for keeping track of the processing steps that have been
# applied to the data.
# applied to the data. Let's inspect the log entries.

for log_entry in ds_interpolated.log:
for log_entry in ds.position.log:
print(log_entry)

# %%
# Filtering multiple data variables
# ---------------------------------
# All :py:mod:`movement.filtering` functions are available via the
# ``move`` accessor. These ``move`` accessor methods operate on the
# ``position`` data variable in the dataset ``ds`` by default.
# There is also an additional argument ``data_vars`` that allows us to
# specify which data variables in ``ds`` to filter.
# When multiple data variable names are specified in ``data_vars``,
# the method will return a dictionary with the data variable names as keys
# and the filtered DataArrays as values, otherwise it will return a single
# DataArray that is the filtered data.
# This is useful when we want to apply the same filtering operation to
# multiple data variables in ``ds`` at the same time.
#
# For instance, to filter both ``position`` and ``velocity`` data variables
# in ``ds``, based on the confidence scores, we can specify
# ``data_vars=["position", "velocity"]`` in the method call.
# As the filtered data variables are returned as a dictionary, we can once
# again use :py:meth:`xarray.Dataset.update` to update ``ds`` in-place
# with the filtered data variables.

ds["velocity"] = ds.move.compute_velocity()
filtered_data_dict = ds.move.filter_by_confidence(
data_vars=["position", "velocity"]
)
ds.update(filtered_data_dict)
Loading

0 comments on commit b27d7a2

Please sign in to comment.