Skip to content

Commit

Permalink
Use scipy.spatial.distance.cdist
Browse files Browse the repository at this point in the history
  • Loading branch information
lochhh committed Aug 29, 2024
1 parent 5813ffa commit b110485
Show file tree
Hide file tree
Showing 3 changed files with 257 additions and 73 deletions.
152 changes: 139 additions & 13 deletions movement/analysis/kinematics.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
from typing import Literal

import xarray as xr
from scipy.spatial.distance import cdist as _cdist

from movement.utils.logging import log_error
from movement.utils.vector import compute_norm


def compute_displacement(data: xr.DataArray) -> xr.DataArray:
Expand Down Expand Up @@ -56,7 +56,7 @@ def compute_velocity(data: xr.DataArray) -> xr.DataArray:
See Also
--------
:py:meth:`xarray.DataArray.differentiate` : The underlying method used.
:meth:`xarray.DataArray.differentiate` : The underlying method used.
"""
return _compute_approximate_time_derivative(data, order=1)
Expand All @@ -82,12 +82,124 @@ def compute_acceleration(data: xr.DataArray) -> xr.DataArray:
See Also
--------
:py:meth:`xarray.DataArray.differentiate` : The underlying method used.
:meth:`xarray.DataArray.differentiate` : The underlying method used.
"""
return _compute_approximate_time_derivative(data, order=2)


def cdist(
a: xr.DataArray,
b: xr.DataArray,
dim: Literal["individuals", "keypoints"],
metric: str | None = "euclidean",
**kwargs,
) -> xr.DataArray:
"""Compute distance between each pair of the two collections of inputs.
This function is a wrapper around :func:`scipy.spatial.distance.cdist`
and computes the pairwise distances between each pair of inputs, where
the inputs are either ``individuals`` or ``keypoints``. The distances
are computed using the specified metric.
Parameters
----------
a : xarray.DataArray
The first input data containing position information of a
single individual or keypoint, with ``space`` as a dimension.
b : xarray.DataArray
The second input data containing position information of a
single individual or keypoint, with ``space`` as a dimension.
dim : str
The dimension to compute the distances for. Must be either
``'individuals'`` or ``'keypoints'``.
metric : str, optional
The distance metric to use. Must be one of the options supported
by :func:`scipy.spatial.distance.cdist`, i.e.
``'braycurtis'``, ``'canberra'``, ``'chebyshev'``, ``'cityblock'``,
``'correlation'``, ``'cosine'``, ``'dice'``, ``'euclidean'``,
``'hamming'``, ``'jaccard'``, ``'jensenshannon'``, ``'kulczynski1'``,
``'mahalanobis'``, ``'matching'``, ``'minkowski'``,
``'rogerstanimoto'``, ``'russellrao'``, ``'seuclidean'``,
``'sokalmichener'``, ``'sokalsneath'``, ``'sqeuclidean'``, ``'yule'``.
Defaults to ``'euclidean'``.
**kwargs : dict
Additional keyword arguments to pass to
:func:`scipy.spatial.distance.cdist`.
Returns
-------
xarray.DataArray
An xarray DataArray containing the computed distances between
each pair of inputs.
Examples
--------
Compute the Euclidean distance (default) between ``ind1`` and
``ind2`` (i.e. interindividual distance for all keypoints)
using the ``position`` data variable in the Dataset ``ds``:
>>> pos1 = ds.position.sel(individuals="ind1")
>>> pos2 = ds.position.sel(individuals="ind2")
>>> ind_dists = cdist(pos1, pos2, dim="individuals")
Compute the Euclidean distance (default) between ``key1`` and
``key2`` (i.e. interkeypoint distance for all individuals)
using the ``position`` data variable in the Dataset ``ds``:
>>> pos1 = ds.position.sel(keypoints="key1")
>>> pos2 = ds.position.sel(keypoints="key2")
>>> kp_dists = cdist(pos1, pos2, dim="keypoints")
Obtain the distance between ``key1`` of ``ind1`` and
``key2`` of ``ind2`` from ``ind_dists``
(i.e. interindividual distance, different keypoints):
>>> dist_ind1key1_ind2key2 = ind_dists.sel(ind1="key1", ind2="key2")
Equivalently, the same distance can be obtained from ``kp_dists``:
>>> dist_ind1key1_ind2key2 = key_dists.sel(key1="ind1", key2="ind2")
Obtain the distance between ``key1`` and ``key2`` of ``ind1``
(i.e. interkeypoint distance within the same individual):
>>> dist_ind1key1_ind1key2 = kp_dists.sel(key1="ind1", key2="ind1")
Obtain the distance between ``key1`` of ``ind1`` and ``ind2``
(i.e. interindividual distance, same keypoint)
>>> dist_ind1key1_ind2key1 = ind_dists.sel(ind1="key1", ind2="key1")
See Also
--------
scipy.spatial.distance.cdist : The underlying function used.
"""
# What happens if the input data has more dims than expected?
# What happens if the input data and dim are conflicting?
core_dim = "individuals" if dim == "keypoints" else "keypoints"
elem1 = getattr(a, dim).item()
elem2 = getattr(b, dim).item()
result = xr.apply_ufunc(
_cdist,
a,
b,
kwargs={"metric": metric, **kwargs},
input_core_dims=[[core_dim, "space"], [core_dim, "space"]],
output_core_dims=[[elem1, elem2]],
vectorize=True,
)
result = result.assign_coords(
{
elem1: getattr(a, core_dim).values,
elem2: getattr(a, core_dim).values,
}
)
return result


def compute_interindividual_distances(
data: xr.DataArray, pairs: dict[str, str | list[str]] | None = None
) -> xr.DataArray | dict[str, xr.DataArray]:
Expand Down Expand Up @@ -174,7 +286,7 @@ def _compute_approximate_time_derivative(
) -> xr.DataArray:
"""Compute the derivative using numerical differentiation.
This function uses :py:meth:`xarray.DataArray.differentiate`,
This function uses :meth:`xarray.DataArray.differentiate`,
which differentiates the array with the second order
accurate central differences.
Expand Down Expand Up @@ -209,14 +321,15 @@ def _compute_pairwise_distances(
data: xr.DataArray,
dim: Literal["individuals", "keypoints"],
pairs: dict[str, str | list[str]] | None = None,
metric: str | None = "euclidean",
**kwargs,
) -> xr.DataArray | dict[str, xr.DataArray]:
"""Compute pairwise distances between ``individuals`` or ``keypoints``.
This function computes the distances between pairs of ``keypoints``
(i.e. interkeypoint distances) or pairs of ``individuals`` (i.e.
interindividual distances). The distances are computed as the norm
of the difference in position between pairs of ``keypoints`` or
``individuals`` at each time point.
interindividual distances). The distances are computed using the
specified metric.
Parameters
----------
Expand All @@ -233,6 +346,19 @@ def _compute_pairwise_distances(
representing a keypoint or individual to compute the distance with.
If not provided, defaults to ``None`` and all possible combinations
of pairs are computed.
metric : str, optional
The distance metric to use. Must be one of the options supported
by :func:`scipy.spatial.distance.cdist`, i.e.
``'braycurtis'``, ``'canberra'``, ``'chebyshev'``, ``'cityblock'``,
``'correlation'``, ``'cosine'``, ``'dice'``, ``'euclidean'``,
``'hamming'``, ``'jaccard'``, ``'jensenshannon'``, ``'kulczynski1'``,
``'mahalanobis'``, ``'matching'``, ``'minkowski'``,
``'rogerstanimoto'``, ``'russellrao'``, ``'seuclidean'``,
``'sokalmichener'``, ``'sokalsneath'``, ``'sqeuclidean'``, ``'yule'``.
Defaults to ``'euclidean'``.
**kwargs : dict
Additional keyword arguments to pass to
:func:`scipy.spatial.distance.cdist`.
Returns
-------
Expand All @@ -246,7 +372,7 @@ def _compute_pairwise_distances(
See Also
--------
movement.utils.vector.compute_norm : Compute the norm of a vector.
:func:`scipy.spatial.distance.cdist` : The underlying function used.
"""
if dim not in ["individuals", "keypoints"]:
Expand All @@ -256,6 +382,7 @@ def _compute_pairwise_distances(
f"but got {dim}.",
)
pairwise_distances = {}

# Compute all possible pair combinations if not provided
if pairs is None:
paired_elements = list(
Expand All @@ -270,12 +397,11 @@ def _compute_pairwise_distances(
)
]
for elem1, elem2 in paired_elements:
distance = compute_norm(
data.sel({dim: elem1}) - data.sel({dim: elem2})
input1 = data.sel({dim: elem1})
input2 = data.sel({dim: elem2})
pairwise_distances[f"dist_{elem1}_{elem2}"] = cdist(
input1, input2, dim=dim, metric=metric, **kwargs
)
if dim in distance.coords:
distance = distance.drop_vars(dim)
pairwise_distances[f"dist_{elem1}_{elem2}"] = distance
# Return DataArray if result only has one key
if len(pairwise_distances) == 1:
return next(iter(pairwise_distances.values()))
Expand Down
46 changes: 31 additions & 15 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,19 +470,37 @@ def kinematic_property(request):


@pytest.fixture
def pairwise_distances_dataset(valid_poses_dataset):
"""Return a dataset in which the positions of either ``ind2`` or ``key2``
is offset by 1 unit (for testing pairwise distances computation).
def pairwise_distances_dataset():
"""Return a minimal poses dataset with 3 individuals
and 3 keypoints for pairwise distances computation.
"""

def _pairwise_distances_dataset(dim):
elem_name = f"{dim[:3]}2"
valid_poses_dataset.position.loc[{dim: elem_name}] = (
valid_poses_dataset.position.sel({dim: elem_name}) + 1
)
return valid_poses_dataset

return _pairwise_distances_dataset
time = np.arange(2)
space = ["x", "y"]
individuals = ["ind1", "ind2", "ind3"]
keypoints = ["key1", "key2", "key3"]
data = np.array(
[
[
[[1, 1], [0, 0], [1, 0]],
[[1, 0], [1, 1], [0, 0]],
[[0, 0], [1, 0], [1, 1]],
],
[
[[3, 6], [1, 4], [0, 4]],
[[0, 4], [3, 6], [1, 4]],
[[1, 4], [0, 4], [3, 6]],
],
]
)
return xr.Dataset(
data_vars={
"position": xr.DataArray(
data,
coords=[time, individuals, keypoints, space],
dims=["time", "individuals", "keypoints", "space"],
)
}
)


# ---------------- VIA tracks CSV file fixtures ----------------------------
Expand Down Expand Up @@ -732,6 +750,7 @@ def track_ids_not_unique_per_frame(
return file_path


# ----------------- Helpers fixture -----------------
class Helpers:
"""Generic helper methods for ``movement`` test modules."""

Expand All @@ -746,9 +765,6 @@ def count_consecutive_nans(da):
return (da.isnull().astype(int).diff("time") == 1).sum().item()


# ----------------- Helper fixture -----------------


@pytest.fixture
def helpers():
"""Return an instance of the ``Helpers`` class."""
Expand Down
Loading

0 comments on commit b110485

Please sign in to comment.