diff --git a/.gitignore b/.gitignore index 2d1a3823..2599b3ba 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,7 @@ dj_local_conf.json log*.txt scratch/ scratch*.py -**/*.nfs* \ No newline at end of file +**/*.nfs* + +# Test +.coverage \ No newline at end of file diff --git a/aeon/README.md b/aeon/README.md index 792d6005..e69de29b 100644 --- a/aeon/README.md +++ b/aeon/README.md @@ -1 +0,0 @@ -# diff --git a/aeon/__init__.py b/aeon/__init__.py index 2a691c53..f5cb7fe7 100644 --- a/aeon/__init__.py +++ b/aeon/__init__.py @@ -1,3 +1,5 @@ +"""Top-level package for aeon.""" + from importlib.metadata import PackageNotFoundError, version try: @@ -10,4 +12,6 @@ del version, PackageNotFoundError # Set functions available directly under the 'aeon' top-level namespace -from aeon.io.api import load as load # noqa: PLC0414 +from aeon.io.api import load + +__all__ = ["load"] diff --git a/aeon/analysis/__init__.py b/aeon/analysis/__init__.py index e69de29b..52f0038f 100644 --- a/aeon/analysis/__init__.py +++ b/aeon/analysis/__init__.py @@ -0,0 +1 @@ +"""Utilities for analyzing data.""" diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index 87c548e2..378c4713 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -26,15 +26,19 @@ patch_markers_linestyles = ["solid", "dash", "dot", "dashdot", "longdashdot"] -def gen_hex_grad(hex_col, vals, min_l=0.3): +def gen_hex_grad(hex_col, vals, min_lightness=0.3): """Generates an array of hex color values based on a gradient defined by unit-normalized values.""" # Convert hex to rgb to hls - h, l, s = rgb_to_hls(*[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)]) # noqa: E741 + hue, lightness, saturation = rgb_to_hls( + *[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)] + ) grad = np.empty(shape=(len(vals),), dtype=" master --- bonsai_workflow: varchar(36) - commit: varchar(64) # e.g. git commit hash of aeon_experiment used to generated this particular epoch + commit: varchar(64) # e.g. git commit hash of aeon_experiment used to generate this epoch source='': varchar(16) # e.g. aeon_experiment or aeon_acquisition (or others) metadata: longblob metadata_file_path: varchar(255) # path of the file, relative to the experiment repository @@ -318,6 +327,7 @@ class ActiveRegion(dj.Part): """ def make(self, key): + """Ingest metadata into EpochConfig.""" from aeon.dj_pipeline.utils import streams_maker from aeon.dj_pipeline.utils.load_metadata import ( extract_epoch_config, @@ -387,6 +397,7 @@ class File(dj.Part): @classmethod def ingest_chunks(cls, experiment_name): + """Ingest chunks for the specified ``experiment_name``.""" device_name = _ref_device_mapping.get(experiment_name, "CameraTop") all_chunks, raw_data_dirs = _get_all_chunks(experiment_name, device_name) @@ -534,6 +545,7 @@ class SubjectWeight(dj.Part): """ def make(self, key): + """Ingest environment data into Environment table.""" chunk_start, chunk_end = (Chunk & key).fetch1("chunk_start", "chunk_end") # Populate the part table @@ -597,6 +609,7 @@ class Name(dj.Part): """ def make(self, key): + """Ingest active configuration data into EnvironmentActiveConfiguration table.""" chunk_start, chunk_end = (Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = Experiment.get_data_directories(key) devices_schema = getattr( @@ -626,6 +639,7 @@ def make(self, key): def _get_all_chunks(experiment_name, device_name): + """Get all chunks for the specified ``experiment_name`` and ``device_name``.""" directory_types = ["quality-control", "raw"] raw_data_dirs = { dir_type: Experiment.get_data_directory( @@ -647,6 +661,7 @@ def _get_all_chunks(experiment_name, device_name): def _match_experiment_directory(experiment_name, path, directories): + """Match the path to the experiment directory.""" for k, v in directories.items(): raw_data_dir = v if pathlib.Path(raw_data_dir) in list(path.parents): diff --git a/aeon/dj_pipeline/analysis/__init__.py b/aeon/dj_pipeline/analysis/__init__.py index e69de29b..52f0038f 100644 --- a/aeon/dj_pipeline/analysis/__init__.py +++ b/aeon/dj_pipeline/analysis/__init__.py @@ -0,0 +1 @@ +"""Utilities for analyzing data.""" diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 2bc73823..2a233d8b 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -1,7 +1,9 @@ +"""Module for block analysis.""" + import itertools import json from collections import defaultdict -from datetime import datetime +from datetime import UTC, datetime import datajoint as dj import numpy as np @@ -19,17 +21,8 @@ gen_subject_colors_dict, subject_colors, ) -from aeon.dj_pipeline import ( - acquisition, - fetch_stream, - get_schema_name, - streams, - tracking, -) -from aeon.dj_pipeline.analysis.visit import ( - filter_out_maintenance_periods, - get_maintenance_periods, -) +from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, tracking +from aeon.dj_pipeline.analysis.visit import filter_out_maintenance_periods, get_maintenance_periods from aeon.io import api as io_api schema = dj.schema(get_schema_name("block_analysis")) @@ -128,8 +121,10 @@ class BlockAnalysis(dj.Computed): @property def key_source(self): - # Ensure that the chunk ingestion has caught up with this block before processing - # (there exists a chunk that ends after the block end time) + """Ensures chunk ingestion is complete before processing the block. + + This is done by checking that there exists a chunk that ends after the block end time. + """ ks = Block.aggr(acquisition.Chunk, latest_chunk_end="MAX(chunk_end)") ks = ks * Block & "latest_chunk_end >= block_end" & "block_end IS NOT NULL" return ks @@ -164,7 +159,12 @@ class Subject(dj.Part): """ def make(self, key): - """Restrict, fetch and aggregate data from different streams to produce intermediate data products at a per-block level (for different patches and different subjects). + """Collates data from various streams to produce per-block intermediate data products. + + The intermediate data products consist of data for each ``Patch`` + and each ``Subject`` within the ``Block``. + The steps to restrict, fetch, and aggregate data from various streams are as follows: + 1. Query data for all chunks within the block. 2. Fetch streams, filter by maintenance period. 3. Fetch subject position data (SLEAP). @@ -186,7 +186,9 @@ def make(self, key): for streams_table in streams_tables: if len(streams_table & chunk_keys) < len(streams_table.key_source & chunk_keys): raise ValueError( - f"BlockAnalysis Not Ready - {streams_table.__name__} not yet fully ingested for block: {key}. Skipping (to retry later)..." + f"BlockAnalysis Not Ready - {streams_table.__name__}" + f"not yet fully ingested for block: {key}." + f"Skipping (to retry later)..." ) # Check if SLEAPTracking is ready, if not, see if BlobPosition can be used instead @@ -194,7 +196,9 @@ def make(self, key): if len(tracking.SLEAPTracking & chunk_keys) < len(tracking.SLEAPTracking.key_source & chunk_keys): if len(tracking.BlobPosition & chunk_keys) < len(tracking.BlobPosition.key_source & chunk_keys): raise ValueError( - f"BlockAnalysis Not Ready - SLEAPTracking (and BlobPosition) not yet fully ingested for block: {key}. Skipping (to retry later)..." + "BlockAnalysis Not Ready - " + f"SLEAPTracking (and BlobPosition) not yet fully ingested for block: {key}. " + "Skipping (to retry later)..." ) else: use_blob_position = True @@ -215,7 +219,7 @@ def make(self, key): patch_keys, patch_names = patch_query.fetch("KEY", "underground_feeder_name") block_patch_entries = [] - for patch_key, patch_name in zip(patch_keys, patch_names): + for patch_key, patch_name in zip(patch_keys, patch_names, strict=False): # pellet delivery and patch threshold data depletion_state_df = fetch_stream( streams.UndergroundFeederDepletionState & patch_key & chunk_restriction @@ -257,12 +261,17 @@ def make(self, key): if not depletion_state_df.empty: if len(depletion_state_df.rate.unique()) > 1: - # multiple patch rates per block is unexpected, log a note and pick the first rate to move forward + # multiple patch rates per block is unexpected + # log a note and pick the first rate to move forward AnalysisNote.insert1( { - "note_timestamp": datetime.utcnow(), + "note_timestamp": datetime.now(UTC), "note_type": "Multiple patch rates", - "note": f"Found multiple patch rates for block {key} - patch: {patch_name} - rates: {depletion_state_df.rate.unique()}", + "note": ( + f"Found multiple patch rates for block {key} " + f"- patch: {patch_name} " + f"- rates: {depletion_state_df.rate.unique()}" + ), } ) @@ -309,7 +318,8 @@ def make(self, key): if use_blob_position and len(subject_names) > 1: raise ValueError( - f"Without SLEAPTracking, BlobPosition can only handle single-subject block. Found {len(subject_names)} subjects." + f"Without SLEAPTracking, BlobPosition can only handle a single-subject block. " + f"Found {len(subject_names)} subjects." ) block_subject_entries = [] @@ -329,7 +339,9 @@ def make(self, key): pos_df = fetch_stream(pos_query)[block_start:block_end] pos_df["likelihood"] = np.nan # keep only rows with area between 0 and 1000 - likely artifacts otherwise - pos_df = pos_df[(pos_df.area > 0) & (pos_df.area < 1000)] + MIN_AREA = 0 + MAX_AREA = 1000 + pos_df = pos_df[(pos_df.area > MIN_AREA) & (pos_df.area < MAX_AREA)] else: pos_query = ( streams.SpinnakerVideoSource @@ -408,7 +420,7 @@ class Patch(dj.Part): -> BlockAnalysis.Patch -> BlockAnalysis.Subject --- - in_patch_timestamps: longblob # timestamps in which a particular subject is spending time at a particular patch + in_patch_timestamps: longblob # timestamps when a subject is at a specific patch in_patch_time: float # total seconds spent in this patch for this block pellet_count: int pellet_timestamps: longblob @@ -434,6 +446,7 @@ class Preference(dj.Part): key_source = BlockAnalysis & BlockAnalysis.Patch & BlockAnalysis.Subject def make(self, key): + """Compute preference scores for each subject at each patch within a block.""" block_patches = (BlockAnalysis.Patch & key).fetch(as_dict=True) block_subjects = (BlockAnalysis.Subject & key).fetch(as_dict=True) subject_names = [s["subject_name"] for s in block_subjects] @@ -460,15 +473,19 @@ def make(self, key): # Ensure wheel_timestamps are of the same length across all patches wheel_lens = [len(p["wheel_timestamps"]) for p in block_patches] + MAX_WHEEL_DIFF = 10 + if len(set(wheel_lens)) > 1: max_diff = max(wheel_lens) - min(wheel_lens) - if max_diff > 10: + if max_diff > MAX_WHEEL_DIFF: # if diff is more than 10 samples, raise error, this is unexpected, some patches crash? - raise ValueError(f"Wheel data lengths are not consistent across patches ({max_diff} samples diff)") + raise ValueError( + f"Inconsistent wheel data lengths across patches ({max_diff} samples diff)" + ) + min_wheel_len = min(wheel_lens) for p in block_patches: - p["wheel_timestamps"] = p["wheel_timestamps"][: min(wheel_lens)] - p["wheel_cumsum_distance_travelled"] = p["wheel_cumsum_distance_travelled"][: min(wheel_lens)] - + p["wheel_timestamps"] = p["wheel_timestamps"][:min_wheel_len] + p["wheel_cumsum_distance_travelled"] = p["wheel_cumsum_distance_travelled"][:min_wheel_len] self.insert1(key) in_patch_radius = 130 # pixels @@ -606,11 +623,14 @@ def make(self, key): all_cum_time = np.sum( [all_subj_patch_pref_dict[p][subject_name]["cum_time"][-1] for p in patch_names] ) + + CUM_PREF_DIST_MIN = 1e-3 + for patch_name in patch_names: cum_pref_dist = ( all_subj_patch_pref_dict[patch_name][subject_name]["cum_dist"] / all_cum_dist ) - cum_pref_dist = np.where(cum_pref_dist < 1e-3, 0, cum_pref_dist) + cum_pref_dist = np.where(cum_pref_dist < CUM_PREF_DIST_MIN, 0, cum_pref_dist) all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_dist"] = cum_pref_dist cum_pref_time = ( @@ -679,6 +699,7 @@ class BlockPatchPlots(dj.Computed): """ def make(self, key): + """Compute and plot various block-level statistics and visualizations.""" # Define subject colors and patch styling for plotting exp_subject_names = (acquisition.Experiment.Subject & key).fetch("subject", order_by="subject") if not len(exp_subject_names): @@ -1317,6 +1338,7 @@ class BlockSubjectPositionPlots(dj.Computed): """ def make(self, key): + """Compute and plot various block-level statistics and visualizations.""" # Get some block info block_start, block_end = (Block & key).fetch1("block_start", "block_end") chunk_restriction = acquisition.create_chunk_restriction( @@ -1544,6 +1566,7 @@ def make(self, key): # ---- Foraging Bout Analysis ---- + @schema class BlockForaging(dj.Computed): definition = """ @@ -1564,6 +1587,7 @@ class Bout(dj.Part): """ def make(self, key): + """Compute and store foraging bouts for each subject in the block.""" foraging_bout_df = get_foraging_bouts(key) foraging_bout_df.rename( columns={ @@ -1595,24 +1619,32 @@ class AnalysisNote(dj.Manual): # ---- Helper Functions ---- + def get_threshold_associated_pellets(patch_key, start, end): - """Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time. + """Gets pellet delivery timestamps for each patch threshold update within the specified time range. 1. Get all patch state update timestamps (DepletionState): let's call these events "A" - - Remove all events within 1 second of each other - - Remove all events without threshold value (NaN) + + - Remove all events within 1 second of each other + - Remove all events without threshold value (NaN) 2. Get all pellet delivery timestamps (DeliverPellet): let's call these events "B" - - Find matching beam break timestamps within 1.2s after each pellet delivery + + - Find matching beam break timestamps within 1.2s after each pellet delivery 3. For each event "A", find the nearest event "B" within 100ms before or after the event "A" - - These are the pellet delivery events "B" associated with the previous threshold update event "A" - 4. Shift back the pellet delivery timestamps by 1 to match the pellet delivery with the previous threshold update + + - These are the pellet delivery events "B" associated with the previous threshold update event "A" + 4. Shift back the pellet delivery timestamps by 1 to match the pellet delivery with the + previous threshold update 5. Remove all threshold updates events "A" without a corresponding pellet delivery event "B" + Args: patch_key (dict): primary key for the patch start (datetime): start timestamp end (datetime): end timestamp + Returns: pd.DataFrame: DataFrame with the following columns: + - threshold_update_timestamp (index) - pellet_timestamp - beam_break_timestamp @@ -1646,19 +1678,22 @@ def get_threshold_associated_pellets(patch_key, start, end): ) # Step 2 - Remove invalid rows (back-to-back events) - # pellet delivery trigger - time difference is less than 1.2 seconds - invalid_rows = delivered_pellet_df.index.to_series().diff().dt.total_seconds() < 1.2 + BTB_MIN_TIME_DIFF = 1.2 # pellet delivery trigger - time diff is less than 1.2 seconds + BB_MIN_TIME_DIFF = 1.0 # beambreak - time difference is less than 1 seconds + PT_MIN_TIME_DIFF = 1.0 # patch threshold - time difference is less than 1 seconds + + invalid_rows = delivered_pellet_df.index.to_series().diff().dt.total_seconds() < BTB_MIN_TIME_DIFF delivered_pellet_df = delivered_pellet_df[~invalid_rows] # exclude manual deliveries delivered_pellet_df = delivered_pellet_df.loc[ delivered_pellet_df.index.difference(manual_delivery_df.index) ] - # beambreak - time difference is less than 1 seconds - invalid_rows = beambreak_df.index.to_series().diff().dt.total_seconds() < 1 + + invalid_rows = beambreak_df.index.to_series().diff().dt.total_seconds() < BB_MIN_TIME_DIFF beambreak_df = beambreak_df[~invalid_rows] - # patch threshold - time difference is less than 1 seconds + depletion_state_df = depletion_state_df.dropna(subset=["threshold"]) - invalid_rows = depletion_state_df.index.to_series().diff().dt.total_seconds() < 1 + invalid_rows = depletion_state_df.index.to_series().diff().dt.total_seconds() < PT_MIN_TIME_DIFF depletion_state_df = depletion_state_df[~invalid_rows] # Return empty if no data @@ -1675,7 +1710,7 @@ def get_threshold_associated_pellets(patch_key, start, end): beambreak_df.reset_index().rename(columns={"time": "beam_break_timestamp"}), left_on="time", right_on="beam_break_timestamp", - tolerance=pd.Timedelta("1.2s"), + tolerance=pd.Timedelta("{BTB_MIN_TIME_DIFF}s"), direction="forward", ) .set_index("time") @@ -1768,7 +1803,8 @@ def get_foraging_bouts( spun_indices = np.where(diffs > wheel_spun_thresh) patch_spun[spun_indices[1]] = patch_names[spun_indices[0]] patch_spun_df = pd.DataFrame( - {"cum_wheel_dist": comb_cum_wheel_dist, "patch_spun": patch_spun}, index=wheel_ts + {"cum_wheel_dist": comb_cum_wheel_dist, "patch_spun": patch_spun}, + index=wheel_ts, ) wheel_s_r = pd.Timedelta(wheel_ts[1] - wheel_ts[0], unit="ns") max_inactive_win_len = int(max_inactive_time / wheel_s_r) @@ -1788,7 +1824,8 @@ def get_foraging_bouts( if bout_start_indxs[-1] >= len(wheel_ts): bout_start_indxs = bout_start_indxs[:-1] bout_end_indxs = bout_end_indxs[:-1] - assert len(bout_start_indxs) == len(bout_end_indxs) + if len(bout_start_indxs) != len(bout_end_indxs): + raise ValueError("Mismatch between the lengths of bout_start_indxs and bout_end_indxs.") bout_durations = (wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs]).astype( # in seconds "timedelta64[ns]" ).astype(float) / 1e9 diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index babae2fb..2d3d43fd 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -1,13 +1,14 @@ -import datetime +"""Module for visit-related tables in the analysis schema.""" + +from collections import deque +from datetime import UTC, datetime + import datajoint as dj -import pandas as pd import numpy as np -from collections import deque +import pandas as pd from aeon.analysis import utils as analysis_utils - -from aeon.dj_pipeline import get_schema_name, fetch_stream -from aeon.dj_pipeline import acquisition, lab, qc, tracking +from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name schema = dj.schema(get_schema_name("analysis")) @@ -67,11 +68,13 @@ class Visit(dj.Part): @property def key_source(self): + """Key source for OverlapVisit.""" return dj.U("experiment_name", "place", "overlap_start") & (Visit & VisitEnd).proj( overlap_start="visit_start" ) def make(self, key): + """Populate OverlapVisit table with overlapping visits.""" visit_starts, visit_ends = (Visit * VisitEnd & key & {"visit_start": key["overlap_start"]}).fetch( "visit_start", "visit_end" ) @@ -113,12 +116,17 @@ def make(self, key): def ingest_environment_visits(experiment_names: list | None = None): - """Function to populate into `Visit` and `VisitEnd` for specified experiments (default: 'exp0.2-r0'). This ingestion routine handles only those "complete" visits, not ingesting any "on-going" visits using "analyze" method: `aeon.analyze.utils.visits()`. + """Populates ``Visit`` and ``VisitEnd`` for the specified experiment names. + + This ingestion routine includes only "complete" visits and + does not ingest any "on-going" visits. + Visits are retrieved using :func:`aeon.analysis.utils.visits`. Args: - experiment_names (list, optional): list of names of the experiment to populate into the Visit table. Defaults to None. + experiment_names (list, optional): list of names of the experiment + to populate into the ``Visit`` table. + If unspecified, defaults to ``None`` and ``['exp0.2-r0']`` is used. """ - if experiment_names is None: experiment_names = ["exp0.2-r0"] place_key = {"place": "environment"} @@ -131,7 +139,7 @@ def ingest_environment_visits(experiment_names: list | None = None): .fetch("last_visit") ) start = min(subjects_last_visits) if len(subjects_last_visits) else "1900-01-01" - end = datetime.datetime.now() if start else "2200-01-01" + end = datetime.now(UTC) if start else "2200-01-01" enter_exit_query = ( acquisition.SubjectEnterExit.Time * acquisition.EventType @@ -146,11 +154,9 @@ def ingest_environment_visits(experiment_names: list | None = None): enter_exit_df = pd.DataFrame( zip( *enter_exit_query.fetch( - "subject", - "enter_exit_time", - "event_type", - order_by="enter_exit_time", - ) + "subject", "enter_exit_time", "event_type", order_by="enter_exit_time" + ), + strict=False, ) ) enter_exit_df.columns = ["id", "time", "event"] @@ -187,6 +193,7 @@ def ingest_environment_visits(experiment_names: list | None = None): def get_maintenance_periods(experiment_name, start, end): + """Get maintenance periods for the specified experiment and time range.""" # get states from acquisition.Environment.EnvironmentState chunk_restriction = acquisition.create_chunk_restriction(experiment_name, start, end) state_query = ( @@ -219,12 +226,13 @@ def get_maintenance_periods(experiment_name, start, end): return deque( [ (pd.Timestamp(start), pd.Timestamp(end)) - for start, end in zip(maintenance_starts, maintenance_ends) + for start, end in zip(maintenance_starts, maintenance_ends, strict=False) ] ) # queue object. pop out from left after use def filter_out_maintenance_periods(data_df, maintenance_period, end_time, dropna=False): + """Filter out maintenance periods from the data_df.""" maint_period = maintenance_period.copy() while maint_period: (maintenance_start, maintenance_end) = maint_period[0] diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 4569193e..fe6db2c3 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -1,3 +1,5 @@ +"""Module for visit analysis.""" + import datetime from datetime import time @@ -5,19 +7,22 @@ import numpy as np import pandas as pd -from aeon.dj_pipeline import get_schema_name from aeon.dj_pipeline import acquisition, lab, tracking from aeon.dj_pipeline.analysis.visit import ( Visit, VisitEnd, - get_maintenance_periods, filter_out_maintenance_periods, + get_maintenance_periods, ) logger = dj.logger + # schema = dj.schema(get_schema_name("analysis")) schema = dj.schema() +# Constants values +MIN_AREA = 0 +MAX_AREA = 1000 # ---------- Position Filtering Method ------------------ @@ -66,8 +71,7 @@ class VisitSubjectPosition(dj.Computed): """ class TimeSlice(dj.Part): - definition = """ - # A short time-slice (e.g. 10 minutes) of the recording of a given animal for a visit + definition = """ # A short time-slice (e.g. 10min) of the recording of a given animal for a visit -> master time_slice_start: datetime(6) # datetime of the start of this time slice --- @@ -83,7 +87,8 @@ class TimeSlice(dj.Part): @property def key_source(self): - """Chunk for all visits: + """Chunk for all visits as the following conditions. + + visit_start during this Chunk - i.e. first chunk of the visit + visit_end during this Chunk - i.e. last chunk of the visit + chunk starts after visit_start and ends before visit_end (or NOW() - i.e. ongoing visits). @@ -96,20 +101,20 @@ def key_source(self): "visit_end BETWEEN chunk_start AND chunk_end", "chunk_start >= visit_start AND chunk_end <= visit_end", ] - & "chunk_start < chunk_end" # in some chunks, end timestamp comes before start (timestamp error) + & "chunk_start < chunk_end" + # in some chunks, end timestamp comes before start (timestamp error) ) def make(self, key): + """Populate VisitSubjectPosition for each visit.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") # -- Determine the time to start time_slicing in this chunk - if chunk_start < key["visit_start"] < chunk_end: - # For chunk containing the visit_start - i.e. first chunk of this visit - start_time = key["visit_start"] - else: - # For chunks after the first chunk of this visit - start_time = chunk_start - + start_time = ( + key["visit_start"] + if chunk_start < key["visit_start"] < chunk_end + else chunk_start # Use chunk_start if the visit has not yet started. + ) # -- Determine the time to end time_slicing in this chunk if VisitEnd & key: # finished visit visit_end = (VisitEnd & key).fetch1("visit_end") @@ -139,7 +144,8 @@ def make(self, key): object_id = (tracking.CameraTracking.Object & key).fetch1("object_id") else: logger.info( - '"More than one unique object ID found - using animal/object mapping from AnimalObjectMapping"' + "More than one unique object ID found - " + "using animal/object mapping from AnimalObjectMapping" ) if not (AnimalObjectMapping & key): raise ValueError( @@ -190,22 +196,33 @@ def make(self, key): @classmethod def get_position(cls, visit_key=None, subject=None, start=None, end=None): - """Given a key to a single Visit, return a Pandas DataFrame for the position data of the subject for the specified Visit time period.""" + """Retrieves a Pandas DataFrame of a subject's position data for a specified ``Visit``. + + A ``Visit`` is specified by either a ``visit_key`` or + a combination of ``subject``, ``start``, and ``end``. + If all four arguments are provided, the ``visit_key`` is ignored. + + Args: + visit_key (dict, optional): key to a single ``Visit``. + Only required if ``subject``, ``start``, and ``end`` are not provided. + subject (str, optional): subject name. + Only required if ``visit_key`` is not provided. + start (datetime): start time of the period of interest. + Only required if ``visit_key`` is not provided. + end (datetime, optional): end time of the period of interest. + Only required if ``visit_key`` is not provided. + """ if visit_key is not None: - assert len(Visit & visit_key) == 1 + if len(Visit & visit_key) != 1: + raise ValueError("The `visit_key` must correspond to exactly one Visit.") start, end = ( Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") & visit_key ).fetch1("visit_start", "visit_end") subject = visit_key["subject"] - elif all((subject, start, end)): - start = start - end = end - subject = subject - else: + elif not all((subject, start, end)): raise ValueError( - 'Either "visit_key" or all three "subject", "start" and "end" has to be specified' - ) - + 'Either "visit_key" or all three "subject", "start", and "end" must be specified.' + ) return tracking._get_position( cls.TimeSlice, object_attr="subject", @@ -242,7 +259,7 @@ class Nest(dj.Part): -> lab.ArenaNest --- time_fraction_in_nest: float # fraction of time the animal spent in this nest in this visit - in_nest: longblob # array of indices for when the animal is in this nest (index into the position data) + in_nest: longblob # indices array marking when the animal is in this nest """ class FoodPatch(dj.Part): @@ -258,6 +275,7 @@ class FoodPatch(dj.Part): key_source = Visit & (VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end") def make(self, key): + """Populate VisitTimeDistribution for each visit.""" visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) @@ -285,7 +303,7 @@ def make(self, key): position = filter_out_maintenance_periods(position, maintenance_period, day_end) # filter for objects of the correct size - valid_position = (position.area > 0) & (position.area < 1000) + valid_position = (position.area > MIN_AREA) & (position.area < MAX_AREA) position[~valid_position] = np.nan position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) # in corridor @@ -386,7 +404,7 @@ class VisitSummary(dj.Computed): --- day_duration: float # total duration (in hours) total_distance_travelled: float # (m) total distance the animal travelled during this visit - total_pellet_count: int # total pellet delivered (triggered) for all patches during this visit + total_pellet_count: int # total pellet triggered for all patches during this visit total_wheel_distance_travelled: float # total wheel travelled distance for all patches """ @@ -395,7 +413,7 @@ class FoodPatch(dj.Part): -> master -> acquisition.ExperimentFoodPatch --- - pellet_count: int # number of pellets being delivered (triggered) by this patch during this visit + pellet_count: int # number of pellets delivered by this patch during this visit wheel_distance_travelled: float # wheel travelled distance during this visit for this patch """ @@ -403,6 +421,7 @@ class FoodPatch(dj.Part): key_source = Visit & (VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end") def make(self, key): + """Populate VisitSummary for each visit.""" visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) @@ -430,7 +449,7 @@ def make(self, key): # filter out maintenance period based on logs position = filter_out_maintenance_periods(position, maintenance_period, day_end) # filter for objects of the correct size - valid_position = (position.area > 0) & (position.area < 1000) + valid_position = (position.area > MIN_AREA) & (position.area < MAX_AREA) position[~valid_position] = np.nan position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) position_diff = np.sqrt(np.square(np.diff(position.x)) + np.square(np.diff(position.y))) @@ -513,7 +532,9 @@ def make(self, key): @schema class VisitForagingBout(dj.Computed): - definition = """ # A time period spanning the time when the animal enters a food patch and moves the wheel to when it leaves the food patch + """Time period when a subject enters a food patch, moves the wheel, and then leaves the patch.""" + + definition = """ # Time from subject's entry to exit of a food patch to interact with the wheel. -> Visit -> acquisition.ExperimentFoodPatch bout_start: datetime(6) # start time of bout @@ -530,6 +551,7 @@ class VisitForagingBout(dj.Computed): ) * acquisition.ExperimentFoodPatch def make(self, key): + """Populate VisitForagingBout for each visit.""" visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") # get in_patch timestamps diff --git a/aeon/dj_pipeline/create_experiments/create_experiment_01.py b/aeon/dj_pipeline/create_experiments/create_experiment_01.py index 7b87bf11..18edb4c3 100644 --- a/aeon/dj_pipeline/create_experiments/create_experiment_01.py +++ b/aeon/dj_pipeline/create_experiments/create_experiment_01.py @@ -1,3 +1,5 @@ +"""Functions to populate the database with the metadata for experiment 0.1.""" + import pathlib import yaml @@ -9,6 +11,7 @@ def ingest_exp01_metadata(metadata_yml_filepath, experiment_name): + """Ingest metadata from a yml file into the database for experiment 0.1.""" with open(metadata_yml_filepath) as f: arena_setup = yaml.full_load(f) @@ -163,6 +166,7 @@ def ingest_exp01_metadata(metadata_yml_filepath, experiment_name): def create_new_experiment(): + """Create a new experiment and add subjects to it.""" # ---------------- Subject ----------------- subject.Subject.insert( [ @@ -241,6 +245,7 @@ def create_new_experiment(): def add_arena_setup(): + """Add arena setup.""" # Arena Setup - Experiment Devices this_file = pathlib.Path(__file__).expanduser().absolute().resolve() metadata_yml_filepath = this_file.parent / "setup_yml" / "Experiment0.1.yml" @@ -264,6 +269,7 @@ def add_arena_setup(): def main(): + """Main function to create a new experiment and set up the arena.""" create_new_experiment() add_arena_setup() diff --git a/aeon/dj_pipeline/create_experiments/create_experiment_02.py b/aeon/dj_pipeline/create_experiments/create_experiment_02.py index 82a8f03f..081e6911 100644 --- a/aeon/dj_pipeline/create_experiments/create_experiment_02.py +++ b/aeon/dj_pipeline/create_experiments/create_experiment_02.py @@ -1,3 +1,5 @@ +"""Functions to create new experiments for experiment0.2.""" + from aeon.dj_pipeline import acquisition, lab, subject # ============ Manual and automatic steps to for experiment 0.2 populate ============ @@ -6,6 +8,7 @@ def create_new_experiment(): + """Create new experiment for experiment0.2.""" # ---------------- Subject ----------------- subject_list = [ {"subject": "BAA-1100699", "sex": "U", "subject_birth_date": "2021-01-01"}, @@ -76,6 +79,7 @@ def create_new_experiment(): def main(): + """Main function to create a new experiment.""" create_new_experiment() diff --git a/aeon/dj_pipeline/create_experiments/create_octagon_1.py b/aeon/dj_pipeline/create_experiments/create_octagon_1.py index 4b077e65..3fc01ef6 100644 --- a/aeon/dj_pipeline/create_experiments/create_octagon_1.py +++ b/aeon/dj_pipeline/create_experiments/create_octagon_1.py @@ -1,3 +1,5 @@ +"""Functions to create new experiments for octagon1.0.""" + from aeon.dj_pipeline import acquisition, subject # ============ Manual and automatic steps to for experiment 0.2 populate ============ @@ -6,6 +8,7 @@ def create_new_experiment(): + """Create new experiment for octagon1.0.""" # ---------------- Subject ----------------- # This will get replaced by content from colony.csv subject_list = [ @@ -57,6 +60,7 @@ def create_new_experiment(): def main(): + """Main function to create a new experiment.""" create_new_experiment() diff --git a/aeon/dj_pipeline/create_experiments/create_presocial.py b/aeon/dj_pipeline/create_experiments/create_presocial.py index 05dc0dc8..3e9b8f76 100644 --- a/aeon/dj_pipeline/create_experiments/create_presocial.py +++ b/aeon/dj_pipeline/create_experiments/create_presocial.py @@ -1,3 +1,5 @@ +"""Functions to create new experiments for presocial0.1.""" + from aeon.dj_pipeline import acquisition, lab, subject experiment_type = "presocial0.1" @@ -7,6 +9,7 @@ def create_new_experiment(): + """Create new experiments for presocial0.1.""" lab.Location.insert1({"lab": "SWC", "location": location}, skip_duplicates=True) acquisition.ExperimentType.insert1({"experiment_type": experiment_type}, skip_duplicates=True) @@ -44,13 +47,14 @@ def create_new_experiment(): "directory_type": "raw", "directory_path": f"aeon/data/raw/{computer}/{experiment_type}", } - for experiment_name, computer in zip(experiment_names, computers) + for experiment_name, computer in zip(experiment_names, computers, strict=False) ], skip_duplicates=True, ) def main(): + """Main function to create a new experiment.""" create_new_experiment() diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py index 9a67eadc..a3e60a32 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py @@ -1,9 +1,10 @@ -from pathlib import Path +"""Functions to create new social experiments.""" + from datetime import datetime + from aeon.dj_pipeline import acquisition from aeon.dj_pipeline.utils.paths import get_repository_path - # ---- Programmatic creation of a new social experiment ---- # Infer experiment metadata from the experiment name # User-specified "experiment_name" (everything else should be automatically inferred) @@ -16,6 +17,7 @@ def create_new_social_experiment(experiment_name): + """Create new social experiment.""" exp_name, machine_name = experiment_name.split("-") raw_dir = ceph_data_dir / "raw" / machine_name.upper() / exp_name if not raw_dir.exists(): @@ -55,4 +57,3 @@ def create_new_social_experiment(experiment_name): {"experiment_name": experiment_name, "devices_schema_name": exp_name.replace(".", "")}, skip_duplicates=True, ) - diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py index b50c0a17..7d10f734 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py @@ -1,3 +1,5 @@ +"""Functions to create new experiments for social0-r1.""" + import pathlib from aeon.dj_pipeline import acquisition, lab, subject @@ -8,6 +10,7 @@ def create_new_experiment(): + """Create new experiments for social0-r1.""" # ---------------- Subject ----------------- subject_list = [ {"subject": "BAA-1100704", "sex": "U", "subject_birth_date": "2021-01-01"}, @@ -79,6 +82,7 @@ def create_new_experiment(): def add_arena_setup(): + """Add arena setup.""" # Arena Setup - Experiment Devices this_file = pathlib.Path(__file__).expanduser().absolute().resolve() metadata_yml_filepath = this_file.parent / "setup_yml" / "SocialExperiment0.yml" @@ -102,6 +106,7 @@ def add_arena_setup(): def main(): + """Main function to create a new experiment and set up the arena.""" create_new_experiment() add_arena_setup() @@ -148,19 +153,21 @@ def fixID(subjid, valid_ids=None, valid_id_file=None): if ";" in subjid: subjidA, subjidB = subjid.split(";") return ( - f"{fixID(subjidA.strip(), valid_ids=valid_ids)};{fixID(subjidB.strip(), valid_ids=valid_ids)}" + f"{fixID(subjidA.strip(), valid_ids=valid_ids)};" + f"{fixID(subjidB.strip(), valid_ids=valid_ids)}" ) if "vs" in subjid: subjidA, tmp, subjidB = subjid.split(" ")[1:] return ( - f"{fixID(subjidA.strip(), valid_ids=valid_ids)};{fixID(subjidB.strip(), valid_ids=valid_ids)}" + f"{fixID(subjidA.strip(), valid_ids=valid_ids)};" + f"{fixID(subjidB.strip(), valid_ids=valid_ids)}" ) try: ld = [jl.levenshtein_distance(subjid, x[-len(subjid) :]) for x in valid_ids] return valid_ids[np.argmin(ld)] - except: + except ValueError: return subjid diff --git a/aeon/dj_pipeline/docs/notebooks/social_experiments_block_analysis.ipynb b/aeon/dj_pipeline/docs/notebooks/social_experiments_block_analysis.ipynb index 02b0b919..f84b786b 100644 --- a/aeon/dj_pipeline/docs/notebooks/social_experiments_block_analysis.ipynb +++ b/aeon/dj_pipeline/docs/notebooks/social_experiments_block_analysis.ipynb @@ -83,7 +83,10 @@ "outputs": [], "source": [ "# Pick a block of interest\n", - "block_key = {\"experiment_name\": \"social0.1-aeon3\", \"block_start\": \"2023-11-30 18:49:05.001984\"}" + "block_key = {\n", + " \"experiment_name\": \"social0.1-aeon3\",\n", + " \"block_start\": \"2023-11-30 18:49:05.001984\",\n", + "}" ], "metadata": { "collapsed": false, diff --git a/aeon/dj_pipeline/lab.py b/aeon/dj_pipeline/lab.py index 2f10665f..141c40cd 100644 --- a/aeon/dj_pipeline/lab.py +++ b/aeon/dj_pipeline/lab.py @@ -1,8 +1,11 @@ +"""DataJoint schema for the lab pipeline.""" + import datajoint as dj from . import get_schema_name schema = dj.schema(get_schema_name("lab")) +logger = dj.logger # ------------------- GENERAL LAB INFORMATION -------------------- @@ -82,12 +85,13 @@ class ArenaShape(dj.Lookup): definition = """ arena_shape: varchar(32) """ - contents = zip(["square", "circular", "rectangular", "linear", "octagon"]) + contents = zip(["square", "circular", "rectangular", "linear", "octagon"], strict=False) @schema class Arena(dj.Lookup): - """Coordinate frame convention: + """Coordinate frame convention as the following items. + + x-dimension: x=0 is the left most point of the bounding box of the arena + y-dimension: y=0 is the top most point of the bounding box of the arena + z-dimension: z=0 is the lowest point of the arena (e.g. the ground) diff --git a/aeon/dj_pipeline/populate/__init__.py b/aeon/dj_pipeline/populate/__init__.py index e69de29b..7178aec4 100644 --- a/aeon/dj_pipeline/populate/__init__.py +++ b/aeon/dj_pipeline/populate/__init__.py @@ -0,0 +1 @@ +"""Utilities for the workflow orchestration.""" diff --git a/aeon/dj_pipeline/populate/process.py b/aeon/dj_pipeline/populate/process.py index 5c2e4d15..d3699b25 100644 --- a/aeon/dj_pipeline/populate/process.py +++ b/aeon/dj_pipeline/populate/process.py @@ -1,8 +1,12 @@ """Start an Aeon ingestion process. -This script defines auto-processing routines to operate the DataJoint pipeline for the Aeon project. Three separate "process" functions are defined to call `populate()` for different groups of tables, depending on their priority in the ingestion routines (high, mid, low). +This script defines auto-processing routines to operate the DataJoint pipeline +for the Aeon project. Three separate "process" functions are defined to call +`populate()` for different groups of tables, depending on their priority in +the ingestion routines (high, mid, low). -Each process function is run in a while-loop with the total run-duration configurable via command line argument '--duration' (if not set, runs perpetually) +Each process function is run in a while-loop with the total run-duration configurable +via command line argument '--duration' (if not set, runs perpetually) - the loop will not begin a new cycle after this period of time (in seconds) - the loop will run perpetually if duration<0 or if duration==None - the script will not be killed _at_ this limit, it will keep executing, @@ -22,7 +26,8 @@ Usage from python: - `from aeon.dj_pipeline.populate.process import run; run(worker_name='high_priority', duration=20, sleep=5)` + `from aeon.dj_pipeline.populate.process import run + run(worker_name='high_priority', duration=20, sleep=5)` """ @@ -33,10 +38,10 @@ from aeon.dj_pipeline.populate.worker import ( acquisition_worker, - logger, analysis_worker, - streams_worker, + logger, pyrat_worker, + streams_worker, ) # ---- some wrappers to support execution as script or CLI diff --git a/aeon/dj_pipeline/populate/worker.py b/aeon/dj_pipeline/populate/worker.py index 81e9cb18..e758f9d5 100644 --- a/aeon/dj_pipeline/populate/worker.py +++ b/aeon/dj_pipeline/populate/worker.py @@ -1,9 +1,10 @@ +"""This module defines the workers for the AEON pipeline.""" + import datajoint as dj -from datajoint_utilities.dj_worker import DataJointWorker, ErrorLog, WorkerLog, RegisteredWorker +from datajoint_utilities.dj_worker import DataJointWorker, ErrorLog, WorkerLog from datajoint_utilities.dj_worker.worker_schema import is_djtable -from aeon.dj_pipeline import db_prefix -from aeon.dj_pipeline import subject, acquisition, tracking, qc +from aeon.dj_pipeline import acquisition, db_prefix, qc, subject, tracking from aeon.dj_pipeline.analysis import block_analysis from aeon.dj_pipeline.utils import streams_maker @@ -106,6 +107,7 @@ def ingest_epochs_chunks(): def get_workflow_operation_overview(): + """Get the workflow operation overview for the worker schema.""" from datajoint_utilities.dj_worker.utils import get_workflow_operation_overview return get_workflow_operation_overview(worker_schema_name=worker_schema_name, db_prefixes=[db_prefix]) diff --git a/aeon/dj_pipeline/qc.py b/aeon/dj_pipeline/qc.py index 7044da0e..b54951dd 100644 --- a/aeon/dj_pipeline/qc.py +++ b/aeon/dj_pipeline/qc.py @@ -1,13 +1,14 @@ +"""DataJoint schema for the quality control pipeline.""" + import datajoint as dj import numpy as np import pandas as pd +from aeon.dj_pipeline import acquisition, get_schema_name, streams from aeon.io import api as io_api -from aeon.dj_pipeline import get_schema_name -from aeon.dj_pipeline import acquisition, streams - schema = dj.schema(get_schema_name("qc")) +logger = dj.logger # -------------- Quality Control --------------------- @@ -28,7 +29,7 @@ class QCRoutine(dj.Lookup): --- qc_routine_order: int # the order in which this qc routine is executed qc_routine_description: varchar(255) # description of this QC routine - qc_module: varchar(64) # the module where the qc_function can be imported from - e.g. aeon.analysis.quality_control + qc_module: varchar(64) # module path, e.g., aeon.analysis.quality_control qc_function: varchar(64) # the function used to evaluate this QC - e.g. check_drop_frame """ @@ -38,7 +39,7 @@ class QCRoutine(dj.Lookup): @schema class CameraQC(dj.Imported): - definition = """ # Quality controls performed on a particular camera for a particular acquisition chunk + definition = """ # Quality controls performed on a particular camera for one acquisition chunk -> acquisition.Chunk -> streams.SpinnakerVideoSource --- @@ -55,6 +56,7 @@ class CameraQC(dj.Imported): @property def key_source(self): + """Return the keys for the CameraQC table.""" return ( acquisition.Chunk * ( @@ -66,6 +68,7 @@ def key_source(self): ) # CameraTop def make(self, key): + """Perform quality control checks on the CameraTop stream.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name") diff --git a/aeon/dj_pipeline/report.py b/aeon/dj_pipeline/report.py index ec88ae7c..6b1bce02 100644 --- a/aeon/dj_pipeline/report.py +++ b/aeon/dj_pipeline/report.py @@ -1,3 +1,5 @@ +"""DataJoint schema dedicated for tables containing figures.""" + import datetime import json import os @@ -10,20 +12,16 @@ from aeon.analysis import plotting as analysis_plotting from aeon.dj_pipeline.analysis.visit import Visit, VisitEnd -from aeon.dj_pipeline.analysis.visit_analysis import * -from . import acquisition, analysis, get_schema_name +from . import acquisition, analysis + +logger = dj.logger # schema = dj.schema(get_schema_name("report")) schema = dj.schema() os.environ["DJ_SUPPORT_FILEPATH_MANAGEMENT"] = "TRUE" -""" - DataJoint schema dedicated for tables containing figures -""" - - @schema class InArenaSummaryPlot(dj.Computed): definition = """ @@ -44,6 +42,7 @@ class InArenaSummaryPlot(dj.Computed): } def make(self, key): + """Make method for InArenaSummaryPlot table.""" in_arena_start, in_arena_end = (analysis.InArena * analysis.InArenaEnd & key).fetch1( "in_arena_start", "in_arena_end" ) @@ -192,7 +191,7 @@ def make(self, key): alpha=0.6, label="nest", ) - for patch_idx, (patch_name, in_patch) in enumerate(zip(patch_names, in_patches)): + for patch_idx, (patch_name, in_patch) in enumerate(zip(patch_names, in_patches, strict=False)): ethogram_ax.plot( position_minutes_elapsed[in_patch], np.full_like(position_minutes_elapsed[in_patch], (patch_idx + 3)), @@ -276,12 +275,13 @@ class SubjectRewardRateDifference(dj.Computed): -> acquisition.Experiment.Subject --- in_arena_count: int - reward_rate_difference_plotly: longblob # dictionary storing the plotly object (from fig.to_plotly_json()) + reward_rate_difference_plotly: longblob # dict storing the plotly object (from fig.to_plotly_json()) """ key_source = acquisition.Experiment.Subject & analysis.InArenaRewardRate def make(self, key): + """Insert reward rate differences plot in SubjectRewardRateDifference table.""" from aeon.dj_pipeline.utils.plotting import plot_reward_rate_differences fig = plot_reward_rate_differences(key) @@ -298,7 +298,14 @@ def make(self, key): @classmethod def delete_outdated_entries(cls): - """Each entry in this table correspond to one subject. However, the plot is capturing data for all sessions.Hence a dynamic update routine is needed to recompute the plot as new sessions become available.""" + """Dynamically update the plot for all sessions. + + Each entry in this table correspond to one subject. + However, the plot is capturing data for all sessions. + Hence a dynamic update routine is needed to recompute + the plot as new sessions become available. + + """ outdated_entries = ( cls * ( @@ -319,12 +326,13 @@ class SubjectWheelTravelledDistance(dj.Computed): -> acquisition.Experiment.Subject --- in_arena_count: int - wheel_travelled_distance_plotly: longblob # dictionary storing the plotly object (from fig.to_plotly_json()) + wheel_travelled_distance_plotly: longblob # dict storing the plotly object (from fig.to_plotly_json()) """ key_source = acquisition.Experiment.Subject & analysis.InArenaSummary def make(self, key): + """Insert wheel travelled distance plot in SubjectWheelTravelledDistance table.""" from aeon.dj_pipeline.utils.plotting import plot_wheel_travelled_distance in_arena_keys = (analysis.InArenaSummary & key).fetch("KEY") @@ -343,7 +351,13 @@ def make(self, key): @classmethod def delete_outdated_entries(cls): - """Each entry in this table correspond to one subject. However the plot is capturing data for all sessions. Hence a dynamic update routine is needed to recompute the plot as new sessions become available.""" + """Dynamically update the plot for all sessions. + + Each entry in this table correspond to one subject. + However the plot is capturing data for all sessions. + Hence a dynamic update routine is needed to recompute + the plot as new sessions become available. + """ outdated_entries = ( cls * ( @@ -368,6 +382,7 @@ class ExperimentTimeDistribution(dj.Computed): """ def make(self, key): + """Insert average time distribution plot into ExperimentTimeDistribution table.""" from aeon.dj_pipeline.utils.plotting import plot_average_time_distribution in_arena_keys = (analysis.InArenaTimeDistribution & key).fetch("KEY") @@ -386,7 +401,13 @@ def make(self, key): @classmethod def delete_outdated_entries(cls): - """Each entry in this table correspond to one subject. However the plot is capturing data for all sessions. Hence a dynamic update routine is needed to recompute the plot as new sessions become available.""" + """Dynamically update the plot for all sessions. + + Each entry in this table correspond to one subject. + However the plot is capturing data for all sessions. + Hence a dynamic update routine is needed to recompute + the plot as new sessions become available. + """ outdated_entries = ( cls * ( @@ -402,6 +423,7 @@ def delete_outdated_entries(cls): def delete_outdated_plot_entries(): + """Delete outdated entries in the tables that store plots.""" for tbl in ( SubjectRewardRateDifference, SubjectWheelTravelledDistance, @@ -415,7 +437,7 @@ class VisitDailySummaryPlot(dj.Computed): definition = """ -> Visit --- - pellet_count_plotly: longblob # Dictionary storing the plotly object (from fig.to_plotly_json()) + pellet_count_plotly: longblob # Dict storing the plotly object (from fig.to_plotly_json()) wheel_distance_travelled_plotly: longblob total_distance_travelled_plotly: longblob weight_patch_plotly: longblob @@ -431,6 +453,7 @@ class VisitDailySummaryPlot(dj.Computed): ) def make(self, key): + """Make method for VisitDailySummaryPlot table.""" from aeon.dj_pipeline.utils.plotting import ( plot_foraging_bouts_count, plot_foraging_bouts_distribution, @@ -530,6 +553,7 @@ def make(self, key): def _make_path(in_arena_key): + """Make path for saving figures.""" store_stage = pathlib.Path(dj.config["stores"]["djstore"]["stage"]) experiment_name, subject, in_arena_start = (analysis.InArena & in_arena_key).fetch1( "experiment_name", "subject", "in_arena_start" @@ -540,8 +564,9 @@ def _make_path(in_arena_key): def _save_figs(figs, fig_names, save_dir, prefix, extension=".png"): + """Save figures and return a dictionary with figure names and file paths.""" fig_dict = {} - for fig, figname in zip(figs, fig_names): + for fig, figname in zip(figs, fig_names, strict=False): fig_fp = save_dir / (prefix + "_" + figname + extension) fig.tight_layout() fig.savefig(fig_fp, dpi=300) diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py index 593bc7fe..6e906f92 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py @@ -1,6 +1,4 @@ -"""March 2022 -Cloning and archiving schemas and data for experiment 0.1. -""" +"""March 2022. Cloning and archiving schemas and data for experiment 0.1.""" import os @@ -26,6 +24,7 @@ def clone_pipeline(): + """Clone the pipeline for experiment 0.1.""" diagram = None for orig_schema_name in schema_name_mapper: virtual_module = dj.create_virtual_module(orig_schema_name, orig_schema_name) @@ -39,6 +38,7 @@ def clone_pipeline(): def data_copy(restriction, table_block_list, batch_size=None): + """Migrate schema.""" for orig_schema_name, cloned_schema_name in schema_name_mapper.items(): orig_schema = dj.create_virtual_module(orig_schema_name, orig_schema_name) cloned_schema = dj.create_virtual_module(cloned_schema_name, cloned_schema_name) diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py index 740932b2..82d7815c 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py @@ -1,4 +1,4 @@ -"""Jan 2024: Cloning and archiving schemas and data for experiment 0.2. +"""Jan 2024. Cloning and archiving schemas and data for experiment 0.2. The pipeline code associated with this archived data pipeline is here https://github.com/SainsburyWellcomeCentre/aeon_mecha/releases/tag/dj_exp02_stable @@ -30,6 +30,7 @@ def clone_pipeline(): + """Clone the pipeline for experiment 0.2.""" diagram = None for orig_schema_name in schema_name_mapper: virtual_module = dj.create_virtual_module(orig_schema_name, orig_schema_name) @@ -43,6 +44,7 @@ def clone_pipeline(): def data_copy(restriction, table_block_list, batch_size=None): + """Migrate schema.""" for orig_schema_name, cloned_schema_name in schema_name_mapper.items(): orig_schema = dj.create_virtual_module(orig_schema_name, orig_schema_name) cloned_schema = dj.create_virtual_module(cloned_schema_name, cloned_schema_name) diff --git a/aeon/dj_pipeline/scripts/reingest_fullpose_sleap_data.py b/aeon/dj_pipeline/scripts/reingest_fullpose_sleap_data.py index b3586f82..9411064b 100644 --- a/aeon/dj_pipeline/scripts/reingest_fullpose_sleap_data.py +++ b/aeon/dj_pipeline/scripts/reingest_fullpose_sleap_data.py @@ -1,4 +1,7 @@ +"""Functions to find and delete orphaned epochs that have been ingested but are no longer valid.""" + from datetime import datetime + from aeon.dj_pipeline import acquisition, tracking aeon_schemas = acquisition.aeon_schemas @@ -8,11 +11,10 @@ def find_chunks_to_reingest(exp_key, delete_not_fullpose=False): - """ - Find chunks with newly available full pose data to reingest. + """Find chunks with newly available full pose data to reingest. + If available, fullpose data can be found in `processed` folder """ - device_name = "CameraTop" devices_schema = getattr( @@ -21,13 +23,14 @@ def find_chunks_to_reingest(exp_key, delete_not_fullpose=False): "devices_schema_name" ), ) - stream_reader = getattr(getattr(devices_schema, device_name), "Pose") + stream_reader = getattr(devices_schema, device_name).Pose # special ingestion case for social0.2 full-pose data (using Pose reader from social03) if exp_key["experiment_name"].startswith("social0.2"): from aeon.io import reader as io_reader - stream_reader = getattr(getattr(devices_schema, device_name), "Pose03") - assert isinstance(stream_reader, io_reader.Pose), "Pose03 is not a Pose reader" + stream_reader = getattr(devices_schema, device_name).Pose03 + if not isinstance(stream_reader, io_reader.Pose): + raise TypeError("Pose03 is not a Pose reader") # find processed path for exp_key processed_dir = acquisition.Experiment.get_data_directory(exp_key, "processed") diff --git a/aeon/dj_pipeline/scripts/sync_ingested_and_raw_epochs.py b/aeon/dj_pipeline/scripts/sync_ingested_and_raw_epochs.py index 186355ce..3b247b7c 100644 --- a/aeon/dj_pipeline/scripts/sync_ingested_and_raw_epochs.py +++ b/aeon/dj_pipeline/scripts/sync_ingested_and_raw_epochs.py @@ -1,6 +1,9 @@ -import datajoint as dj +"""Functions to find and delete orphaned epochs that have been ingested but are no longer valid.""" + from datetime import datetime +import datajoint as dj + from aeon.dj_pipeline import acquisition, streams from aeon.dj_pipeline.analysis import block_analysis @@ -11,8 +14,8 @@ def find_orphaned_ingested_epochs(exp_key, delete_invalid_epochs=False): - """ - Find ingested epochs that are no longer valid + """Find ingested epochs that are no longer valid. + This is due to the raw epoch/chunk files/directories being deleted for whatever reason (e.g. bad data, testing, etc.) """ diff --git a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py index 4182c2e5..31ee8109 100644 --- a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py +++ b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py @@ -1,6 +1,4 @@ -"""July 2022 -Upgrade all timestamps longblob fields with datajoint 0.13.7. -""" +"""July 2022. Upgrade all timestamps longblob fields with datajoint 0.13.7.""" from datetime import datetime @@ -8,7 +6,11 @@ import numpy as np from tqdm import tqdm -assert dj.__version__ >= "0.13.7" +logger = dj.logger + + +if dj.__version__ < "0.13.7": + raise ImportError(f"DataJoint version must be at least 0.13.7, but found {dj.__version__}.") schema = dj.schema("u_thinh_aeonfix") @@ -32,6 +34,7 @@ class TimestampFix(dj.Manual): def main(): + """Update all timestamps longblob fields in the specified schemas.""" for schema_name in schema_names: vm = dj.create_virtual_module(schema_name, schema_name) table_names = [ @@ -54,7 +57,10 @@ def main(): if not len(ts) or isinstance(ts[0], np.datetime64): TimestampFix.insert1(fix_key) continue - assert isinstance(ts[0], datetime) + if not isinstance(ts[0], datetime): + raise TypeError( + f"Expected ts[0] to be of type 'datetime', but got {type(ts[0])}." + ) with table.connection.transaction: table.update1( { @@ -66,6 +72,7 @@ def main(): def get_table(schema_object, table_object_name): + """Get the table object from the schema object.""" if "." in table_object_name: master_name, part_name = table_object_name.split(".") return getattr(getattr(schema_object, master_name), part_name) diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index 18aee718..3ff95770 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -1,7 +1,9 @@ +"""DataJoint schema for animal subjects.""" + import json import os import time -from datetime import datetime, timedelta +from datetime import UTC, datetime, timedelta import datajoint as dj import requests @@ -56,6 +58,7 @@ class SubjectDetail(dj.Imported): """ def make(self, key): + """Automatically import and update entries in the Subject table.""" eartag_or_id = key["subject"] # cage id, sex, line/strain, genetic background, dob, lab id params = { @@ -175,6 +178,7 @@ class SubjectReferenceWeight(dj.Manual): @classmethod def get_reference_weight(cls, subject_name): + """Get the reference weight for the subject.""" subj_key = {"subject": subject_name} food_restrict_query = SubjectProcedure & subj_key & "procedure_name = 'R02 - food restriction'" @@ -183,7 +187,7 @@ def get_reference_weight(cls, subject_name): 0 ] else: - ref_date = datetime.now().date() + ref_date = datetime.now(UTC).date() weight_query = SubjectWeight & subj_key & f"weight_time < '{ref_date}'" ref_weight = ( @@ -193,7 +197,7 @@ def get_reference_weight(cls, subject_name): entry = { "subject": subject_name, "reference_weight": ref_weight, - "last_updated_time": datetime.utcnow(), + "last_updated_time": datetime.now(UTC), } cls.update1(entry) if cls & {"subject": subject_name} else cls.insert1(entry) @@ -235,7 +239,8 @@ class PyratIngestion(dj.Imported): schedule_interval = 12 # schedule interval in number of hours def _auto_schedule(self): - utc_now = datetime.utcnow() + """Automatically schedule the next task.""" + utc_now = datetime.now(UTC) next_task_schedule_time = utc_now + timedelta(hours=self.schedule_interval) if ( @@ -247,8 +252,8 @@ def _auto_schedule(self): PyratIngestionTask.insert1({"pyrat_task_scheduled_time": next_task_schedule_time}) def make(self, key): - execution_time = datetime.utcnow() """Automatically import or update entries in the Subject table.""" + execution_time = datetime.now(UTC) new_eartags = [] for responsible_id in lab.User.fetch("responsible_id"): # 1 - retrieve all animals from this user @@ -283,7 +288,7 @@ def make(self, key): new_entry_count += 1 logger.info(f"Inserting {new_entry_count} new subject(s) from Pyrat") - completion_time = datetime.utcnow() + completion_time = datetime.now(UTC) self.insert1( { **key, @@ -313,7 +318,8 @@ class PyratCommentWeightProcedure(dj.Imported): key_source = (PyratIngestion * SubjectDetail) & "available = 1" def make(self, key): - execution_time = datetime.utcnow() + """Automatically import or update entries in the PyratCommentWeightProcedure table.""" + execution_time = datetime.now(UTC) logger.info("Extracting weights/comments/procedures") eartag_or_id = key["subject"] @@ -366,8 +372,7 @@ def make(self, key): "lab_id": animal_resp["labid"], } ) - - completion_time = datetime.utcnow() + completion_time = datetime.now(UTC) self.insert1( { **key, @@ -385,7 +390,7 @@ class CreatePyratIngestionTask(dj.Computed): def make(self, key): """Create one new PyratIngestionTask for every newly added users.""" - PyratIngestionTask.insert1({"pyrat_task_scheduled_time": datetime.utcnow()}) + PyratIngestionTask.insert1({"pyrat_task_scheduled_time": datetime.now(UTC)}) time.sleep(1) self.insert1(key) @@ -455,8 +460,8 @@ def make(self, key): def get_pyrat_data(endpoint: str, params: dict = None, **kwargs): - """ - Get data from PyRat API. + """Get data from PyRat API. + See docs at: https://swc.pyrat.cloud/api/v3/docs (production) """ base_url = "https://swc.pyrat.cloud/api/v3/" @@ -465,7 +470,8 @@ def get_pyrat_data(endpoint: str, params: dict = None, **kwargs): if pyrat_system_token is None or pyrat_user_token is None: raise ValueError( - "The PYRAT tokens must be defined as an environment variable named 'PYRAT_SYSTEM_TOKEN' and 'PYRAT_USER_TOKEN'" + "The PYRAT tokens must be defined as an environment \ + variable named 'PYRAT_SYSTEM_TOKEN' and 'PYRAT_USER_TOKEN'" ) session = requests.Session() @@ -474,7 +480,7 @@ def get_pyrat_data(endpoint: str, params: dict = None, **kwargs): if params is not None: params_str_list = [] for k, v in params.items(): - if isinstance(v, (list, tuple)): + if isinstance(v, (list | tuple)): for i in v: params_str_list.append(f"{k}={i}") else: @@ -485,7 +491,9 @@ def get_pyrat_data(endpoint: str, params: dict = None, **kwargs): response = session.get(base_url + endpoint + params_str, **kwargs) - if response.status_code != 200: + RESPONSE_STATUS_CODE_OK = 200 + + if response.status_code != RESPONSE_STATUS_CODE_OK: raise requests.exceptions.HTTPError( f"PyRat API errored out with response code: {response.status_code}" ) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 584e0384..559d08fb 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -1,16 +1,17 @@ -from pathlib import Path +"""DataJoint schema for tracking data.""" import datajoint as dj import matplotlib.path import numpy as np import pandas as pd -from aeon.dj_pipeline import acquisition, dict_to_uuid, get_schema_name, lab, qc, streams, fetch_stream +from aeon.dj_pipeline import acquisition, dict_to_uuid, fetch_stream, get_schema_name, lab, streams from aeon.io import api as io_api aeon_schemas = acquisition.aeon_schemas schema = dj.schema(get_schema_name("tracking")) +logger = dj.logger pixel_scale = 0.00192 # 1 px = 1.92 mm arena_center_x, arena_center_y = 1.475, 1.075 # center @@ -72,6 +73,7 @@ def insert_new_params( params: dict, tracking_paramset_id: int = None, ): + """Insert a new set of parameters for a given tracking method.""" if tracking_paramset_id is None: tracking_paramset_id = (dj.U().aggr(cls, n="max(tracking_paramset_id)").fetch1("n") or 0) + 1 @@ -109,7 +111,9 @@ def insert_new_params( @schema class SLEAPTracking(dj.Imported): - definition = """ # Tracked objects position data from a particular VideoSource for multi-animal experiment using the SLEAP tracking method per chunk + """Tracking data from SLEAP for multi-animal experiments.""" + + definition = """ # Position data from a VideoSource for multi-animal experiments using SLEAP per chunk -> acquisition.Chunk -> streams.SpinnakerVideoSource -> TrackingParamSet @@ -139,6 +143,7 @@ class Part(dj.Part): @property def key_source(self): + """Return the keys to be processed.""" return ( acquisition.Chunk * ( @@ -151,6 +156,7 @@ def key_source(self): ) # SLEAP & CameraTop def make(self, key): + """Ingest SLEAP tracking data for a given chunk.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -164,14 +170,15 @@ def make(self, key): ), ) - stream_reader = getattr(getattr(devices_schema, device_name), "Pose") + stream_reader = getattr(devices_schema, device_name).Pose # special ingestion case for social0.2 full-pose data (using Pose reader from social03) # fullpose for social0.2 has a different "pattern" for non-fullpose, hence the Pose03 reader if key["experiment_name"].startswith("social0.2"): from aeon.io import reader as io_reader - stream_reader = getattr(getattr(devices_schema, device_name), "Pose03") - assert isinstance(stream_reader, io_reader.Pose), "Pose03 is not a Pose reader" + stream_reader = getattr(devices_schema, device_name).Pose03 + if not isinstance(stream_reader, io_reader.Pose): + raise TypeError("Pose03 is not a Pose reader") data_dirs = [acquisition.Experiment.get_data_directory(key, "processed")] pose_data = io_api.load( @@ -251,7 +258,7 @@ class BlobPosition(dj.Imported): class Object(dj.Part): definition = """ # Position data of object tracked by a particular camera tracking -> master - object_id: int # object with id = -1 means "unknown/not sure", could potentially be the same object as those with other id value + object_id: int # id=-1 means "unknown"; could be the same object as those with other values --- identity_name='': varchar(16) sample_count: int # number of data points acquired from this stream for a given chunk @@ -263,6 +270,7 @@ class Object(dj.Part): @property def key_source(self): + """Return the keys to be processed.""" ks = ( acquisition.Chunk * ( @@ -275,6 +283,7 @@ def key_source(self): return ks - SLEAPTracking # do this only when SLEAPTracking is not available def make(self, key): + """Ingest blob position data for a given chunk.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -353,13 +362,31 @@ def make(self, key): def compute_distance(position_df, target, xcol="x", ycol="y"): - assert len(target) == 2 + """Compute the distance between the position and the target. + + Args: + position_df (pd.DataFrame): DataFrame containing the position data. + target (tuple): Tuple of length 2 indicating the target x and y position. + xcol (str): x column name in ``position_df``. Default is 'x'. + ycol (str): y column name in ``position_df``. Default is 'y'. + """ + COORDS = 2 # x, y + if len(target) != COORDS: + raise ValueError("Target must be a list of tuple of length 2.") return np.sqrt(np.square(position_df[[xcol, ycol]] - target).sum(axis=1)) def is_position_in_patch( position_df, patch_position, wheel_distance_travelled, patch_radius=0.2 ) -> pd.Series: + """Returns a boolean array of whether a given position is inside the patch and the wheel is moving. + + Args: + position_df (pd.DataFrame): DataFrame containing the position data. + patch_position (tuple): Tuple of length 2 indicating the patch x and y position. + wheel_distance_travelled (pd.Series): distance travelled by the wheel. + patch_radius (float): Radius of the patch. Default is 0.2. + """ distance_from_patch = compute_distance(position_df, patch_position) in_patch = distance_from_patch < patch_radius exit_patch = in_patch.astype(np.int8).diff() < 0 @@ -371,10 +398,17 @@ def is_position_in_patch( def is_position_in_nest(position_df, nest_key, xcol="x", ycol="y") -> pd.Series: - """Given the session key and the position data - arrays of x and y + """Check if a position is inside the nest. + + Notes: Given the session key and the position data - arrays of x and y return an array of boolean indicating whether or not a position is inside the nest. """ - nest_vertices = list(zip(*(lab.ArenaNest.Vertex & nest_key).fetch("vertex_x", "vertex_y"))) + nest_vertices = list( + zip( + *(lab.ArenaNest.Vertex & nest_key).fetch("vertex_x", "vertex_y"), + strict=False, + ) + ) nest_path = matplotlib.path.Path(nest_vertices) position_df["in_nest"] = nest_path.contains_points(position_df[[xcol, ycol]]) return position_df["in_nest"] @@ -392,6 +426,7 @@ def _get_position( attrs_to_scale: list, scale_factor=1.0, ): + """Get the position data for a given object between the specified time range.""" obj_restriction = {object_attr: object_name} start_restriction = f'"{start}" BETWEEN {start_attr} AND {end_attr}' @@ -419,7 +454,7 @@ def _get_position( position = pd.DataFrame( { k: np.hstack(v) * scale_factor if k in attrs_to_scale else np.hstack(v) - for k, v in zip(fetch_attrs, fetched_data) + for k, v in zip(fetch_attrs, fetched_data, strict=False) } ) position.set_index(timestamp_attr, inplace=True) diff --git a/aeon/dj_pipeline/utils/__init__.py b/aeon/dj_pipeline/utils/__init__.py index e69de29b..0bb46925 100644 --- a/aeon/dj_pipeline/utils/__init__.py +++ b/aeon/dj_pipeline/utils/__init__.py @@ -0,0 +1 @@ +"""Helper functions and utilities for the Aeon project.""" diff --git a/aeon/dj_pipeline/utils/load_metadata.py b/aeon/dj_pipeline/utils/load_metadata.py index ce0a248e..e30f91ca 100644 --- a/aeon/dj_pipeline/utils/load_metadata.py +++ b/aeon/dj_pipeline/utils/load_metadata.py @@ -1,9 +1,12 @@ +"""Load metadata from the experiment and insert into streams schema.""" + import datetime import inspect import json import pathlib from collections import defaultdict from pathlib import Path + import datajoint as dj import numpy as np from dotmap import DotMap @@ -35,13 +38,16 @@ def insert_stream_types(): existing_stream = (streams.StreamType.proj( "stream_reader", "stream_reader_kwargs") & {"stream_type": entry["stream_type"]}).fetch1() - if existing_stream["stream_reader_kwargs"].get("columns") != entry["stream_reader_kwargs"].get( - "columns"): + existing_columns = existing_stream["stream_reader_kwargs"].get("columns") + entry_columns = entry["stream_reader_kwargs"].get("columns") + if existing_columns != entry_columns: logger.warning(f"Stream type already exists:\n\t{entry}\n\t{existing_stream}") def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): - """Use aeon.schema.schemas and metadata.yml to insert into streams.DeviceType and streams.Device. + """Insert device types into streams.DeviceType and streams.Device. + + Notes: Use aeon.schema.schemas and metadata.yml to insert into streams.DeviceType and streams.Device. Only insert device types that were defined both in the device schema (e.g., exp02) and Metadata.yml. It then creates new device tables under streams schema. """ @@ -138,13 +144,15 @@ def extract_epoch_config(experiment_name: str, devices_schema: DotMap, metadata_ if isinstance(commit, float) and np.isnan(commit): commit = epoch_config["metadata"]["Revision"] - assert commit, f'Neither "Commit" nor "Revision" found in {metadata_yml_filepath}' + if not commit: + raise ValueError(f'Neither "Commit" nor "Revision" found in {metadata_yml_filepath}') devices: list[dict] = json.loads( json.dumps(epoch_config["metadata"]["Devices"], default=lambda x: x.__dict__, indent=4) ) - # Maintain backward compatibility - In exp02, it is a list of dict. From presocial onward, it's a dict of dict. + # Maintain backward compatibility - In exp02, it is a list of dict. + # From presocial onward, it's a dict of dict. if isinstance(devices, list): devices: dict = {d.pop("Name"): d for d in devices} # {deivce_name: device_config} @@ -199,7 +207,9 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath if not (streams.Device & device_key): logger.warning( - f"Device {device_name} (serial number: {device_sn}) is not yet registered in streams.Device.\nThis should not happen - check if metadata.yml and schemas dotmap are consistent. Skipping..." + f"Device {device_name} (serial number: {device_sn}) is not \ + yet registered in streams.Device.\nThis should not happen - \ + check if metadata.yml and schemas dotmap are consistent. Skipping..." ) # skip if this device (with a serial number) is not yet inserted in streams.Device continue @@ -231,7 +241,9 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath } ) - """Check if this device is currently installed. If the same device serial number is currently installed check for any changes in configuration. If not, skip this""" + # Check if this device is currently installed. + # If the same device serial number is currently installed check for changes in configuration. + # If not, skip this. current_device_query = table - table.RemovalTime & experiment_key & device_key if current_device_query: @@ -349,6 +361,7 @@ def get_device_info(devices_schema: DotMap) -> dict[dict]: """ def _get_class_path(obj): + """Returns the class path of the object.""" return f"{obj.__class__.__module__}.{obj.__class__.__name__}" schema_json = json.dumps(devices_schema, default=lambda x: x.__dict__, indent=4) @@ -401,7 +414,12 @@ def _get_class_path(obj): def get_device_mapper(devices_schema: DotMap, metadata_yml_filepath: Path): - """Returns a mapping dictionary between device name and device type based on the dataset schema and metadata.yml from the experiment. Store the mapper dictionary and read from it if the type info doesn't exist in Metadata.yml. + """Returns a mapping dictionary of device names to types based on the dataset schema and metadata.yml. + + Notes: Returns a mapping dictionary between device name and device type + based on the dataset schema and metadata.yml from the experiment. + Store the mapper dictionary and read from it if the type info doesn't + exist in Metadata.yml. Args: devices_schema (DotMap): DotMap object (e.g., exp02) @@ -441,7 +459,8 @@ def get_device_mapper(devices_schema: DotMap, metadata_yml_filepath: Path): device_type_mapper[item.Name] = item.Type device_sn[item.Name] = ( item.SerialNumber or item.PortName or None - ) # assign either the serial number (if it exists) or port name. If neither exists, assign None + ) # assign either the serial number (if it exists) or port name. + # If neither exists, assign None elif isinstance(item, str): # presocial if meta_data.Devices[item].get("Type"): device_type_mapper[item] = meta_data.Devices[item].get("Type") diff --git a/aeon/dj_pipeline/utils/paths.py b/aeon/dj_pipeline/utils/paths.py index 63a13a1f..75b459d4 100644 --- a/aeon/dj_pipeline/utils/paths.py +++ b/aeon/dj_pipeline/utils/paths.py @@ -1,3 +1,5 @@ +"""Utility functions for working with paths in the context of the DJ pipeline.""" + from __future__ import annotations import pathlib @@ -32,7 +34,7 @@ def get_repository_path(repository_name: str) -> pathlib.Path: def find_root_directory( root_directories: str | pathlib.Path, full_path: str | pathlib.Path ) -> pathlib.Path: - """Given multiple potential root directories and a full-path, search and return one directory that is the parent of the given path. + """Finds the parent directory of a given full path among multiple potential root directories. Args: root_directories (str | pathlib.Path): A list of potential root directories. @@ -51,7 +53,7 @@ def find_root_directory( raise FileNotFoundError(f"{full_path} does not exist!") # turn to list if only a single root directory is provided - if isinstance(root_directories, (str, pathlib.Path)): + if isinstance(root_directories, (str | pathlib.Path)): root_directories = [root_directories] try: @@ -61,7 +63,7 @@ def find_root_directory( if pathlib.Path(root_dir) in set(full_path.parents) ) - except StopIteration: + except StopIteration as err: raise FileNotFoundError( f"No valid root directory found (from {root_directories})" f" for {full_path}" - ) + ) from err diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index 01cd14e7..5a160b70 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -1,3 +1,5 @@ +"""Utility functions for plotting visit data.""" + import datajoint as dj import numpy as np import pandas as pd @@ -23,14 +25,16 @@ def plot_reward_rate_differences(subject_keys): - """Plotting the reward rate differences between food patches (Patch 2 - Patch 1) for all sessions from all subjects specified in "subject_keys". + """Plots the reward rate differences between two food patches (Patch 2 - Patch 1). - Examples: - ``` - subject_keys = (acquisition.Experiment.Subject & 'experiment_name = "exp0.1-r0"').fetch('KEY') + The reward rate differences between the two food patches are plotted + for all sessions from all subjects in ``subject_keys``. - fig = plot_reward_rate_differences(subject_keys) - ``` + Examples: + >>> subject_keys = ( + ... acquisition.Experiment.Subject + ... & 'experiment_name = "exp0.1-r0"').fetch('KEY') + >>> fig = plot_reward_rate_differences(subject_keys) """ subj_names, sess_starts, rate_timestamps, rate_diffs = ( analysis.InArenaRewardRate & subject_keys @@ -45,7 +49,7 @@ def plot_reward_rate_differences(subject_keys): y_labels = [ f'{subj_name}_{sess_start.strftime("%m/%d/%Y")}' - for subj_name, sess_start in zip(subj_names, sess_starts) + for subj_name, sess_start in zip(subj_names, sess_starts, strict=False) ] rateDiffs_matrix = np.full((nSessions, longest_rateDiff), np.nan) @@ -74,7 +78,7 @@ def plot_reward_rate_differences(subject_keys): def plot_wheel_travelled_distance(session_keys): - """Plotting the wheel travelled distance for different patches for all sessions specified in "session_keys". + """Plot wheel-travelled-distance for different patches for all sessions specified in session_keys. Examples: ``` @@ -98,7 +102,7 @@ def plot_wheel_travelled_distance(session_keys): distance_travelled_df["in_arena"] = [ f'{subj_name}_{sess_start.strftime("%m/%d/%Y")}' for subj_name, sess_start in zip( - distance_travelled_df.subject, distance_travelled_df.in_arena_start + distance_travelled_df.subject, distance_travelled_df.in_arena_start, strict=False ) ] @@ -124,6 +128,7 @@ def plot_wheel_travelled_distance(session_keys): def plot_average_time_distribution(session_keys): + """Plots the average time spent in different regions.""" subject_list, arena_location_list, avg_time_spent_list = [], [], [] # Time spent in arena and corridor @@ -205,15 +210,21 @@ def plot_visit_daily_summary( Args: visit_key (dict) : Key from the VisitSummary table - attr (str): Name of the attribute to plot (e.g., 'pellet_count', 'wheel_distance_travelled', 'total_distance_travelled') - per_food_patch (bool, optional): Separately plot results from different food patches. Defaults to False. + attr (str): Name of the attribute to plot (e.g., 'pellet_count', + 'wheel_distance_travelled', 'total_distance_travelled') + per_food_patch (bool, optional): Separately plot results from + different food patches. Defaults to False. Returns: fig: Figure object Examples: >>> fig = plot_visit_daily_summary(visit_key, attr='pellet_count', per_food_patch=True) - >>> fig = plot_visit_daily_summary(visit_key, attr='wheel_distance_travelled', per_food_patch=True) + >>> fig = plot_visit_daily_summary( + ... visit_key, + ... attr="wheel_distance_travelled" + ... per_food_patch=True, + ... ) >>> fig = plot_visit_daily_summary(visit_key, attr='total_distance_travelled') """ per_food_patch = not attr.startswith("total") @@ -282,8 +293,10 @@ def plot_foraging_bouts_count( Args: visit_key (dict): Key from the Visit table - freq (str): Frequency level at which the visit time distribution is plotted. Corresponds to pandas freq. - per_food_patch (bool, optional): Separately plot results from different food patches. Defaults to False. + freq (str): Frequency level at which the visit time + distribution is plotted. Corresponds to pandas freq. + per_food_patch (bool, optional): Separately plot results from + different food patches. Defaults to False. min_bout_duration (int): Minimum foraging bout duration (in seconds) min_pellet_count (int): Minimum number of pellets min_wheel_dist (int): Minimum wheel distance travelled (in cm) @@ -292,7 +305,13 @@ def plot_foraging_bouts_count( fig: Figure object Examples: - >>> fig = plot_foraging_bouts_count(visit_key, freq="D", per_food_patch=True, min_bout_duration=1, min_wheel_dist=1) + >>> fig = plot_foraging_bouts_count( + ... visit_key, + ... freq="D", + ... per_food_patch=True, + ... min_bout_duration=1, + ... min_wheel_dist=1 + ... ) """ # Get all foraging bouts for the visit foraging_bouts = ( @@ -373,8 +392,10 @@ def plot_foraging_bouts_distribution( Args: visit_key (dict): Key from the Visit table - attr (str): Options include: pellet_count, bout_duration, wheel_distance_travelled - per_food_patch (bool, optional): Separately plot results from different food patches. Defaults to False. + attr (str): Options include: pellet_count, bout_duration, + wheel_distance_travelled + per_food_patch (bool, optional): Separately plot results from + different food patches. Defaults to False. min_bout_duration (int): Minimum foraging bout duration (in seconds) min_pellet_count (int): Minimum number of pellets min_wheel_dist (int): Minimum wheel distance travelled (in cm) @@ -460,7 +481,8 @@ def plot_visit_time_distribution(visit_key, freq="D"): Args: visit_key (dict): Key from the Visit table - freq (str): Frequency level at which the visit time distribution is plotted. Corresponds to pandas freq. + freq (str): Frequency level at which the visit time distribution + is plotted. Corresponds to pandas freq. Returns: fig: Figure object @@ -511,16 +533,21 @@ def plot_visit_time_distribution(visit_key, freq="D"): return fig -def _get_region_data(visit_key, attrs=["in_nest", "in_arena", "in_corridor", "in_patch"]): +def _get_region_data(visit_key, attrs=None): """Retrieve region data from VisitTimeDistribution tables. Args: visit_key (dict): Key from the Visit table - attrs (list, optional): List of column names (in VisitTimeDistribution tables) to retrieve. Defaults to all. + attrs (list, optional): List of column names (in VisitTimeDistribution tables) to retrieve. + If unspecified, defaults to `None` and ``["in_nest", "in_arena", "in_corridor", "in_patch"]`` + is used. Returns: region (pd.DataFrame): Timestamped region info """ + if attrs is None: + attrs = ["in_nest", "in_arena", "in_corridor", "in_patch"] + visit_start, visit_end = (VisitEnd & visit_key).fetch1("visit_start", "visit_end") region = pd.DataFrame() diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index bfd669e9..f04af930 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -1,3 +1,5 @@ +"""Module for stream-related tables in the analysis schema.""" + import importlib import inspect import re @@ -22,12 +24,18 @@ class StreamType(dj.Lookup): - """Catalog of all steam types for the different device types used across Project Aeon. One StreamType corresponds to one reader class in `aeon.io.reader`. The combination of `stream_reader` and `stream_reader_kwargs` should fully specify the data loading routine for a particular device, using the `aeon.io.utils`.""" + """Catalog of all stream types used across Project Aeon. + + Catalog of all stream types for the different device types used across Project Aeon. + One StreamType corresponds to one Reader class in :mod:`aeon.io.reader`. + The combination of ``stream_reader`` and ``stream_reader_kwargs`` should fully specify the data + loading routine for a particular device, using :func:`aeon.io.api.load`. + """ - definition = """ # Catalog of all stream types used across Project Aeon + definition = """ # Catalog of all stream types used across Project Aeon stream_type : varchar(36) --- - stream_reader : varchar(256) # name of the reader class found in `aeon_mecha` package (e.g. aeon.io.reader.Video) + stream_reader : varchar(256) # reader class name in aeon.io.reader (e.g. aeon.io.reader.Video) stream_reader_kwargs : longblob # keyword arguments to instantiate the reader class stream_description='': varchar(256) stream_hash : uuid # hash of dict(stream_reader_kwargs, stream_reader=stream_reader) @@ -67,17 +75,16 @@ def get_device_template(device_type: str): device_type = dj.utils.from_camel_case(device_type) class ExperimentDevice(dj.Manual): - definition = f""" - # {device_title} placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-{aeon.__version__}) + definition = f"""# {device_title} operation for time,location, experiment (v{aeon.__version__}) -> acquisition.Experiment -> Device - {device_type}_install_time : datetime(6) # time of the {device_type} placed and started operation at this position + {device_type}_install_time : datetime(6) # {device_type} time of placement and start operation --- - {device_type}_name : varchar(36) + {device_type}_name : varchar(36) """ class Attribute(dj.Part): - definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device + definition = """ # Metadata (e.g. FPS, config, calibration) for this experimental device -> master attribute_name : varchar(32) --- @@ -106,8 +113,9 @@ def get_device_stream_template(device_type: str, stream_type: str, streams_modul & (streams_module.DeviceType.Stream & {"device_type": device_type, "stream_type": stream_type}) ).fetch1() - for i, n in enumerate(stream_detail["stream_reader"].split(".")): - reader = aeon if i == 0 else getattr(reader, n) + reader = aeon + for n in stream_detail["stream_reader"].split(".")[1:]: + reader = getattr(reader, n) if reader is aeon.io.reader.Pose: logger.warning("Automatic generation of stream table for Pose reader is not supported. Skipping...") @@ -115,7 +123,8 @@ def get_device_stream_template(device_type: str, stream_type: str, streams_modul stream = reader(**stream_detail["stream_reader_kwargs"]) - table_definition = f""" # Raw per-chunk {stream_type} data stream from {device_type} (auto-generated with aeon_mecha-{aeon.__version__}) + ver = aeon.__version__ + table_definition = f""" # Raw per-chunk {stream_type} from {device_type}(auto-generated with v{ver}) -> {device_type} -> acquisition.Chunk --- @@ -134,18 +143,21 @@ class DeviceDataStream(dj.Imported): @property def key_source(self): - f""" - Only the combination of Chunk and {device_type} with overlapping time - + Chunk(s) that started after {device_type} install time and ended before {device_type} remove time - + Chunk(s) that started after {device_type} install time for {device_type} that are not yet removed + docstring = f"""Only the combination of Chunk and {device_type} with overlapping time. + + + Chunk(s) started after {device_type} install time & ended before {device_type} remove time + + Chunk(s) started after {device_type} install time for {device_type} and not yet removed """ + self.__doc__ = docstring + device_type_name = dj.utils.from_camel_case(device_type) return ( acquisition.Chunk * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) - & f"chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time" - & f'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, "2200-01-01")' + & f"chunk_start >= {device_type_name}_install_time" + & f'chunk_start < IFNULL({device_type_name}_removal_time,"2200-01-01")' ) def make(self, key): + """Load and insert the data for the DeviceDataStream table.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -189,6 +201,7 @@ def make(self, key): def main(create_tables=True): + """Main function to create and update stream-related tables in the analysis schema.""" if not _STREAMS_MODULE_FILE.exists(): with open(_STREAMS_MODULE_FILE, "w") as f: imports_str = ( @@ -259,12 +272,19 @@ def main(create_tables=True): device_stream_table_def = inspect.getsource(table_class).lstrip() # Replace the definition + device_type_name = dj.utils.from_camel_case(device_type) replacements = { "DeviceDataStream": f"{device_type}{stream_type}", "ExperimentDevice": device_type, - 'f"chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time"': f"'chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time'", - """f'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, "2200-01-01")'""": f"""'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, "2200-01-01")'""", - 'f"{dj.utils.from_camel_case(device_type)}_name"': f"'{dj.utils.from_camel_case(device_type)}_name'", + 'f"chunk_start >= {device_type_name}_install_time"': ( + f"'chunk_start >= {device_type_name}_install_time'" + ), + """f'chunk_start < IFNULL({device_type_name}_removal_time, "2200-01-01")'""": ( + f"""'chunk_start < IFNULL({device_type_name}_removal_time,"2200-01-01")'""" + ), + 'f"{device_type_name}_name"': ( + f"'{device_type_name}_name'" + ), "{device_type}": device_type, "{stream_type}": stream_type, "{aeon.__version__}": aeon.__version__, diff --git a/aeon/dj_pipeline/utils/video.py b/aeon/dj_pipeline/utils/video.py index 63b64f24..b9bf40e8 100644 --- a/aeon/dj_pipeline/utils/video.py +++ b/aeon/dj_pipeline/utils/video.py @@ -1,3 +1,5 @@ +"""Utility functions for video processing.""" + import base64 from pathlib import Path @@ -22,8 +24,8 @@ def retrieve_video_frames( ): """Retrive video trames from the raw data directory.""" raw_data_dir = Path(raw_data_dir) - assert raw_data_dir.exists() - + if not raw_data_dir.exists(): + raise FileNotFoundError(f"The specified raw data directory does not exist: {raw_data_dir}") # Load video data videodata = io_api.load( root=raw_data_dir.as_posix(), diff --git a/aeon/dj_pipeline/webapps/sciviz/specsheet.yaml b/aeon/dj_pipeline/webapps/sciviz/specsheet.yaml index 629ee571..7c8793ec 100644 --- a/aeon/dj_pipeline/webapps/sciviz/specsheet.yaml +++ b/aeon/dj_pipeline/webapps/sciviz/specsheet.yaml @@ -850,7 +850,7 @@ SciViz: def dj_query(aeon_block_analysis): aeon_analysis = aeon_block_analysis return {'query': aeon_block_analysis.BlockSubjectPositionPlots(), 'fetch_args': ['position_plot']} - + VideoStream: route: /videostream diff --git a/aeon/io/__init__.py b/aeon/io/__init__.py index e69de29b..e23efc36 100644 --- a/aeon/io/__init__.py +++ b/aeon/io/__init__.py @@ -0,0 +1 @@ +"""Utilities for I/O operations.""" diff --git a/aeon/io/api.py b/aeon/io/api.py index 5d505ea6..22a11ce7 100644 --- a/aeon/io/api.py +++ b/aeon/io/api.py @@ -1,3 +1,5 @@ +"""API for reading Aeon data from disk.""" + import bisect import datetime from os import PathLike @@ -73,7 +75,7 @@ def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=No :param datetime, optional end: The right bound of the time range to extract. :param datetime, optional time: An object or series specifying the timestamps to extract. :param datetime, optional tolerance: - The maximum distance between original and new timestamps for inexact matches. + The maximum distance between original and new timestamps for inexact matches. :param str, optional epoch: A wildcard pattern to use when searching epoch data. :param optional kwargs: Optional keyword arguments to forward to the reader when reading chunk data. :return: A pandas data frame containing epoch event metadata, sorted by time. diff --git a/aeon/io/device.py b/aeon/io/device.py index d7707fb0..8e473662 100644 --- a/aeon/io/device.py +++ b/aeon/io/device.py @@ -1,3 +1,5 @@ +"""Deprecated Device class for grouping multiple Readers into a logical device.""" + import inspect from typing_extensions import deprecated @@ -33,10 +35,12 @@ class Device: """ def __init__(self, name, *args, pattern=None): + """Initializes the Device class.""" self.name = name self.registry = compositeStream(name if pattern is None else pattern, *args) def __iter__(self): + """Iterates over the device registry.""" if len(self.registry) == 1: singleton = self.registry.get(self.name, None) if singleton: diff --git a/aeon/io/reader.py b/aeon/io/reader.py index 16af3096..dbf574ec 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -1,3 +1,5 @@ +"""Module for reading data from raw files in an Aeon dataset.""" + from __future__ import annotations import datetime @@ -38,6 +40,7 @@ class Reader: """ def __init__(self, pattern, columns, extension): + """Initialize the object with specified pattern, columns, and file extension.""" self.pattern = pattern self.columns = columns self.extension = extension @@ -51,6 +54,7 @@ class Harp(Reader): """Extracts data from raw binary files encoded using the Harp protocol.""" def __init__(self, pattern, columns, extension="bin"): + """Initialize the object.""" super().__init__(pattern, columns, extension) def read(self, file): @@ -87,6 +91,7 @@ class Chunk(Reader): """Extracts path and epoch information from chunk files in the dataset.""" def __init__(self, reader=None, pattern=None, extension=None): + """Initialize the object with optional reader, pattern, and file extension.""" if isinstance(reader, Reader): pattern = reader.pattern extension = reader.extension @@ -103,6 +108,7 @@ class Metadata(Reader): """Extracts metadata information from all epochs in the dataset.""" def __init__(self, pattern="Metadata"): + """Initialize the object with the specified pattern.""" super().__init__(pattern, columns=["workflow", "commit", "metadata"], extension="yml") def read(self, file): @@ -125,6 +131,7 @@ class Csv(Reader): """ def __init__(self, pattern, columns, dtype=None, extension="csv"): + """Initialize the object with the specified pattern, columns, and data type.""" super().__init__(pattern, columns, extension) self.dtype = dtype @@ -140,22 +147,21 @@ def read(self, file): class JsonList(Reader): - """Extracts data from json list (.jsonl) files, where the key "seconds" - stores the Aeon timestamp, in seconds. - """ + """Extracts data from .jsonl files, where the key "seconds" stores the Aeon timestamp (s).""" def __init__(self, pattern, columns=(), root_key="value", extension="jsonl"): + """Initialize the object with the specified pattern, columns, and root key.""" super().__init__(pattern, columns, extension) self.columns = columns self.root_key = root_key def read(self, file): """Reads data from the specified jsonl file.""" - with open(file, "r") as f: + with open(file) as f: df = pd.read_json(f, lines=True) df.set_index("seconds", inplace=True) for column in self.columns: - df[column] = df[self.root_key].apply(lambda x: x[column]) + df[column] = df[self.root_key].apply(lambda x: x[column]) # noqa B023 return df @@ -163,13 +169,15 @@ class Subject(Csv): """Extracts metadata for subjects entering and exiting the environment. Columns: - id (str): Unique identifier of a subject in the environment. - weight (float): Weight measurement of the subject on entering - or exiting the environment. - event (str): Event type. Can be one of `Enter`, `Exit` or `Remain`. + + - id (str): Unique identifier of a subject in the environment. + - weight (float): Weight measurement of the subject on entering + or exiting the environment. + - event (str): Event type. Can be one of `Enter`, `Exit` or `Remain`. """ def __init__(self, pattern): + """Initialize the object with a specified pattern.""" super().__init__(pattern, columns=["id", "weight", "event"]) @@ -177,13 +185,15 @@ class Log(Csv): """Extracts message log data. Columns: - priority (str): Priority level of the message. - type (str): Type of the log message. - message (str): Log message data. Can be structured using tab - separated values. + + - priority (str): Priority level of the message. + - type (str): Type of the log message. + - message (str): Log message data. Can be structured using tab + separated values. """ def __init__(self, pattern): + """Initialize the object with a specified pattern and columns.""" super().__init__(pattern, columns=["priority", "type", "message"]) @@ -191,10 +201,12 @@ class Heartbeat(Harp): """Extract periodic heartbeat event data. Columns: - second (int): The whole second corresponding to the heartbeat, in seconds. + + - second (int): The whole second corresponding to the heartbeat, in seconds. """ def __init__(self, pattern): + """Initialize the object with a specified pattern.""" super().__init__(pattern, columns=["second"]) @@ -202,11 +214,13 @@ class Encoder(Harp): """Extract magnetic encoder data. Columns: - angle (float): Absolute angular position, in radians, of the magnetic encoder. - intensity (float): Intensity of the magnetic field. + + - angle (float): Absolute angular position, in radians, of the magnetic encoder. + - intensity (float): Intensity of the magnetic field. """ def __init__(self, pattern): + """Initialize the object with a specified pattern and columns.""" super().__init__(pattern, columns=["angle", "intensity"]) @@ -214,18 +228,20 @@ class Position(Harp): """Extract 2D position tracking data for a specific camera. Columns: - x (float): x-coordinate of the object center of mass. - y (float): y-coordinate of the object center of mass. - angle (float): angle, in radians, of the ellipse fit to the object. - major (float): length, in pixels, of the major axis of the ellipse - fit to the object. - minor (float): length, in pixels, of the minor axis of the ellipse - fit to the object. - area (float): number of pixels in the object mass. - id (float): unique tracking ID of the object in a frame. + + - x (float): x-coordinate of the object center of mass. + - y (float): y-coordinate of the object center of mass. + - angle (float): angle, in radians, of the ellipse fit to the object. + - major (float): length, in pixels, of the major axis of the ellipse + fit to the object. + - minor (float): length, in pixels, of the minor axis of the ellipse + fit to the object. + - area (float): number of pixels in the object mass. + - id (float): unique tracking ID of the object in a frame. """ def __init__(self, pattern): + """Initialize the object with a specified pattern and columns.""" super().__init__(pattern, columns=["x", "y", "angle", "major", "minor", "area", "id"]) @@ -233,10 +249,12 @@ class BitmaskEvent(Harp): """Extracts event data matching a specific digital I/O bitmask. Columns: - event (str): Unique identifier for the event code. + + - event (str): Unique identifier for the event code. """ def __init__(self, pattern, value, tag): + """Initialize the object with specified pattern, value, and tag.""" super().__init__(pattern, columns=["event"]) self.value = value self.tag = tag @@ -256,10 +274,12 @@ class DigitalBitmask(Harp): """Extracts event data matching a specific digital I/O bitmask. Columns: - event (str): Unique identifier for the event code. + + - event (str): Unique identifier for the event code. """ def __init__(self, pattern, mask, columns): + """Initialize the object with specified pattern, mask, and columns.""" super().__init__(pattern, columns) self.mask = mask @@ -277,11 +297,13 @@ class Video(Csv): """Extracts video frame metadata. Columns: - hw_counter (int): Hardware frame counter value for the current frame. - hw_timestamp (int): Internal camera timestamp for the current frame. + + - hw_counter (int): Hardware frame counter value for the current frame. + - hw_timestamp (int): Internal camera timestamp for the current frame. """ def __init__(self, pattern): + """Initialize the object with a specified pattern.""" super().__init__(pattern, columns=["hw_counter", "hw_timestamp", "_frame", "_path", "_epoch"]) self._rawcolumns = ["time"] + self.columns[0:2] @@ -299,12 +321,13 @@ class Pose(Harp): """Reader for Harp-binarized tracking data given a model that outputs id, parts, and likelihoods. Columns: - class (int): Int ID of a subject in the environment. - class_likelihood (float): Likelihood of the subject's identity. - part (str): Bodypart on the subject. - part_likelihood (float): Likelihood of the specified bodypart. - x (float): X-coordinate of the bodypart. - y (float): Y-coordinate of the bodypart. + + - class (int): Int ID of a subject in the environment. + - class_likelihood (float): Likelihood of the subject's identity. + - part (str): Bodypart on the subject. + - part_likelihood (float): Likelihood of the specified bodypart. + - x (float): X-coordinate of the bodypart. + - y (float): Y-coordinate of the bodypart. """ def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/processed"): @@ -387,7 +410,8 @@ def read(self, file: Path) -> pd.DataFrame: if bonsai_sleap_v == BONSAI_SLEAP_V3: # combine all identity_likelihood cols into a single col as dict part_data["identity_likelihood"] = part_data.apply( - lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, axis=1 + lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, + axis=1, ) part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True) part_data = part_data[ # reorder columns diff --git a/aeon/io/video.py b/aeon/io/video.py index 26c49827..bd5326fa 100644 --- a/aeon/io/video.py +++ b/aeon/io/video.py @@ -1,3 +1,5 @@ +"""Module for reading and writing video files using OpenCV.""" + import cv2 @@ -5,10 +7,10 @@ def frames(data): """Extracts the raw frames corresponding to the provided video metadata. :param DataFrame data: - A pandas DataFrame where each row specifies video acquisition path and frame number. + A pandas DataFrame where each row specifies video acquisition path and frame number. :return: - An object to iterate over numpy arrays for each row in the DataFrame, - containing the raw video frame data. + An object to iterate over numpy arrays for each row in the DataFrame, + containing the raw video frame data. """ capture = None filename = None @@ -42,7 +44,7 @@ def export(frames, file, fps, fourcc=None): :param str file: The path to the exported video file. :param fps: The frame rate of the exported video. :param optional fourcc: - Specifies the four character code of the codec used to compress the frames. + Specifies the four character code of the codec used to compress the frames. """ writer = None try: diff --git a/aeon/schema/__init__.py b/aeon/schema/__init__.py index e69de29b..bbce21ee 100644 --- a/aeon/schema/__init__.py +++ b/aeon/schema/__init__.py @@ -0,0 +1 @@ +"""Utilities for the schemas.""" diff --git a/aeon/schema/core.py b/aeon/schema/core.py index 6f70c8b4..02703e74 100644 --- a/aeon/schema/core.py +++ b/aeon/schema/core.py @@ -1,3 +1,5 @@ +"""Schema definition for core Harp data streams.""" + import aeon.io.reader as _reader from aeon.schema.streams import Stream, StreamGroup @@ -6,6 +8,7 @@ class Heartbeat(Stream): """Heartbeat event for Harp devices.""" def __init__(self, pattern): + """Initializes the Heartbeat stream.""" super().__init__(_reader.Heartbeat(f"{pattern}_8_*")) @@ -13,6 +16,7 @@ class Video(Stream): """Video frame metadata.""" def __init__(self, pattern): + """Initializes the Video stream.""" super().__init__(_reader.Video(f"{pattern}_*")) @@ -20,6 +24,7 @@ class Position(Stream): """Position tracking data for the specified camera.""" def __init__(self, pattern): + """Initializes the Position stream.""" super().__init__(_reader.Position(f"{pattern}_200_*")) @@ -27,6 +32,7 @@ class Encoder(Stream): """Wheel magnetic encoder data.""" def __init__(self, pattern): + """Initializes the Encoder stream.""" super().__init__(_reader.Encoder(f"{pattern}_90_*")) @@ -34,6 +40,7 @@ class Environment(StreamGroup): """Metadata for environment mode and subjects.""" def __init__(self, pattern): + """Initializes the Environment stream group.""" super().__init__(pattern, EnvironmentState, SubjectState) @@ -41,6 +48,7 @@ class EnvironmentState(Stream): """Environment state log.""" def __init__(self, pattern): + """Initializes the EnvironmentState stream.""" super().__init__(_reader.Csv(f"{pattern}_EnvironmentState_*", ["state"])) @@ -48,6 +56,7 @@ class SubjectState(Stream): """Subject state log.""" def __init__(self, pattern): + """Initialises the SubjectState stream.""" super().__init__(_reader.Subject(f"{pattern}_SubjectState_*")) @@ -55,6 +64,7 @@ class MessageLog(Stream): """Message log data.""" def __init__(self, pattern): + """Initializes the MessageLog stream.""" super().__init__(_reader.Log(f"{pattern}_MessageLog_*")) @@ -62,4 +72,5 @@ class Metadata(Stream): """Metadata for acquisition epochs.""" def __init__(self, pattern): + """Initializes the Metadata stream.""" super().__init__(_reader.Metadata(pattern)) diff --git a/aeon/schema/foraging.py b/aeon/schema/foraging.py index 0eaf593c..85601db6 100644 --- a/aeon/schema/foraging.py +++ b/aeon/schema/foraging.py @@ -1,3 +1,5 @@ +"""Schema definition for foraging experiments.""" + from enum import Enum import pandas as pd @@ -18,6 +20,7 @@ class Area(Enum): class _RegionReader(_reader.Harp): def __init__(self, pattern): + """Initializes the RegionReader class.""" super().__init__(pattern, columns=["region"]) def read(self, file): @@ -37,6 +40,7 @@ class _PatchState(_reader.Csv): """ def __init__(self, pattern): + """Initializes the PatchState class.""" super().__init__(pattern, columns=["threshold", "d1", "delta"]) @@ -50,6 +54,7 @@ class _Weight(_reader.Harp): """ def __init__(self, pattern): + """Initializes the Weight class.""" super().__init__(pattern, columns=["value", "stable"]) @@ -57,6 +62,7 @@ class Region(Stream): """Region tracking data for the specified camera.""" def __init__(self, pattern): + """Initializes the Region stream.""" super().__init__(_RegionReader(f"{pattern}_201_*")) @@ -64,6 +70,7 @@ class DepletionFunction(Stream): """State of the linear depletion function for foraging patches.""" def __init__(self, pattern): + """Initializes the DepletionFunction stream.""" super().__init__(_PatchState(f"{pattern}_State_*")) @@ -71,6 +78,7 @@ class Feeder(StreamGroup): """Feeder commands and events.""" def __init__(self, pattern): + """Initializes the Feeder stream group.""" super().__init__(pattern, BeamBreak, DeliverPellet) @@ -78,6 +86,7 @@ class BeamBreak(Stream): """Beam break events for pellet detection.""" def __init__(self, pattern): + """Initializes the BeamBreak stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_32_*", 0x22, "PelletDetected")) @@ -85,6 +94,7 @@ class DeliverPellet(Stream): """Pellet delivery commands.""" def __init__(self, pattern): + """Initializes the DeliverPellet stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x01, "TriggerPellet")) @@ -92,6 +102,7 @@ class Patch(StreamGroup): """Data streams for a patch.""" def __init__(self, pattern): + """Initializes the Patch stream group.""" super().__init__(pattern, DepletionFunction, _stream.Encoder, Feeder) @@ -99,6 +110,7 @@ class Weight(StreamGroup): """Weight measurement data streams for a specific nest.""" def __init__(self, pattern): + """Initializes the Weight stream group.""" super().__init__(pattern, WeightRaw, WeightFiltered, WeightSubject) @@ -106,6 +118,7 @@ class WeightRaw(Stream): """Raw weight measurement for a specific nest.""" def __init__(self, pattern): + """Initializes the WeightRaw stream.""" super().__init__(_Weight(f"{pattern}_200_*")) @@ -113,6 +126,7 @@ class WeightFiltered(Stream): """Filtered weight measurement for a specific nest.""" def __init__(self, pattern): + """Initializes the WeightFiltered stream.""" super().__init__(_Weight(f"{pattern}_202_*")) @@ -120,6 +134,7 @@ class WeightSubject(Stream): """Subject weight measurement for a specific nest.""" def __init__(self, pattern): + """Initializes the WeightSubject stream.""" super().__init__(_Weight(f"{pattern}_204_*")) @@ -127,4 +142,5 @@ class SessionData(Stream): """Session metadata for Experiment 0.1.""" def __init__(self, pattern): + """Initializes the SessionData stream.""" super().__init__(_reader.Csv(f"{pattern}_2*", columns=["id", "weight", "event"])) diff --git a/aeon/schema/ingestion_schemas.py b/aeon/schema/ingestion_schemas.py index fe2ee3dd..4a2d45d7 100644 --- a/aeon/schema/ingestion_schemas.py +++ b/aeon/schema/ingestion_schemas.py @@ -6,7 +6,8 @@ import aeon.schema.core as stream from aeon.io import reader -from aeon.io.api import aeon as aeon_time, chunk as aeon_chunk +from aeon.io.api import aeon as aeon_time +from aeon.io.api import chunk as aeon_chunk from aeon.schema import foraging, octagon, social_01, social_02, social_03 from aeon.schema.streams import Device, Stream, StreamGroup @@ -26,7 +27,8 @@ def read(self, file: PathLike[str], sr_hz: int = 50) -> pd.DataFrame: freq = 1 / sr_hz * 1e3 # convert to ms if first_index is not None: chunk_origin = aeon_chunk(first_index) - data = data.resample(f"{freq}ms", origin=chunk_origin).first() # take first sample in each resampled bin + data = data.resample(f"{freq}ms", origin=chunk_origin).first() + # take first sample in each resampled bin return data @@ -214,7 +216,10 @@ def __init__(self, path): social04 = DotMap( [ Device("Metadata", stream.Metadata), - Device("Environment", social_02.Environment, social_02.SubjectData, social_03.EnvironmentActiveConfiguration), + Device("Environment", + social_02.Environment, + social_02.SubjectData, + social_03.EnvironmentActiveConfiguration), Device("CameraTop", Video, stream.Position, social_03.Pose), Device("CameraNorth", Video), Device("CameraSouth", Video), diff --git a/aeon/schema/octagon.py b/aeon/schema/octagon.py index 2ea85b4e..ae085abe 100644 --- a/aeon/schema/octagon.py +++ b/aeon/schema/octagon.py @@ -1,42 +1,55 @@ +"""Schema definition for octagon experiments-specific data streams.""" + import aeon.io.reader as _reader from aeon.schema.streams import Stream, StreamGroup class Photodiode(Stream): def __init__(self, path): + """Initializes the Photodiode stream.""" super().__init__(_reader.Harp(f"{path}_44_*", columns=["adc", "encoder"])) class OSC(StreamGroup): def __init__(self, path): + """Initializes the OSC stream group.""" super().__init__(path) class BackgroundColor(Stream): def __init__(self, pattern): + """Initializes the BackgroundColor stream.""" super().__init__( _reader.Csv(f"{pattern}_backgroundcolor_*", columns=["typetag", "r", "g", "b", "a"]) ) class ChangeSubjectState(Stream): def __init__(self, pattern): + """Initializes the ChangeSubjectState stream.""" super().__init__( - _reader.Csv(f"{pattern}_changesubjectstate_*", columns=["typetag", "id", "weight", "event"]) + _reader.Csv( + f"{pattern}_changesubjectstate_*", + columns=["typetag", "id", "weight", "event"], + ) ) class EndTrial(Stream): def __init__(self, pattern): + """Initialises the EndTrial stream.""" super().__init__(_reader.Csv(f"{pattern}_endtrial_*", columns=["typetag", "value"])) class Slice(Stream): def __init__(self, pattern): + """Initialises the Slice.""" super().__init__( _reader.Csv( - f"{pattern}_octagonslice_*", columns=["typetag", "wall_id", "r", "g", "b", "a", "delay"] + f"{pattern}_octagonslice_*", + columns=["typetag", "wall_id", "r", "g", "b", "a", "delay"], ) ) class GratingsSlice(Stream): def __init__(self, pattern): + """Initialises the GratingsSlice stream.""" super().__init__( _reader.Csv( f"{pattern}_octagongratingsslice_*", @@ -55,6 +68,7 @@ def __init__(self, pattern): class Poke(Stream): def __init__(self, pattern): + """Initializes the Poke class.""" super().__init__( _reader.Csv( f"{pattern}_poke_*", @@ -72,6 +86,7 @@ def __init__(self, pattern): class Response(Stream): def __init__(self, pattern): + """Initialises the Response class.""" super().__init__( _reader.Csv( f"{pattern}_response_*", columns=["typetag", "wall_id", "poke_id", "response_time"] @@ -80,6 +95,7 @@ def __init__(self, pattern): class RunPreTrialNoPoke(Stream): def __init__(self, pattern): + """Initialises the RunPreTrialNoPoke class.""" super().__init__( _reader.Csv( f"{pattern}_run_pre_no_poke_*", @@ -96,102 +112,127 @@ def __init__(self, pattern): class StartNewSession(Stream): def __init__(self, pattern): + """Initializes the StartNewSession class.""" super().__init__(_reader.Csv(f"{pattern}_startnewsession_*", columns=["typetag", "path"])) class TaskLogic(StreamGroup): def __init__(self, path): + """Initialises the TaskLogic stream group.""" super().__init__(path) class TrialInitiation(Stream): def __init__(self, pattern): + """Initializes the TrialInitiation stream.""" super().__init__(_reader.Harp(f"{pattern}_1_*", columns=["trial_type"])) class Response(Stream): def __init__(self, pattern): + """Initializes the Response stream.""" super().__init__(_reader.Harp(f"{pattern}_2_*", columns=["wall_id", "poke_id"])) class PreTrialState(Stream): def __init__(self, pattern): + """Initializes the PreTrialState stream.""" super().__init__(_reader.Harp(f"{pattern}_3_*", columns=["state"])) class InterTrialInterval(Stream): def __init__(self, pattern): + """Initializes the InterTrialInterval stream.""" super().__init__(_reader.Harp(f"{pattern}_4_*", columns=["state"])) class SliceOnset(Stream): def __init__(self, pattern): + """Initializes the SliceOnset stream.""" super().__init__(_reader.Harp(f"{pattern}_10_*", columns=["wall_id"])) class DrawBackground(Stream): def __init__(self, pattern): + """Initializes the DrawBackground stream.""" super().__init__(_reader.Harp(f"{pattern}_11_*", columns=["state"])) class GratingsSliceOnset(Stream): def __init__(self, pattern): + """Initializes the GratingsSliceOnset stream.""" super().__init__(_reader.Harp(f"{pattern}_12_*", columns=["wall_id"])) class Wall(StreamGroup): def __init__(self, path): + """Initialises the Wall stream group.""" super().__init__(path) class BeamBreak0(Stream): def __init__(self, pattern): + """Initialises the BeamBreak0 stream.""" super().__init__(_reader.DigitalBitmask(f"{pattern}_32_*", 0x1, columns=["state"])) class BeamBreak1(Stream): def __init__(self, pattern): + """Initialises the BeamBreak1 stream.""" super().__init__(_reader.DigitalBitmask(f"{pattern}_32_*", 0x2, columns=["state"])) class BeamBreak2(Stream): def __init__(self, pattern): + """Initialises the BeamBreak2 stream.""" super().__init__(_reader.DigitalBitmask(f"{pattern}_32_*", 0x4, columns=["state"])) class SetLed0(Stream): def __init__(self, pattern): + """Initialises the SetLed0 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_34_*", 0x1, "Set")) class SetLed1(Stream): def __init__(self, pattern): + """Initialises the SetLed1 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_34_*", 0x2, "Set")) class SetLed2(Stream): def __init__(self, pattern): + """Initialises the SetLed2 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_34_*", 0x4, "Set")) class SetValve0(Stream): def __init__(self, pattern): + """Initialises the SetValve0 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_34_*", 0x8, "Set")) class SetValve1(Stream): def __init__(self, pattern): + """Initialises the SetValve1 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_34_*", 0x10, "Set")) class SetValve2(Stream): def __init__(self, pattern): + """Initialises the SetValve2 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_34_*", 0x20, "Set")) class ClearLed0(Stream): def __init__(self, pattern): + """Initialises the ClearLed0 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x1, "Clear")) class ClearLed1(Stream): def __init__(self, pattern): + """Initializes the ClearLed1 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x2, "Clear")) class ClearLed2(Stream): def __init__(self, pattern): + """Initializes the ClearLed2 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x4, "Clear")) class ClearValve0(Stream): def __init__(self, pattern): + """Initializes the ClearValve0 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x8, "Clear")) class ClearValve1(Stream): def __init__(self, pattern): + """Initializes the ClearValve1 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x10, "Clear")) class ClearValve2(Stream): def __init__(self, pattern): + """Initializes the ClearValve2 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x20, "Clear")) diff --git a/aeon/schema/schemas.py b/aeon/schema/schemas.py index 0da2f1bf..d62dc6ec 100644 --- a/aeon/schema/schemas.py +++ b/aeon/schema/schemas.py @@ -1,3 +1,5 @@ +"""Schemas for different experiments.""" + from dotmap import DotMap import aeon.schema.core as stream @@ -115,7 +117,12 @@ social03 = DotMap( [ Device("Metadata", stream.Metadata), - Device("Environment", social_02.Environment, social_02.SubjectData, social_03.EnvironmentActiveConfiguration), + Device( + "Environment", + social_02.Environment, + social_02.SubjectData, + social_03.EnvironmentActiveConfiguration, + ), Device("CameraTop", stream.Video, social_03.Pose), Device("CameraNorth", stream.Video), Device("CameraSouth", stream.Video), @@ -146,7 +153,12 @@ social04 = DotMap( [ Device("Metadata", stream.Metadata), - Device("Environment", social_02.Environment, social_02.SubjectData, social_03.EnvironmentActiveConfiguration), + Device( + "Environment", + social_02.Environment, + social_02.SubjectData, + social_03.EnvironmentActiveConfiguration, + ), Device("CameraTop", stream.Video, social_03.Pose), Device("CameraNorth", stream.Video), Device("CameraSouth", stream.Video), diff --git a/aeon/schema/social_01.py b/aeon/schema/social_01.py index 7f6e2ab0..3230e1aa 100644 --- a/aeon/schema/social_01.py +++ b/aeon/schema/social_01.py @@ -1,9 +1,12 @@ +"""Schema definition for social_01 experiments-specific data streams.""" + import aeon.io.reader as _reader from aeon.schema.streams import Stream class RfidEvents(Stream): def __init__(self, path): + """Initializes the RfidEvents stream.""" path = path.replace("Rfid", "") if path.startswith("Events"): path = path.replace("Events", "") @@ -13,4 +16,5 @@ def __init__(self, path): class Pose(Stream): def __init__(self, path): + """Initializes the Pose stream.""" super().__init__(_reader.Pose(f"{path}_node-0*")) diff --git a/aeon/schema/social_02.py b/aeon/schema/social_02.py index 9b50cf60..0df58e32 100644 --- a/aeon/schema/social_02.py +++ b/aeon/schema/social_02.py @@ -1,3 +1,5 @@ +"""Schema definition for social_02 experiments-specific data streams.""" + import aeon.io.reader as _reader from aeon.schema import core, foraging from aeon.schema.streams import Stream, StreamGroup @@ -5,18 +7,21 @@ class Environment(StreamGroup): def __init__(self, path): + """Initializes the Environment stream group.""" super().__init__(path) EnvironmentState = core.EnvironmentState class BlockState(Stream): def __init__(self, path): + """Initializes the BlockState stream.""" super().__init__( _reader.Csv(f"{path}_BlockState_*", columns=["pellet_ct", "pellet_ct_thresh", "due_time"]) ) class LightEvents(Stream): def __init__(self, path): + """Initializes the LightEvents stream.""" super().__init__(_reader.Csv(f"{path}_LightEvents_*", columns=["channel", "value"])) MessageLog = core.MessageLog @@ -24,52 +29,62 @@ def __init__(self, path): class SubjectData(StreamGroup): def __init__(self, path): + """Initializes the SubjectData stream group.""" super().__init__(path) class SubjectState(Stream): def __init__(self, path): + """Initializes the SubjectState stream.""" super().__init__(_reader.Csv(f"{path}_SubjectState_*", columns=["id", "weight", "type"])) class SubjectVisits(Stream): def __init__(self, path): + """Initializes the SubjectVisits stream.""" super().__init__(_reader.Csv(f"{path}_SubjectVisits_*", columns=["id", "type", "region"])) class SubjectWeight(Stream): def __init__(self, path): + """Initializes the SubjectWeight stream.""" super().__init__( _reader.Csv( - f"{path}_SubjectWeight_*", columns=["weight", "confidence", "subject_id", "int_id"] + f"{path}_SubjectWeight_*", + columns=["weight", "confidence", "subject_id", "int_id"], ) ) class Pose(Stream): def __init__(self, path): + """Initializes the Pose stream.""" super().__init__(_reader.Pose(f"{path}_test-node1*")) class Pose03(Stream): - def __init__(self, path): + """Initializes the Pose stream.""" super().__init__(_reader.Pose(f"{path}_202_*")) class WeightRaw(Stream): def __init__(self, path): + """Initializes the WeightRaw stream.""" super().__init__(_reader.Harp(f"{path}_200_*", ["weight(g)", "stability"])) class WeightFiltered(Stream): def __init__(self, path): + """Initializes the WeightFiltered stream.""" super().__init__(_reader.Harp(f"{path}_202_*", ["weight(g)", "stability"])) class Patch(StreamGroup): def __init__(self, path): + """Initializes the Patch stream group.""" super().__init__(path) class DepletionState(Stream): def __init__(self, path): + """Initializes the DepletionState stream.""" super().__init__(_reader.Csv(f"{path}_State_*", columns=["threshold", "offset", "rate"])) Encoder = core.Encoder @@ -78,17 +93,21 @@ def __init__(self, path): class ManualDelivery(Stream): def __init__(self, path): + """Initializes the ManualDelivery stream.""" super().__init__(_reader.Harp(f"{path}_201_*", ["manual_delivery"])) class MissedPellet(Stream): def __init__(self, path): + """Initializes the MissedPellet stream.""" super().__init__(_reader.Harp(f"{path}_202_*", ["missed_pellet"])) class RetriedDelivery(Stream): def __init__(self, path): + """Initializes the RetriedDelivery stream.""" super().__init__(_reader.Harp(f"{path}_203_*", ["retried_delivery"])) class RfidEvents(Stream): def __init__(self, path): + """Initializes the RfidEvents stream.""" super().__init__(_reader.Harp(f"{path}_32*", ["rfid"])) diff --git a/aeon/schema/social_03.py b/aeon/schema/social_03.py index fdb1f7df..5954d35b 100644 --- a/aeon/schema/social_03.py +++ b/aeon/schema/social_03.py @@ -1,15 +1,16 @@ -import json -import pandas as pd +"""Schema definition for social_03 experiments-specific data streams.""" + import aeon.io.reader as _reader from aeon.schema.streams import Stream class Pose(Stream): def __init__(self, path): + """Initializes the Pose stream.""" super().__init__(_reader.Pose(f"{path}_202_*")) class EnvironmentActiveConfiguration(Stream): - def __init__(self, path): + """Initializes the EnvironmentActiveConfiguration stream.""" super().__init__(_reader.JsonList(f"{path}_ActiveConfiguration_*", columns=["name"])) diff --git a/aeon/schema/streams.py b/aeon/schema/streams.py index 2c5d57b2..c29e2779 100644 --- a/aeon/schema/streams.py +++ b/aeon/schema/streams.py @@ -1,3 +1,5 @@ +"""Classes for defining data streams and devices.""" + import inspect from itertools import chain from warnings import warn @@ -11,9 +13,11 @@ class Stream: """ def __init__(self, reader): + """Initializes the stream with a reader.""" self.reader = reader def __iter__(self): + """Yields the stream name and reader.""" yield (self.__class__.__name__, self.reader) @@ -26,6 +30,7 @@ class StreamGroup: """ def __init__(self, path, *args): + """Initializes the stream group with a path and a list of data streams.""" self.path = path self._args = args self._nested = ( @@ -35,6 +40,7 @@ def __init__(self, path, *args): ) def __iter__(self): + """Yields the stream name and reader for each data stream in the group.""" for factory in chain(self._nested, self._args): yield from iter(factory(self.path)) @@ -53,6 +59,7 @@ class Device: """ def __init__(self, name, *args, path=None): + """Initializes the device with a name and a list of data streams.""" if name is None: raise ValueError("name cannot be None.") @@ -77,6 +84,7 @@ def _createStreams(path, args): return streams def __iter__(self): + """Iterates over the device streams.""" if len(self._streams) == 1: singleton = self._streams.get(self.name, None) if singleton: diff --git a/pyproject.toml b/pyproject.toml index 658abf5b..0d7d7ec7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,20 +97,9 @@ lint.select = [ ] line-length = 108 lint.ignore = [ - "D100", # skip adding docstrings for module - "D104", # ignore missing docstring in public package - "D105", # skip adding docstrings for magic methods - "D107", # skip adding docstrings for __init__ - "E201", - "E202", - "E203", - "E231", "E731", - "E702", - "S101", - "PT004", # Rule `PT004` is deprecated and will be removed in a future release. - "PT013", - "PLR0912", + "PT004", # Deprecated and will be removed in a future release. + "PLR0912", "PLR0913", "PLR0915", ] @@ -124,39 +113,18 @@ extend-exclude = [ ] [tool.ruff.lint.per-file-ignores] "tests/*" = [ - "D103", # skip adding docstrings for public functions + "D103", # skip adding docstrings for public functions + "S101", # skip using assert ] "aeon/schema/*" = [ "D101", # skip adding docstrings for schema classes "D106", # skip adding docstrings for nested streams ] "aeon/dj_pipeline/*" = [ - "B006", - "B021", - "D101", # skip adding docstrings for table class since it is added inside definition - "D102", # skip adding docstrings for make function - "D103", # skip adding docstrings for public functions - "D106", # skip adding docstrings for Part tables - "E501", - "F401", # ignore unused import errors - "B905", # ignore unused import errors - "E999", - "S324", - "E722", - "S110", - "F821", - "B904", - "UP038", - "S607", - "S605", - "D205", - "D202", - "F403", - "PLR2004", - "SIM108", - "PLW0127", - "PLR2004", - "I001", + "D101", # skip adding docstrings for schema classes + "D106", # skip adding docstrings for nested streams + "S324", # skip hashlib insecure hash function (md5) warning + "F401", # skip incorrectly detecting `aeon.dj_pipeline` dependencies as unused ] [tool.ruff.lint.pydocstyle] diff --git a/tests/dj_pipeline/conftest.py b/tests/dj_pipeline/conftest.py index 3604890f..d1383fe7 100644 --- a/tests/dj_pipeline/conftest.py +++ b/tests/dj_pipeline/conftest.py @@ -47,17 +47,33 @@ def test_params(): @pytest.fixture(autouse=True, scope="session") def dj_config(): - """Configures DataJoint connection and loads custom settings.""" + """Configures DataJoint connection and loads custom settings. + + This fixture sets up the DataJoint configuration using the + 'dj_local_conf.json' file. It raises FileNotFoundError if the file + does not exist, and KeyError if 'custom' is not found in the + DataJoint configuration. + """ dj_config_fp = pathlib.Path("dj_local_conf.json") assert dj_config_fp.exists() dj.config.load(dj_config_fp) dj.config["safemode"] = False assert "custom" in dj.config - dj.config["custom"]["database.prefix"] = f"u_{dj.config['database.user']}_testsuite_" + dj.config["custom"][ + "database.prefix" + ] = f"u_{dj.config['database.user']}_testsuite_" def load_pipeline(): - from aeon.dj_pipeline import acquisition, analysis, lab, qc, report, subject, tracking + from aeon.dj_pipeline import ( + acquisition, + analysis, + lab, + qc, + report, + subject, + tracking, + ) return { "subject": subject, diff --git a/tests/dj_pipeline/test_acquisition.py b/tests/dj_pipeline/test_acquisition.py index 28f39d12..51cd1e77 100644 --- a/tests/dj_pipeline/test_acquisition.py +++ b/tests/dj_pipeline/test_acquisition.py @@ -1,7 +1,9 @@ -from pytest import mark +"""Tests for the acquisition pipeline.""" +import pytest -@mark.ingestion + +@pytest.mark.ingestion def test_epoch_chunk_ingestion(test_params, pipeline, epoch_chunk_ingestion): acquisition = pipeline["acquisition"] @@ -15,19 +17,30 @@ def test_epoch_chunk_ingestion(test_params, pipeline, epoch_chunk_ingestion): ) -@mark.ingestion -def test_experimentlog_ingestion(test_params, pipeline, epoch_chunk_ingestion, experimentlog_ingestion): +@pytest.mark.ingestion +def test_experimentlog_ingestion( + test_params, pipeline, epoch_chunk_ingestion, experimentlog_ingestion +): acquisition = pipeline["acquisition"] assert ( - len(acquisition.ExperimentLog.Message & {"experiment_name": test_params["experiment_name"]}) + len( + acquisition.ExperimentLog.Message + & {"experiment_name": test_params["experiment_name"]} + ) == test_params["experiment_log_message_count"] ) assert ( - len(acquisition.SubjectEnterExit.Time & {"experiment_name": test_params["experiment_name"]}) + len( + acquisition.SubjectEnterExit.Time + & {"experiment_name": test_params["experiment_name"]} + ) == test_params["subject_enter_exit_count"] ) assert ( - len(acquisition.SubjectWeight.WeightTime & {"experiment_name": test_params["experiment_name"]}) + len( + acquisition.SubjectWeight.WeightTime + & {"experiment_name": test_params["experiment_name"]} + ) == test_params["subject_weight_time_count"] ) diff --git a/tests/dj_pipeline/test_pipeline_instantiation.py b/tests/dj_pipeline/test_pipeline_instantiation.py index cb3b51fb..52da5625 100644 --- a/tests/dj_pipeline/test_pipeline_instantiation.py +++ b/tests/dj_pipeline/test_pipeline_instantiation.py @@ -1,7 +1,9 @@ -from pytest import mark +"""Tests for pipeline instantiation and experiment creation.""" +import pytest -@mark.instantiation + +@pytest.mark.instantiation def test_pipeline_instantiation(pipeline): assert hasattr(pipeline["acquisition"], "FoodPatchEvent") assert hasattr(pipeline["lab"], "Arena") @@ -11,16 +13,19 @@ def test_pipeline_instantiation(pipeline): assert hasattr(pipeline["tracking"], "CameraTracking") -@mark.instantiation +@pytest.mark.instantiation def test_experiment_creation(test_params, pipeline, experiment_creation): acquisition = pipeline["acquisition"] experiment_name = test_params["experiment_name"] assert acquisition.Experiment.fetch1("experiment_name") == experiment_name raw_dir = ( - acquisition.Experiment.Directory & {"experiment_name": experiment_name, "directory_type": "raw"} + acquisition.Experiment.Directory + & {"experiment_name": experiment_name, "directory_type": "raw"} ).fetch1("directory_path") assert raw_dir == test_params["raw_dir"] - exp_subjects = (acquisition.Experiment.Subject & {"experiment_name": experiment_name}).fetch("subject") + exp_subjects = ( + acquisition.Experiment.Subject & {"experiment_name": experiment_name} + ).fetch("subject") assert len(exp_subjects) == test_params["subject_count"] assert "BAA-1100701" in exp_subjects diff --git a/tests/dj_pipeline/test_qc.py b/tests/dj_pipeline/test_qc.py index 9815031e..31e6baf9 100644 --- a/tests/dj_pipeline/test_qc.py +++ b/tests/dj_pipeline/test_qc.py @@ -1,7 +1,9 @@ -from pytest import mark +"""Tests for the QC pipeline.""" +import pytest -@mark.qc + +@pytest.mark.qc def test_camera_qc_ingestion(test_params, pipeline, camera_qc_ingestion): qc = pipeline["qc"] diff --git a/tests/dj_pipeline/test_tracking.py b/tests/dj_pipeline/test_tracking.py index 973e0741..1227adb2 100644 --- a/tests/dj_pipeline/test_tracking.py +++ b/tests/dj_pipeline/test_tracking.py @@ -1,14 +1,19 @@ +"""Test tracking pipeline.""" + import datetime import pathlib +import datajoint as dj import numpy as np -from pytest import mark +import pytest + +logger = dj.logger + index = 0 column_name = "position_x" # data column to run test on -file_name = ( - "exp0.2-r0-20220524090000-21053810-20220524082942-0-0.npy" # test file to be saved with save_test_data -) +file_name = "exp0.2-r0-20220524090000-21053810-20220524082942-0-0.npy" +# test file to be saved with save_test_data def save_test_data(pipeline, test_params): @@ -19,7 +24,11 @@ def save_test_data(pipeline, test_params): file_name = ( "-".join( [ - v.strftime("%Y%m%d%H%M%S") if isinstance(v, datetime.datetime) else str(v) + ( + v.strftime("%Y%m%d%H%M%S") + if isinstance(v, datetime.datetime) + else str(v) + ) for v in key.values() ] ) @@ -33,18 +42,25 @@ def save_test_data(pipeline, test_params): return test_file -@mark.ingestion -@mark.tracking +@pytest.mark.ingestion +@pytest.mark.tracking def test_camera_tracking_ingestion(test_params, pipeline, camera_tracking_ingestion): tracking = pipeline["tracking"] - assert len(tracking.CameraTracking.Object()) == test_params["camera_tracking_object_count"] + assert ( + len(tracking.CameraTracking.Object()) + == test_params["camera_tracking_object_count"] + ) key = tracking.CameraTracking.Object().fetch("KEY")[index] file_name = ( "-".join( [ - v.strftime("%Y%m%d%H%M%S") if isinstance(v, datetime.datetime) else str(v) + ( + v.strftime("%Y%m%d%H%M%S") + if isinstance(v, datetime.datetime) + else str(v) + ) for v in key.values() ] ) diff --git a/tests/io/test_api.py b/tests/io/test_api.py index 2a491c55..8652df17 100644 --- a/tests/io/test_api.py +++ b/tests/io/test_api.py @@ -1,8 +1,9 @@ +"""Tests for the aeon API.""" + from pathlib import Path import pandas as pd import pytest -from pytest import mark import aeon from aeon.schema.ingestion_schemas import social03 @@ -12,7 +13,7 @@ nonmonotonic_path = Path(__file__).parent.parent / "data" / "nonmonotonic" -@mark.api +@pytest.mark.api def test_load_start_only(): data = aeon.load( nonmonotonic_path, exp02.Patch2.Encoder, start=pd.Timestamp("2022-06-06T13:00:49") @@ -20,7 +21,7 @@ def test_load_start_only(): assert len(data) > 0 -@mark.api +@pytest.mark.api def test_load_end_only(): data = aeon.load( nonmonotonic_path, exp02.Patch2.Encoder, end=pd.Timestamp("2022-06-06T13:00:49") @@ -28,26 +29,28 @@ def test_load_end_only(): assert len(data) > 0 -@mark.api +@pytest.mark.api def test_load_filter_nonchunked(): - data = aeon.load(nonmonotonic_path, exp02.Metadata, start=pd.Timestamp("2022-06-06T09:00:00")) + data = aeon.load( + nonmonotonic_path, exp02.Metadata, start=pd.Timestamp("2022-06-06T09:00:00") + ) assert len(data) > 0 -@mark.api +@pytest.mark.api def test_load_monotonic(): data = aeon.load(monotonic_path, exp02.Patch2.Encoder) assert len(data) > 0 assert data.index.is_monotonic_increasing -@mark.api +@pytest.mark.api def test_load_nonmonotonic(): data = aeon.load(nonmonotonic_path, exp02.Patch2.Encoder) assert not data.index.is_monotonic_increasing -@mark.api +@pytest.mark.api def test_load_encoder_with_downsampling(): DOWNSAMPLE_PERIOD = 0.02 data = aeon.load(monotonic_path, social03.Patch2.Encoder) diff --git a/tests/io/test_reader.py b/tests/io/test_reader.py index 640768ab..e702c7e1 100644 --- a/tests/io/test_reader.py +++ b/tests/io/test_reader.py @@ -1,25 +1,25 @@ +"""Tests for the Pose stream.""" + from pathlib import Path import pytest -from pytest import mark import aeon from aeon.schema.schemas import social02, social03 pose_path = Path(__file__).parent.parent / "data" / "pose" - -@mark.api +@pytest.mark.api def test_Pose_read_local_model_dir(): + """Test that the Pose stream can read a local model directory.""" data = aeon.load(pose_path, social02.CameraTop.Pose) assert len(data) > 0 - -@mark.api +@pytest.mark.api def test_Pose_read_local_model_dir_with_register_prefix(): + """Test that the Pose stream can read a local model directory with a register prefix.""" data = aeon.load(pose_path, social03.CameraTop.Pose) assert len(data) > 0 - if __name__ == "__main__": pytest.main()