Skip to content

Commit

Permalink
Merge pull request #443 from MilagrosMarin/datajoint_pipeline
Browse files Browse the repository at this point in the history
Enhancements and refactoring: Ruff checks resolved
  • Loading branch information
ttngu207 authored Nov 20, 2024
2 parents ef7b816 + 886b6e0 commit 9be1f8e
Show file tree
Hide file tree
Showing 64 changed files with 907 additions and 416 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ dj_local_conf.json
log*.txt
scratch/
scratch*.py
**/*.nfs*
**/*.nfs*

# Test
.coverage
1 change: 0 additions & 1 deletion aeon/README.md
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
#
6 changes: 5 additions & 1 deletion aeon/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Top-level package for aeon."""

from importlib.metadata import PackageNotFoundError, version

try:
Expand All @@ -10,4 +12,6 @@
del version, PackageNotFoundError

# Set functions available directly under the 'aeon' top-level namespace
from aeon.io.api import load as load # noqa: PLC0414
from aeon.io.api import load

__all__ = ["load"]
1 change: 1 addition & 0 deletions aeon/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Utilities for analyzing data."""
38 changes: 23 additions & 15 deletions aeon/analysis/block_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,19 @@
patch_markers_linestyles = ["solid", "dash", "dot", "dashdot", "longdashdot"]


def gen_hex_grad(hex_col, vals, min_l=0.3):
def gen_hex_grad(hex_col, vals, min_lightness=0.3):
"""Generates an array of hex color values based on a gradient defined by unit-normalized values."""
# Convert hex to rgb to hls
h, l, s = rgb_to_hls(*[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)]) # noqa: E741
hue, lightness, saturation = rgb_to_hls(
*[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)]
)
grad = np.empty(shape=(len(vals),), dtype="<U10") # init grad
for i, val in enumerate(vals):
cur_l = (l * val) + (min_l * (1 - val)) # get cur lightness relative to `hex_col`
cur_l = max(min(cur_l, l), min_l) # set min, max bounds
cur_rgb_col = hls_to_rgb(h, cur_l, s) # convert to rgb
cur_lightness = (lightness * val) + (
min_lightness * (1 - val)
) # get cur lightness relative to `hex_col`
cur_lightness = max(min(cur_lightness, lightness), min_lightness) # set min, max bounds
cur_rgb_col = hls_to_rgb(hue, cur_lightness, saturation) # convert to rgb
cur_hex_col = "#{:02x}{:02x}{:02x}".format(
*tuple(int(c * 255) for c in cur_rgb_col)
) # convert to hex
Expand All @@ -55,19 +59,23 @@ def conv2d(arr, kernel):

def gen_subject_colors_dict(subject_names):
"""Generates a dictionary of subject colors based on a list of subjects."""
return {s: c for s, c in zip(subject_names, subject_colors)}
return dict(zip(subject_names, subject_colors, strict=False))


def gen_patch_style_dict(patch_names):
"""Based on a list of patches, generates a dictionary of:
- patch_colors_dict: patch name to color
- patch_markers_dict: patch name to marker
- patch_symbols_dict: patch name to symbol
- patch_linestyles_dict: patch name to linestyle
"""Generates a dictionary of patch styles given a list of patch_names.
The dictionary contains dictionaries which map patch names to their respective styles.
Below are the keys for each nested dictionary and their contents:
- colors: patch name to color
- markers: patch name to marker
- symbols: patch name to symbol
- linestyles: patch name to linestyle
"""
return {
"colors": {p: c for p, c in zip(patch_names, patch_colors)},
"markers": {p: m for p, m in zip(patch_names, patch_markers)},
"symbols": {p: s for p, s in zip(patch_names, patch_markers_symbols)},
"linestyles": {p: ls for p, ls in zip(patch_names, patch_markers_linestyles)},
"colors": dict(zip(patch_names, patch_colors, strict=False)),
"markers": dict(zip(patch_names, patch_markers, strict=False)),
"symbols": dict(zip(patch_names, patch_markers_symbols, strict=False)),
"linestyles": dict(zip(patch_names, patch_markers_linestyles, strict=False)),
}
30 changes: 14 additions & 16 deletions aeon/analysis/movies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Helper functions for processing video data."""

import math

import cv2
Expand All @@ -13,9 +15,8 @@ def gridframes(frames, width, height, shape: None | int | tuple[int, int] = None
:param list frames: A list of frames to include in the grid layout.
:param int width: The width of the output grid image, in pixels.
:param int height: The height of the output grid image, in pixels.
:param optional shape:
Either the number of frames to include, or the number of rows and columns
in the output grid image layout.
:param optional shape: Either the number of frames to include,
or the number of rows and columns in the output grid image layout.
:return: A new image containing the arrangement of the frames in a grid.
"""
if shape is None:
Expand Down Expand Up @@ -67,13 +68,12 @@ def groupframes(frames, n, fun):
def triggerclip(data, events, before=None, after=None):
"""Split video data around the specified sequence of event timestamps.
:param DataFrame data:
A pandas DataFrame where each row specifies video acquisition path and frame number.
:param DataFrame data: A pandas DataFrame where each row specifies
video acquisition path and frame number.
:param iterable events: A sequence of timestamps to extract.
:param Timedelta before: The left offset from each timestamp used to clip the data.
:param Timedelta after: The right offset from each timestamp used to clip the data.
:return:
A pandas DataFrame containing the frames, clip and sequence numbers for each event timestamp.
:return: A pandas DataFrame containing the frames, clip and sequence numbers for each event timestamp.
"""
if before is None:
before = pd.Timedelta(0)
Expand All @@ -100,9 +100,8 @@ def triggerclip(data, events, before=None, after=None):
def collatemovie(clipdata, fun):
"""Collates a set of video clips into a single movie using the specified aggregation function.
:param DataFrame clipdata:
A pandas DataFrame where each row specifies video path, frame number, clip and sequence number.
This DataFrame can be obtained from the output of the triggerclip function.
:param DataFrame clipdata: A pandas DataFrame where each row specifies video path, frame number,
clip and sequence number. This DataFrame can be obtained from the output of the triggerclip function.
:param callable fun: The aggregation function used to process the frames in each clip.
:return: The sequence of processed frames representing the collated movie.
"""
Expand All @@ -114,14 +113,13 @@ def collatemovie(clipdata, fun):
def gridmovie(clipdata, width, height, shape=None):
"""Collates a set of video clips into a grid movie with the specified pixel dimensions and grid layout.
:param DataFrame clipdata:
A pandas DataFrame where each row specifies video path, frame number, clip and sequence number.
This DataFrame can be obtained from the output of the triggerclip function.
:param DataFrame clipdata: A pandas DataFrame where each row specifies video path, frame number,
clip and sequence number.
This DataFrame can be obtained from the output of the triggerclip function.
:param int width: The width of the output grid movie, in pixels.
:param int height: The height of the output grid movie, in pixels.
:param optional shape:
Either the number of frames to include, or the number of rows and columns
in the output grid movie layout.
:param optional shape: Either the number of frames to include,
or the number of rows and columns in the output grid movie layout.
:return: The sequence of processed frames representing the collated grid movie.
"""
return collatemovie(clipdata, lambda g: gridframes(g, width, height, shape))
10 changes: 6 additions & 4 deletions aeon/analysis/plotting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Helper functions for plotting data."""

import math

import matplotlib.pyplot as plt
Expand Down Expand Up @@ -71,7 +73,7 @@ def rateplot(
:param datetime, optional end: The right bound of the time range for the continuous rate.
:param datetime, optional smooth: The size of the smoothing kernel applied to the rate output.
:param DateOffset, Timedelta or str, optional smooth:
The size of the smoothing kernel applied to the continuous rate output.
The size of the smoothing kernel applied to the continuous rate output.
:param bool, optional center: Specifies whether to center the convolution kernels.
:param Axes, optional ax: The Axes on which to draw the rate plot and raster.
"""
Expand Down Expand Up @@ -117,11 +119,11 @@ def colorline(
:param array-like x, y: The horizontal / vertical coordinates of the data points.
:param array-like, optional z:
The dynamic variable used to color each data point by indexing the color map.
The dynamic variable used to color each data point by indexing the color map.
:param str or ~matplotlib.colors.Colormap, optional cmap:
The colormap used to map normalized data values to RGBA colors.
The colormap used to map normalized data values to RGBA colors.
:param matplotlib.colors.Normalize, optional norm:
The normalizing object used to scale data to the range [0, 1] for indexing the color map.
The normalizing object used to scale data to the range [0, 1] for indexing the color map.
:param Axes, optional ax: The Axes on which to draw the colored line.
"""
if ax is None:
Expand Down
18 changes: 15 additions & 3 deletions aeon/analysis/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Helper functions for data analysis and visualization."""

import numpy as np
import pandas as pd

Expand Down Expand Up @@ -60,7 +62,6 @@ def visits(data, onset="Enter", offset="Exit"):
missing_data = data.duplicated(subset=time_offset, keep="last")
if missing_data.any():
data.loc[missing_data, ["duration"] + [name for name in data.columns if rsuffix in name]] = pd.NA

# rename columns and sort data
data.rename({time_onset: lonset, id_onset: "id", time_offset: loffset}, axis=1, inplace=True)
data = data[["id"] + [name for name in data.columns if "_" in name] + [lonset, loffset, "duration"]]
Expand All @@ -83,7 +84,7 @@ def rate(events, window, frequency, weight=1, start=None, end=None, smooth=None,
:param datetime, optional end: The right bound of the time range for the continuous rate.
:param datetime, optional smooth: The size of the smoothing kernel applied to the rate output.
:param DateOffset, Timedelta or str, optional smooth:
The size of the smoothing kernel applied to the continuous rate output.
The size of the smoothing kernel applied to the continuous rate output.
:param bool, optional center: Specifies whether to center the convolution kernels.
:return: A Series containing the continuous event rate over time.
"""
Expand All @@ -101,7 +102,18 @@ def rate(events, window, frequency, weight=1, start=None, end=None, smooth=None,
def get_events_rates(
events, window_len_sec, frequency, unit_len_sec=60, start=None, end=None, smooth=None, center=False
):
"""Computes the event rate from a sequence of events over a specified window."""
"""Computes the event rate from a sequence of events over a specified window.
:param Series events: The discrete sequence of events, with timestamps in seconds as index.
:param int window_len_sec: The length of the window over which the event rate is estimated.
:param DateOffset, Timedelta or str frequency: The sampling frequency for the continuous rate.
:param int, optional unit_len_sec: The length of one sample point. Default is 60 seconds.
:param datetime, optional start: The left bound of the time range for the continuous rate.
:param datetime, optional end: The right bound of the time range for the continuous rate.
:param int, optional smooth: The size of the smoothing kernel applied to the continuous rate output.
:param bool, optional center: Specifies whether to center the convolution kernels.
:return: A Series containing the continuous event rate over time.
"""
# events is an array with the time (in seconds) of event occurence
# window_len_sec is the size of the window over which the event rate is estimated
# unit_len_sec is the length of one sample point
Expand Down
15 changes: 12 additions & 3 deletions aeon/dj_pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""DataJoint pipeline for Aeon."""

import hashlib
import logging
import os
import uuid

import datajoint as dj

logger = dj.logger

_default_database_prefix = os.getenv("DJ_DB_PREFIX") or "aeon_"
_default_repository_config = {"ceph_aeon": "/ceph/aeon"}

Expand Down Expand Up @@ -53,7 +57,12 @@ def fetch_stream(query, drop_pk=True, round_microseconds=True):
df.rename(columns={"timestamps": "time"}, inplace=True)
df.set_index("time", inplace=True)
df.sort_index(inplace=True)
df = df.convert_dtypes(convert_string=False, convert_integer=False, convert_boolean=False, convert_floating=False)
df = df.convert_dtypes(
convert_string=False,
convert_integer=False,
convert_boolean=False,
convert_floating=False
)
if not df.empty and round_microseconds:
logging.warning("Rounding timestamps to microseconds is now enabled by default."
" To disable, set round_microseconds=False.")
Expand All @@ -68,5 +77,5 @@ def fetch_stream(query, drop_pk=True, round_microseconds=True):
from .utils import streams_maker

streams = dj.VirtualModule("streams", streams_maker.schema_name)
except:
pass
except Exception as e:
logger.debug(f"Could not import streams module: {e}")
Loading

0 comments on commit 9be1f8e

Please sign in to comment.