From 6005df1288b06724aeb4a506d86a3b50d9a81185 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 11:09:46 +0000 Subject: [PATCH 001/143] fix: resolve D100 error --- aeon/analysis/movies.py | 6 +- aeon/analysis/plotting.py | 11 +- aeon/analysis/utils.py | 42 +- aeon/dj_pipeline/acquisition.py | 96 +++- aeon/dj_pipeline/analysis/block_analysis.py | 542 +++++++++++++----- aeon/dj_pipeline/analysis/visit.py | 47 +- aeon/dj_pipeline/analysis/visit_analysis.py | 132 +++-- .../create_experiment_01.py | 42 +- .../create_experiment_02.py | 7 +- .../create_experiments/create_octagon_1.py | 7 +- .../create_experiments/create_presocial.py | 6 +- .../create_socialexperiment.py | 16 +- .../create_socialexperiment_0.py | 27 +- aeon/dj_pipeline/lab.py | 2 + aeon/dj_pipeline/populate/worker.py | 15 +- aeon/dj_pipeline/qc.py | 27 +- aeon/dj_pipeline/report.py | 70 ++- aeon/dj_pipeline/subject.py | 58 +- aeon/dj_pipeline/tracking.py | 62 +- aeon/dj_pipeline/utils/load_metadata.py | 114 +++- aeon/dj_pipeline/utils/paths.py | 7 +- aeon/dj_pipeline/utils/plotting.py | 147 +++-- aeon/dj_pipeline/utils/streams_maker.py | 33 +- aeon/dj_pipeline/utils/video.py | 2 + aeon/io/api.py | 34 +- aeon/io/device.py | 2 + aeon/io/reader.py | 93 ++- aeon/io/video.py | 10 +- aeon/schema/core.py | 2 + aeon/schema/dataset.py | 2 + aeon/schema/foraging.py | 14 +- aeon/schema/octagon.py | 42 +- aeon/schema/schemas.py | 26 +- aeon/schema/social_01.py | 2 + aeon/schema/social_02.py | 26 +- aeon/schema/social_03.py | 6 +- aeon/schema/streams.py | 2 + pyproject.toml | 1 - tests/dj_pipeline/test_acquisition.py | 21 +- .../test_pipeline_instantiation.py | 9 +- tests/dj_pipeline/test_qc.py | 2 + tests/dj_pipeline/test_tracking.py | 23 +- tests/io/test_api.py | 16 +- 43 files changed, 1382 insertions(+), 469 deletions(-) diff --git a/aeon/analysis/movies.py b/aeon/analysis/movies.py index 3ac3c1e9..b7d49ad7 100644 --- a/aeon/analysis/movies.py +++ b/aeon/analysis/movies.py @@ -1,3 +1,5 @@ +"""Helper functions for processing video data.""" + import math import cv2 @@ -107,7 +109,9 @@ def collatemovie(clipdata, fun): :return: The sequence of processed frames representing the collated movie. """ clipcount = len(clipdata.groupby("clip_sequence").frame_sequence.count()) - allframes = video.frames(clipdata.sort_values(by=["frame_sequence", "clip_sequence"])) + allframes = video.frames( + clipdata.sort_values(by=["frame_sequence", "clip_sequence"]) + ) return groupframes(allframes, clipcount, fun) diff --git a/aeon/analysis/plotting.py b/aeon/analysis/plotting.py index dc6157a7..ed82a519 100644 --- a/aeon/analysis/plotting.py +++ b/aeon/analysis/plotting.py @@ -1,5 +1,6 @@ -import math +"""Helper functions for plotting data.""" +import math import matplotlib.pyplot as plt import numpy as np import pandas as pd @@ -76,7 +77,9 @@ def rateplot( :param Axes, optional ax: The Axes on which to draw the rate plot and raster. """ label = kwargs.pop("label", None) - eventrate = rate(events, window, frequency, weight, start, end, smooth=smooth, center=center) + eventrate = rate( + events, window, frequency, weight, start, end, smooth=smooth, center=center + ) if ax is None: ax = plt.gca() ax.plot( @@ -85,7 +88,9 @@ def rateplot( label=label, **kwargs, ) - ax.vlines(sessiontime(events.index, eventrate.index[0]), -0.2, -0.1, linewidth=1, **kwargs) + ax.vlines( + sessiontime(events.index, eventrate.index[0]), -0.2, -0.1, linewidth=1, **kwargs + ) def set_ymargin(ax, bottom, top): diff --git a/aeon/analysis/utils.py b/aeon/analysis/utils.py index 13fff107..150cc805 100644 --- a/aeon/analysis/utils.py +++ b/aeon/analysis/utils.py @@ -1,3 +1,5 @@ +""" Helper functions for data analysis and visualization.""" + import numpy as np import pandas as pd @@ -48,7 +50,9 @@ def visits(data, onset="Enter", offset="Exit"): data = data.reset_index() data_onset = data[data.event == onset] data_offset = data[data.event == offset] - data = pd.merge(data_onset, data_offset, on="id", how="left", suffixes=(lsuffix, rsuffix)) + data = pd.merge( + data_onset, data_offset, on="id", how="left", suffixes=(lsuffix, rsuffix) + ) # valid pairings have the smallest positive duration data["duration"] = data[time_offset] - data[time_onset] @@ -59,18 +63,29 @@ def visits(data, onset="Enter", offset="Exit"): # duplicate offsets indicate missing data from previous pairing missing_data = data.duplicated(subset=time_offset, keep="last") if missing_data.any(): - data.loc[missing_data, ["duration"] + [name for name in data.columns if rsuffix in name]] = pd.NA + data.loc[ + missing_data, + ["duration"] + [name for name in data.columns if rsuffix in name], + ] = pd.NA # rename columns and sort data - data.rename({time_onset: lonset, id_onset: "id", time_offset: loffset}, axis=1, inplace=True) - data = data[["id"] + [name for name in data.columns if "_" in name] + [lonset, loffset, "duration"]] + data.rename( + {time_onset: lonset, id_onset: "id", time_offset: loffset}, axis=1, inplace=True + ) + data = data[ + ["id"] + + [name for name in data.columns if "_" in name] + + [lonset, loffset, "duration"] + ] data.drop([event_onset, event_offset], axis=1, inplace=True) data.sort_index(inplace=True) data.reset_index(drop=True, inplace=True) return data -def rate(events, window, frequency, weight=1, start=None, end=None, smooth=None, center=False): +def rate( + events, window, frequency, weight=1, start=None, end=None, smooth=None, center=False +): """Computes the continuous event rate from a discrete event sequence. The window size and sampling frequency can be specified. @@ -99,7 +114,14 @@ def rate(events, window, frequency, weight=1, start=None, end=None, smooth=None, def get_events_rates( - events, window_len_sec, frequency, unit_len_sec=60, start=None, end=None, smooth=None, center=False + events, + window_len_sec, + frequency, + unit_len_sec=60, + start=None, + end=None, + smooth=None, + center=False, ): """Computes the event rate from a sequence of events over a specified window.""" # events is an array with the time (in seconds) of event occurence @@ -114,7 +136,9 @@ def get_events_rates( counts.sort_index(inplace=True) counts_resampled = counts.resample(frequency).sum() counts_rolled = ( - counts_resampled.rolling(window_len_sec_str, center=center).sum() * unit_len_sec / window_len_sec + counts_resampled.rolling(window_len_sec_str, center=center).sum() + * unit_len_sec + / window_len_sec ) counts_rolled_smoothed = counts_rolled.rolling( window_len_sec_str if smooth is None else smooth, center=center @@ -142,6 +166,8 @@ def activepatch(wheel, in_patch): :return: A pandas Series specifying for each timepoint whether the patch is active. """ exit_patch = in_patch.astype(np.int8).diff() < 0 - in_wheel = (wheel.diff().rolling("1s").sum() > 1).reindex(in_patch.index, method="pad") + in_wheel = (wheel.diff().rolling("1s").sum() > 1).reindex( + in_patch.index, method="pad" + ) epochs = exit_patch.cumsum() return in_wheel.groupby(epochs).apply(lambda x: x.cumsum()) > 0 diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index 38499455..c5307ecc 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -1,3 +1,5 @@ +""" DataJoint schema for the acquisition pipeline. """ + import datetime import json import pathlib @@ -162,7 +164,10 @@ def get_data_directories(cls, experiment_key, directory_types=None, as_posix=Fal return [ d for dir_type in directory_types - if (d := cls.get_data_directory(experiment_key, dir_type, as_posix=as_posix)) is not None + if ( + d := cls.get_data_directory(experiment_key, dir_type, as_posix=as_posix) + ) + is not None ] @@ -190,7 +195,9 @@ def ingest_epochs(cls, experiment_name): for i, (_, chunk) in enumerate(all_chunks.iterrows()): chunk_rep_file = pathlib.Path(chunk.path) epoch_dir = pathlib.Path(chunk_rep_file.as_posix().split(device_name)[0]) - epoch_start = datetime.datetime.strptime(epoch_dir.name, "%Y-%m-%dT%H-%M-%S") + epoch_start = datetime.datetime.strptime( + epoch_dir.name, "%Y-%m-%dT%H-%M-%S" + ) # --- insert to Epoch --- epoch_key = {"experiment_name": experiment_name, "epoch_start": epoch_start} @@ -209,11 +216,15 @@ def ingest_epochs(cls, experiment_name): if i > 0: previous_chunk = all_chunks.iloc[i - 1] previous_chunk_path = pathlib.Path(previous_chunk.path) - previous_epoch_dir = pathlib.Path(previous_chunk_path.as_posix().split(device_name)[0]) + previous_epoch_dir = pathlib.Path( + previous_chunk_path.as_posix().split(device_name)[0] + ) previous_epoch_start = datetime.datetime.strptime( previous_epoch_dir.name, "%Y-%m-%dT%H-%M-%S" ) - previous_chunk_end = previous_chunk.name + datetime.timedelta(hours=io_api.CHUNK_DURATION) + previous_chunk_end = previous_chunk.name + datetime.timedelta( + hours=io_api.CHUNK_DURATION + ) previous_epoch_end = min(previous_chunk_end, epoch_start) previous_epoch_key = { "experiment_name": experiment_name, @@ -223,7 +234,11 @@ def ingest_epochs(cls, experiment_name): with cls.connection.transaction: # insert new epoch cls.insert1( - {**epoch_key, **directory, "epoch_dir": epoch_dir.relative_to(raw_data_dir).as_posix()} + { + **epoch_key, + **directory, + "epoch_dir": epoch_dir.relative_to(raw_data_dir).as_posix(), + } ) epoch_list.append(epoch_key) @@ -238,7 +253,9 @@ def ingest_epochs(cls, experiment_name): { **previous_epoch_key, "epoch_end": previous_epoch_end, - "epoch_duration": (previous_epoch_end - previous_epoch_start).total_seconds() + "epoch_duration": ( + previous_epoch_end - previous_epoch_start + ).total_seconds() / 3600, } ) @@ -310,17 +327,23 @@ def make(self, key): experiment_name = key["experiment_name"] devices_schema = getattr( aeon_schemas, - (Experiment.DevicesSchema & {"experiment_name": experiment_name}).fetch1("devices_schema_name"), + (Experiment.DevicesSchema & {"experiment_name": experiment_name}).fetch1( + "devices_schema_name" + ), ) dir_type, epoch_dir = (Epoch & key).fetch1("directory_type", "epoch_dir") data_dir = Experiment.get_data_directory(key, dir_type) metadata_yml_filepath = data_dir / epoch_dir / "Metadata.yml" - epoch_config = extract_epoch_config(experiment_name, devices_schema, metadata_yml_filepath) + epoch_config = extract_epoch_config( + experiment_name, devices_schema, metadata_yml_filepath + ) epoch_config = { **epoch_config, - "metadata_file_path": metadata_yml_filepath.relative_to(data_dir).as_posix(), + "metadata_file_path": metadata_yml_filepath.relative_to( + data_dir + ).as_posix(), } # Insert new entries for streams.DeviceType, streams.Device. @@ -331,15 +354,20 @@ def make(self, key): # Define and instantiate new devices/stream tables under `streams` schema streams_maker.main() # Insert devices' installation/removal/settings - epoch_device_types = ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath) + epoch_device_types = ingest_epoch_metadata( + experiment_name, devices_schema, metadata_yml_filepath + ) self.insert1(key) self.Meta.insert1(epoch_config) - self.DeviceType.insert(key | {"device_type": n} for n in epoch_device_types or {}) + self.DeviceType.insert( + key | {"device_type": n} for n in epoch_device_types or {} + ) with metadata_yml_filepath.open("r") as f: metadata = json.load(f) self.ActiveRegion.insert( - {**key, "region_name": k, "region_data": v} for k, v in metadata["ActiveRegion"].items() + {**key, "region_name": k, "region_data": v} + for k, v in metadata["ActiveRegion"].items() ) @@ -377,7 +405,9 @@ def ingest_chunks(cls, experiment_name): for _, chunk in all_chunks.iterrows(): chunk_rep_file = pathlib.Path(chunk.path) epoch_dir = pathlib.Path(chunk_rep_file.as_posix().split(device_name)[0]) - epoch_start = datetime.datetime.strptime(epoch_dir.name, "%Y-%m-%dT%H-%M-%S") + epoch_start = datetime.datetime.strptime( + epoch_dir.name, "%Y-%m-%dT%H-%M-%S" + ) epoch_key = {"experiment_name": experiment_name, "epoch_start": epoch_start} if not (Epoch & epoch_key): @@ -385,7 +415,9 @@ def ingest_chunks(cls, experiment_name): continue chunk_start = chunk.name - chunk_start = max(chunk_start, epoch_start) # first chunk of the epoch starts at epoch_start + chunk_start = max( + chunk_start, epoch_start + ) # first chunk of the epoch starts at epoch_start chunk_end = chunk_start + datetime.timedelta(hours=io_api.CHUNK_DURATION) if EpochEnd & epoch_key: @@ -405,8 +437,12 @@ def ingest_chunks(cls, experiment_name): ) chunk_starts.append(chunk_key["chunk_start"]) - chunk_list.append({**chunk_key, **directory, "chunk_end": chunk_end, **epoch_key}) - file_name_list.append(chunk_rep_file.name) # handle duplicated files in different folders + chunk_list.append( + {**chunk_key, **directory, "chunk_end": chunk_end, **epoch_key} + ) + file_name_list.append( + chunk_rep_file.name + ) # handle duplicated files in different folders # -- files -- file_datetime_str = chunk_rep_file.stem.replace(f"{device_name}_", "") @@ -522,9 +558,9 @@ def make(self, key): data_dirs = Experiment.get_data_directories(key) devices_schema = getattr( aeon_schemas, - (Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) device = devices_schema.Environment @@ -583,12 +619,14 @@ def make(self, key): data_dirs = Experiment.get_data_directories(key) devices_schema = getattr( aeon_schemas, - (Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) device = devices_schema.Environment - stream_reader = device.EnvironmentActiveConfiguration # expecting columns: time, name, value + stream_reader = ( + device.EnvironmentActiveConfiguration + ) # expecting columns: time, name, value stream_data = io_api.load( root=data_dirs, reader=stream_reader, @@ -611,14 +649,18 @@ def _get_all_chunks(experiment_name, device_name): directory_types = ["quality-control", "raw"] raw_data_dirs = { dir_type: Experiment.get_data_directory( - experiment_key={"experiment_name": experiment_name}, directory_type=dir_type, as_posix=False + experiment_key={"experiment_name": experiment_name}, + directory_type=dir_type, + as_posix=False, ) for dir_type in directory_types } raw_data_dirs = {k: v for k, v in raw_data_dirs.items() if v} if not raw_data_dirs: - raise ValueError(f"No raw data directory found for experiment: {experiment_name}") + raise ValueError( + f"No raw data directory found for experiment: {experiment_name}" + ) chunkdata = io_api.load( root=list(raw_data_dirs.values()), @@ -639,7 +681,9 @@ def _match_experiment_directory(experiment_name, path, directories): repo_path = paths.get_repository_path(directory.pop("repository_name")) break else: - raise FileNotFoundError(f"Unable to identify the directory" f" where this chunk is from: {path}") + raise FileNotFoundError( + f"Unable to identify the directory" f" where this chunk is from: {path}" + ) return raw_data_dir, directory, repo_path diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 7e853a5b..735ebfd3 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -1,3 +1,5 @@ +"""Module for block analysis.""" + import itertools import json from collections import defaultdict @@ -65,14 +67,18 @@ def make(self, key): # find the 0s in `pellet_ct` (these are times when the pellet count reset - i.e. new block) # that would mark the start of a new block - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) exp_key = {"experiment_name": key["experiment_name"]} chunk_restriction = acquisition.create_chunk_restriction( key["experiment_name"], chunk_start, chunk_end ) - block_state_query = acquisition.Environment.BlockState & exp_key & chunk_restriction + block_state_query = ( + acquisition.Environment.BlockState & exp_key & chunk_restriction + ) block_state_df = fetch_stream(block_state_query) if block_state_df.empty: self.insert1(key) @@ -95,8 +101,12 @@ def make(self, key): block_entries = [] if not blocks_df.empty: # calculate block end_times (use due_time) and durations - blocks_df["end_time"] = blocks_df["due_time"].apply(lambda x: io_api.aeon(x)) - blocks_df["duration"] = (blocks_df["end_time"] - blocks_df.index).dt.total_seconds() / 3600 + blocks_df["end_time"] = blocks_df["due_time"].apply( + lambda x: io_api.aeon(x) + ) + blocks_df["duration"] = ( + blocks_df["end_time"] - blocks_df.index + ).dt.total_seconds() / 3600 for _, row in blocks_df.iterrows(): block_entries.append( @@ -184,7 +194,9 @@ def make(self, key): tracking.SLEAPTracking, ) for streams_table in streams_tables: - if len(streams_table & chunk_keys) < len(streams_table.key_source & chunk_keys): + if len(streams_table & chunk_keys) < len( + streams_table.key_source & chunk_keys + ): raise ValueError( f"BlockAnalysis Not Ready - {streams_table.__name__} not yet fully ingested for block: {key}. Skipping (to retry later)..." ) @@ -193,10 +205,14 @@ def make(self, key): # For wheel data, downsample to 10Hz final_encoder_fs = 10 - maintenance_period = get_maintenance_periods(key["experiment_name"], block_start, block_end) + maintenance_period = get_maintenance_periods( + key["experiment_name"], block_start, block_end + ) patch_query = ( - streams.UndergroundFeeder.join(streams.UndergroundFeeder.RemovalTime, left=True) + streams.UndergroundFeeder.join( + streams.UndergroundFeeder.RemovalTime, left=True + ) & key & f'"{block_start}" >= underground_feeder_install_time' & f'"{block_end}" < IFNULL(underground_feeder_removal_time, "2200-01-01")' @@ -210,12 +226,14 @@ def make(self, key): streams.UndergroundFeederDepletionState & patch_key & chunk_restriction )[block_start:block_end] - pellet_ts_threshold_df = get_threshold_associated_pellets(patch_key, block_start, block_end) + pellet_ts_threshold_df = get_threshold_associated_pellets( + patch_key, block_start, block_end + ) # wheel encoder data - encoder_df = fetch_stream(streams.UndergroundFeederEncoder & patch_key & chunk_restriction)[ - block_start:block_end - ] + encoder_df = fetch_stream( + streams.UndergroundFeederEncoder & patch_key & chunk_restriction + )[block_start:block_end] # filter out maintenance period based on logs pellet_ts_threshold_df = filter_out_maintenance_periods( pellet_ts_threshold_df, @@ -234,9 +252,13 @@ def make(self, key): ) if depletion_state_df.empty: - raise ValueError(f"No depletion state data found for block {key} - patch: {patch_name}") + raise ValueError( + f"No depletion state data found for block {key} - patch: {patch_name}" + ) - encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled(encoder_df.angle) + encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled( + encoder_df.angle + ) if len(depletion_state_df.rate.unique()) > 1: # multiple patch rates per block is unexpected, log a note and pick the first rate to move forward @@ -267,7 +289,9 @@ def make(self, key): "wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values[ ::wheel_downsampling_factor ], - "wheel_timestamps": encoder_df.index.values[::wheel_downsampling_factor], + "wheel_timestamps": encoder_df.index.values[ + ::wheel_downsampling_factor + ], "patch_threshold": pellet_ts_threshold_df.threshold.values, "patch_threshold_timestamps": pellet_ts_threshold_df.index.values, "patch_rate": patch_rate, @@ -299,7 +323,9 @@ def make(self, key): # positions - query for CameraTop, identity_name matches subject_name, pos_query = ( streams.SpinnakerVideoSource - * tracking.SLEAPTracking.PoseIdentity.proj("identity_name", part_name="anchor_part") + * tracking.SLEAPTracking.PoseIdentity.proj( + "identity_name", part_name="anchor_part" + ) * tracking.SLEAPTracking.Part & key & { @@ -309,18 +335,23 @@ def make(self, key): & chunk_restriction ) pos_df = fetch_stream(pos_query)[block_start:block_end] - pos_df = filter_out_maintenance_periods(pos_df, maintenance_period, block_end) + pos_df = filter_out_maintenance_periods( + pos_df, maintenance_period, block_end + ) if pos_df.empty: continue position_diff = np.sqrt( - np.square(np.diff(pos_df.x.astype(float))) + np.square(np.diff(pos_df.y.astype(float))) + np.square(np.diff(pos_df.x.astype(float))) + + np.square(np.diff(pos_df.y.astype(float))) ) cumsum_distance_travelled = np.concatenate([[0], np.cumsum(position_diff)]) # weights - weight_query = acquisition.Environment.SubjectWeight & key & chunk_restriction + weight_query = ( + acquisition.Environment.SubjectWeight & key & chunk_restriction + ) weight_df = fetch_stream(weight_query)[block_start:block_end] weight_df.query(f"subject_id == '{subject_name}'", inplace=True) @@ -407,7 +438,10 @@ def make(self, key): subjects_positions_df = pd.concat( [ pd.DataFrame( - {"subject_name": [s["subject_name"]] * len(s["position_timestamps"])} + { + "subject_name": [s["subject_name"]] + * len(s["position_timestamps"]) + } | { k: s[k] for k in ( @@ -435,7 +469,8 @@ def make(self, key): "cum_pref_time", ] all_subj_patch_pref_dict = { - p: {s: {a: pd.Series() for a in pref_attrs} for s in subject_names} for p in patch_names + p: {s: {a: pd.Series() for a in pref_attrs} for s in subject_names} + for p in patch_names } for patch in block_patches: @@ -458,11 +493,15 @@ def make(self, key): ).fetch1("attribute_value") patch_center = (int(patch_center["X"]), int(patch_center["Y"])) subjects_xy = subjects_positions_df[["position_x", "position_y"]].values - dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_center) ** 2, axis=1).astype(float)) + dist_to_patch = np.sqrt( + np.sum((subjects_xy - patch_center) ** 2, axis=1).astype(float) + ) dist_to_patch_df = subjects_positions_df[["subject_name"]].copy() dist_to_patch_df["dist_to_patch"] = dist_to_patch - dist_to_patch_wheel_ts_id_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subject_names) + dist_to_patch_wheel_ts_id_df = pd.DataFrame( + index=cum_wheel_dist.index, columns=subject_names + ) dist_to_patch_pel_ts_id_df = pd.DataFrame( index=patch["pellet_timestamps"], columns=subject_names ) @@ -470,10 +509,12 @@ def make(self, key): # Find closest match between pose_df indices and wheel indices if not dist_to_patch_wheel_ts_id_df.empty: dist_to_patch_wheel_ts_subj = pd.merge_asof( - left=pd.DataFrame(dist_to_patch_wheel_ts_id_df[subject_name].copy()).reset_index( - names="time" - ), - right=dist_to_patch_df[dist_to_patch_df["subject_name"] == subject_name] + left=pd.DataFrame( + dist_to_patch_wheel_ts_id_df[subject_name].copy() + ).reset_index(names="time"), + right=dist_to_patch_df[ + dist_to_patch_df["subject_name"] == subject_name + ] .copy() .reset_index(names="time"), on="time", @@ -482,16 +523,18 @@ def make(self, key): direction="nearest", tolerance=pd.Timedelta("100ms"), ) - dist_to_patch_wheel_ts_id_df[subject_name] = dist_to_patch_wheel_ts_subj[ - "dist_to_patch" - ].values + dist_to_patch_wheel_ts_id_df[subject_name] = ( + dist_to_patch_wheel_ts_subj["dist_to_patch"].values + ) # Find closest match between pose_df indices and pel indices if not dist_to_patch_pel_ts_id_df.empty: dist_to_patch_pel_ts_subj = pd.merge_asof( - left=pd.DataFrame(dist_to_patch_pel_ts_id_df[subject_name].copy()).reset_index( - names="time" - ), - right=dist_to_patch_df[dist_to_patch_df["subject_name"] == subject_name] + left=pd.DataFrame( + dist_to_patch_pel_ts_id_df[subject_name].copy() + ).reset_index(names="time"), + right=dist_to_patch_df[ + dist_to_patch_df["subject_name"] == subject_name + ] .copy() .reset_index(names="time"), on="time", @@ -500,9 +543,9 @@ def make(self, key): direction="nearest", tolerance=pd.Timedelta("200ms"), ) - dist_to_patch_pel_ts_id_df[subject_name] = dist_to_patch_pel_ts_subj[ - "dist_to_patch" - ].values + dist_to_patch_pel_ts_id_df[subject_name] = ( + dist_to_patch_pel_ts_subj["dist_to_patch"].values + ) # Get closest subject to patch at each pellet timestep closest_subjects_pellet_ts = dist_to_patch_pel_ts_id_df.idxmin(axis=1) @@ -514,8 +557,12 @@ def make(self, key): wheel_dist = cum_wheel_dist.diff().fillna(cum_wheel_dist.iloc[0]) # Assign wheel dist to closest subject for each wheel timestep for subject_name in subject_names: - subj_idxs = cum_wheel_dist_subj_df[closest_subjects_wheel_ts == subject_name].index - cum_wheel_dist_subj_df.loc[subj_idxs, subject_name] = wheel_dist[subj_idxs] + subj_idxs = cum_wheel_dist_subj_df[ + closest_subjects_wheel_ts == subject_name + ].index + cum_wheel_dist_subj_df.loc[subj_idxs, subject_name] = wheel_dist[ + subj_idxs + ] cum_wheel_dist_subj_df = cum_wheel_dist_subj_df.cumsum(axis=0) # In patch time @@ -523,14 +570,14 @@ def make(self, key): dt = np.median(np.diff(cum_wheel_dist.index)).astype(int) / 1e9 # s # Fill in `all_subj_patch_pref` for subject_name in subject_names: - all_subj_patch_pref_dict[patch["patch_name"]][subject_name]["cum_dist"] = ( - cum_wheel_dist_subj_df[subject_name].values - ) + all_subj_patch_pref_dict[patch["patch_name"]][subject_name][ + "cum_dist" + ] = cum_wheel_dist_subj_df[subject_name].values subject_in_patch = in_patch[subject_name] subject_in_patch_cum_time = subject_in_patch.cumsum().values * dt - all_subj_patch_pref_dict[patch["patch_name"]][subject_name]["cum_time"] = ( - subject_in_patch_cum_time - ) + all_subj_patch_pref_dict[patch["patch_name"]][subject_name][ + "cum_time" + ] = subject_in_patch_cum_time closest_subj_mask = closest_subjects_pellet_ts == subject_name subj_pellets = closest_subjects_pellet_ts[closest_subj_mask] @@ -546,7 +593,9 @@ def make(self, key): "pellet_count": len(subj_pellets), "pellet_timestamps": subj_pellets.index.values, "patch_threshold": subj_patch_thresh, - "wheel_cumsum_distance_travelled": cum_wheel_dist_subj_df[subject_name].values, + "wheel_cumsum_distance_travelled": cum_wheel_dist_subj_df[ + subject_name + ].values, } ) @@ -555,46 +604,72 @@ def make(self, key): for subject_name in subject_names: # Get sum of subj cum wheel dists and cum in patch time all_cum_dist = np.sum( - [all_subj_patch_pref_dict[p][subject_name]["cum_dist"][-1] for p in patch_names] + [ + all_subj_patch_pref_dict[p][subject_name]["cum_dist"][-1] + for p in patch_names + ] ) all_cum_time = np.sum( - [all_subj_patch_pref_dict[p][subject_name]["cum_time"][-1] for p in patch_names] + [ + all_subj_patch_pref_dict[p][subject_name]["cum_time"][-1] + for p in patch_names + ] ) for patch_name in patch_names: cum_pref_dist = ( - all_subj_patch_pref_dict[patch_name][subject_name]["cum_dist"] / all_cum_dist + all_subj_patch_pref_dict[patch_name][subject_name]["cum_dist"] + / all_cum_dist ) cum_pref_dist = np.where(cum_pref_dist < 1e-3, 0, cum_pref_dist) - all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_dist"] = cum_pref_dist + all_subj_patch_pref_dict[patch_name][subject_name][ + "cum_pref_dist" + ] = cum_pref_dist cum_pref_time = ( - all_subj_patch_pref_dict[patch_name][subject_name]["cum_time"] / all_cum_time + all_subj_patch_pref_dict[patch_name][subject_name]["cum_time"] + / all_cum_time ) - all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_time"] = cum_pref_time + all_subj_patch_pref_dict[patch_name][subject_name][ + "cum_pref_time" + ] = cum_pref_time # sum pref at each ts across patches for each subject total_dist_pref = np.sum( np.vstack( - [all_subj_patch_pref_dict[p][subject_name]["cum_pref_dist"] for p in patch_names] + [ + all_subj_patch_pref_dict[p][subject_name]["cum_pref_dist"] + for p in patch_names + ] ), axis=0, ) total_time_pref = np.sum( np.vstack( - [all_subj_patch_pref_dict[p][subject_name]["cum_pref_time"] for p in patch_names] + [ + all_subj_patch_pref_dict[p][subject_name]["cum_pref_time"] + for p in patch_names + ] ), axis=0, ) for patch_name in patch_names: - cum_pref_dist = all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_dist"] - all_subj_patch_pref_dict[patch_name][subject_name]["running_dist_pref"] = np.divide( + cum_pref_dist = all_subj_patch_pref_dict[patch_name][subject_name][ + "cum_pref_dist" + ] + all_subj_patch_pref_dict[patch_name][subject_name][ + "running_dist_pref" + ] = np.divide( cum_pref_dist, total_dist_pref, out=np.zeros_like(cum_pref_dist), where=total_dist_pref != 0, ) - cum_pref_time = all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_time"] - all_subj_patch_pref_dict[patch_name][subject_name]["running_time_pref"] = np.divide( + cum_pref_time = all_subj_patch_pref_dict[patch_name][subject_name][ + "cum_pref_time" + ] + all_subj_patch_pref_dict[patch_name][subject_name][ + "running_time_pref" + ] = np.divide( cum_pref_time, total_time_pref, out=np.zeros_like(cum_pref_time), @@ -606,12 +681,24 @@ def make(self, key): | { "patch_name": p, "subject_name": s, - "cumulative_preference_by_time": all_subj_patch_pref_dict[p][s]["cum_pref_time"], - "cumulative_preference_by_wheel": all_subj_patch_pref_dict[p][s]["cum_pref_dist"], - "running_preference_by_time": all_subj_patch_pref_dict[p][s]["running_time_pref"], - "running_preference_by_wheel": all_subj_patch_pref_dict[p][s]["running_dist_pref"], - "final_preference_by_time": all_subj_patch_pref_dict[p][s]["cum_pref_time"][-1], - "final_preference_by_wheel": all_subj_patch_pref_dict[p][s]["cum_pref_dist"][-1], + "cumulative_preference_by_time": all_subj_patch_pref_dict[p][s][ + "cum_pref_time" + ], + "cumulative_preference_by_wheel": all_subj_patch_pref_dict[p][s][ + "cum_pref_dist" + ], + "running_preference_by_time": all_subj_patch_pref_dict[p][s][ + "running_time_pref" + ], + "running_preference_by_wheel": all_subj_patch_pref_dict[p][s][ + "running_dist_pref" + ], + "final_preference_by_time": all_subj_patch_pref_dict[p][s][ + "cum_pref_time" + ][-1], + "final_preference_by_wheel": all_subj_patch_pref_dict[p][s][ + "cum_pref_dist" + ][-1], } for p, s in itertools.product(patch_names, subject_names) ) @@ -634,7 +721,9 @@ class BlockPatchPlots(dj.Computed): def make(self, key): # Define subject colors and patch styling for plotting - exp_subject_names = (acquisition.Experiment.Subject & key).fetch("subject", order_by="subject") + exp_subject_names = (acquisition.Experiment.Subject & key).fetch( + "subject", order_by="subject" + ) if not len(exp_subject_names): raise ValueError( "No subjects found in the `acquisition.Experiment.Subject`, missing a manual insert step?." @@ -653,7 +742,10 @@ def make(self, key): # Figure 1 - Patch stats: patch means and pellet threshold boxplots # --- subj_patch_info = ( - (BlockSubjectAnalysis.Patch.proj("pellet_timestamps", "patch_threshold") & key) + ( + BlockSubjectAnalysis.Patch.proj("pellet_timestamps", "patch_threshold") + & key + ) .fetch(format="frame") .reset_index() ) @@ -667,28 +759,46 @@ def make(self, key): ["patch_name", "subject_name", "pellet_timestamps", "patch_threshold"] ] min_subj_patch_info = ( - min_subj_patch_info.explode(["pellet_timestamps", "patch_threshold"], ignore_index=True) + min_subj_patch_info.explode( + ["pellet_timestamps", "patch_threshold"], ignore_index=True + ) .dropna() .reset_index(drop=True) ) # Rename and reindex columns min_subj_patch_info.columns = ["patch", "subject", "time", "threshold"] - min_subj_patch_info = min_subj_patch_info.reindex(columns=["time", "patch", "threshold", "subject"]) + min_subj_patch_info = min_subj_patch_info.reindex( + columns=["time", "patch", "threshold", "subject"] + ) # Add patch mean values and block-normalized delivery times to pellet info n_patches = len(patch_info) - patch_mean_info = pd.DataFrame(index=np.arange(n_patches), columns=min_subj_patch_info.columns) + patch_mean_info = pd.DataFrame( + index=np.arange(n_patches), columns=min_subj_patch_info.columns + ) patch_mean_info["subject"] = "mean" patch_mean_info["patch"] = [d["patch_name"] for d in patch_info] - patch_mean_info["threshold"] = [((1 / d["patch_rate"]) + d["patch_offset"]) for d in patch_info] + patch_mean_info["threshold"] = [ + ((1 / d["patch_rate"]) + d["patch_offset"]) for d in patch_info + ] patch_mean_info["time"] = subj_patch_info["block_start"][0] - min_subj_patch_info_plus = pd.concat((patch_mean_info, min_subj_patch_info)).reset_index(drop=True) + min_subj_patch_info_plus = pd.concat( + (patch_mean_info, min_subj_patch_info) + ).reset_index(drop=True) min_subj_patch_info_plus["norm_time"] = ( - (min_subj_patch_info_plus["time"] - min_subj_patch_info_plus["time"].iloc[0]) - / (min_subj_patch_info_plus["time"].iloc[-1] - min_subj_patch_info_plus["time"].iloc[0]) + ( + min_subj_patch_info_plus["time"] + - min_subj_patch_info_plus["time"].iloc[0] + ) + / ( + min_subj_patch_info_plus["time"].iloc[-1] + - min_subj_patch_info_plus["time"].iloc[0] + ) ).round(3) # Plot it - box_colors = ["#0A0A0A"] + list(subject_colors_dict.values()) # subject colors + mean color + box_colors = ["#0A0A0A"] + list( + subject_colors_dict.values() + ) # subject colors + mean color patch_stats_fig = px.box( min_subj_patch_info_plus.sort_values("patch"), x="patch", @@ -718,7 +828,9 @@ def make(self, key): .dropna() .reset_index(drop=True) ) - weights_block.drop(columns=["experiment_name", "block_start"], inplace=True, errors="ignore") + weights_block.drop( + columns=["experiment_name", "block_start"], inplace=True, errors="ignore" + ) weights_block.rename(columns={"weight_timestamps": "time"}, inplace=True) weights_block.set_index("time", inplace=True) weights_block.sort_index(inplace=True) @@ -742,13 +854,17 @@ def make(self, key): # Figure 3 - Cumulative pellet count: over time, per subject, markered by patch # --- # Create dataframe with cumulative pellet count per subject - cum_pel_ct = min_subj_patch_info_plus.sort_values("time").copy().reset_index(drop=True) + cum_pel_ct = ( + min_subj_patch_info_plus.sort_values("time").copy().reset_index(drop=True) + ) patch_means = cum_pel_ct.loc[0:3][["patch", "threshold"]].rename( columns={"threshold": "mean_thresh"} ) patch_means["mean_thresh"] = patch_means["mean_thresh"].astype(float).round(1) cum_pel_ct = cum_pel_ct.merge(patch_means, on="patch", how="left") - cum_pel_ct = cum_pel_ct[~cum_pel_ct["subject"].str.contains("mean")].reset_index(drop=True) + cum_pel_ct = cum_pel_ct[ + ~cum_pel_ct["subject"].str.contains("mean") + ].reset_index(drop=True) cum_pel_ct = ( cum_pel_ct.groupby("subject", group_keys=False) .apply(lambda group: group.assign(counter=np.arange(len(group)) + 1)) @@ -758,7 +874,9 @@ def make(self, key): make_float_cols = ["threshold", "mean_thresh", "norm_time"] cum_pel_ct[make_float_cols] = cum_pel_ct[make_float_cols].astype(float) cum_pel_ct["patch_label"] = ( - cum_pel_ct["patch"] + " μ: " + cum_pel_ct["mean_thresh"].astype(float).round(1).astype(str) + cum_pel_ct["patch"] + + " μ: " + + cum_pel_ct["mean_thresh"].astype(float).round(1).astype(str) ) cum_pel_ct["norm_thresh_val"] = ( (cum_pel_ct["threshold"] - cum_pel_ct["threshold"].min()) @@ -788,7 +906,9 @@ def make(self, key): mode="markers", marker={ "symbol": patch_markers_dict[patch_grp["patch"].iloc[0]], - "color": gen_hex_grad(pel_mrkr_col, patch_grp["norm_thresh_val"]), + "color": gen_hex_grad( + pel_mrkr_col, patch_grp["norm_thresh_val"] + ), "size": 8, }, name=patch_val, @@ -808,7 +928,9 @@ def make(self, key): cum_pel_per_subject_fig = go.Figure() for id_val, id_grp in cum_pel_ct.groupby("subject"): for patch_val, patch_grp in id_grp.groupby("patch"): - cur_p_mean = patch_means[patch_means["patch"] == patch_val]["mean_thresh"].values[0] + cur_p_mean = patch_means[patch_means["patch"] == patch_val][ + "mean_thresh" + ].values[0] cur_p = patch_val.replace("Patch", "P") cum_pel_per_subject_fig.add_trace( go.Scatter( @@ -823,7 +945,9 @@ def make(self, key): # line=dict(width=2, color=subject_colors_dict[id_val]), marker={ "symbol": patch_markers_dict[patch_val], - "color": gen_hex_grad(pel_mrkr_col, patch_grp["norm_thresh_val"]), + "color": gen_hex_grad( + pel_mrkr_col, patch_grp["norm_thresh_val"] + ), "size": 8, }, name=f"{id_val} - {cur_p} - μ: {cur_p_mean}", @@ -840,7 +964,9 @@ def make(self, key): # Figure 5 - Cumulative wheel distance: over time, per subject-patch # --- # Get wheel timestamps for each patch - wheel_ts = (BlockAnalysis.Patch & key).fetch("patch_name", "wheel_timestamps", as_dict=True) + wheel_ts = (BlockAnalysis.Patch & key).fetch( + "patch_name", "wheel_timestamps", as_dict=True + ) wheel_ts = {d["patch_name"]: d["wheel_timestamps"] for d in wheel_ts} # Get subject patch data subj_wheel_cumsum_dist = (BlockSubjectAnalysis.Patch & key).fetch( @@ -860,7 +986,9 @@ def make(self, key): for subj in subject_names: for patch_name in patch_names: cur_cum_wheel_dist = subj_wheel_cumsum_dist[(subj, patch_name)] - cur_p_mean = patch_means[patch_means["patch"] == patch_name]["mean_thresh"].values[0] + cur_p_mean = patch_means[patch_means["patch"] == patch_name][ + "mean_thresh" + ].values[0] cur_p = patch_name.replace("Patch", "P") cum_wheel_dist_fig.add_trace( go.Scatter( @@ -877,7 +1005,10 @@ def make(self, key): ) # Add markers for each pellet cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[(cum_pel_ct["subject"] == subj) & (cum_pel_ct["patch"] == patch_name)], + cum_pel_ct[ + (cum_pel_ct["subject"] == subj) + & (cum_pel_ct["patch"] == patch_name) + ], pd.DataFrame( { "time": wheel_ts[patch_name], @@ -896,11 +1027,15 @@ def make(self, key): mode="markers", marker={ "symbol": patch_markers_dict[patch_name], - "color": gen_hex_grad(pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"]), + "color": gen_hex_grad( + pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"] + ), "size": 8, }, name=f"{subj} - {cur_p} pellets", - customdata=np.stack((cur_cum_pel_ct["threshold"],), axis=-1), + customdata=np.stack( + (cur_cum_pel_ct["threshold"],), axis=-1 + ), hovertemplate="Threshold: %{customdata[0]:.2f} cm", ) ) @@ -914,10 +1049,14 @@ def make(self, key): # --- # Get and format a dataframe with preference data patch_pref = (BlockSubjectAnalysis.Preference & key).fetch(format="frame") - patch_pref.reset_index(level=["experiment_name", "block_start"], drop=True, inplace=True) + patch_pref.reset_index( + level=["experiment_name", "block_start"], drop=True, inplace=True + ) # Replace small vals with 0 small_pref_thresh = 1e-3 - patch_pref["cumulative_preference_by_wheel"] = patch_pref["cumulative_preference_by_wheel"].apply( + patch_pref["cumulative_preference_by_wheel"] = patch_pref[ + "cumulative_preference_by_wheel" + ].apply( lambda arr: np.where(np.array(arr) < small_pref_thresh, 0, np.array(arr)) ) @@ -925,14 +1064,18 @@ def calculate_running_preference(group, pref_col, out_col): # Sum pref at each ts total_pref = np.sum(np.vstack(group[pref_col].values), axis=0) # Calculate running pref - group[out_col] = group[pref_col].apply(lambda x: np.nan_to_num(x / total_pref, 0.0)) + group[out_col] = group[pref_col].apply( + lambda x: np.nan_to_num(x / total_pref, 0.0) + ) return group patch_pref = ( patch_pref.groupby("subject_name") .apply( lambda group: calculate_running_preference( - group, "cumulative_preference_by_wheel", "running_preference_by_wheel" + group, + "cumulative_preference_by_wheel", + "running_preference_by_wheel", ) ) .droplevel(0) @@ -952,8 +1095,12 @@ def calculate_running_preference(group, pref_col, out_col): # Add trace for each subject-patch combo for subj in subject_names: for patch_name in patch_names: - cur_run_wheel_pref = patch_pref.loc[patch_name].loc[subj]["running_preference_by_wheel"] - cur_p_mean = patch_means[patch_means["patch"] == patch_name]["mean_thresh"].values[0] + cur_run_wheel_pref = patch_pref.loc[patch_name].loc[subj][ + "running_preference_by_wheel" + ] + cur_p_mean = patch_means[patch_means["patch"] == patch_name][ + "mean_thresh" + ].values[0] cur_p = patch_name.replace("Patch", "P") running_pref_by_wheel_plot.add_trace( go.Scatter( @@ -970,7 +1117,10 @@ def calculate_running_preference(group, pref_col, out_col): ) # Add markers for each pellet cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[(cum_pel_ct["subject"] == subj) & (cum_pel_ct["patch"] == patch_name)], + cum_pel_ct[ + (cum_pel_ct["subject"] == subj) + & (cum_pel_ct["patch"] == patch_name) + ], pd.DataFrame( { "time": wheel_ts[patch_name], @@ -989,11 +1139,15 @@ def calculate_running_preference(group, pref_col, out_col): mode="markers", marker={ "symbol": patch_markers_dict[patch_name], - "color": gen_hex_grad(pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"]), + "color": gen_hex_grad( + pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"] + ), "size": 8, }, name=f"{subj} - {cur_p} pellets", - customdata=np.stack((cur_cum_pel_ct["threshold"],), axis=-1), + customdata=np.stack( + (cur_cum_pel_ct["threshold"],), axis=-1 + ), hovertemplate="Threshold: %{customdata[0]:.2f} cm", ) ) @@ -1009,8 +1163,12 @@ def calculate_running_preference(group, pref_col, out_col): # Add trace for each subject-patch combo for subj in subject_names: for patch_name in patch_names: - cur_run_time_pref = patch_pref.loc[patch_name].loc[subj]["running_preference_by_time"] - cur_p_mean = patch_means[patch_means["patch"] == patch_name]["mean_thresh"].values[0] + cur_run_time_pref = patch_pref.loc[patch_name].loc[subj][ + "running_preference_by_time" + ] + cur_p_mean = patch_means[patch_means["patch"] == patch_name][ + "mean_thresh" + ].values[0] cur_p = patch_name.replace("Patch", "P") running_pref_by_patch_fig.add_trace( go.Scatter( @@ -1027,7 +1185,10 @@ def calculate_running_preference(group, pref_col, out_col): ) # Add markers for each pellet cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[(cum_pel_ct["subject"] == subj) & (cum_pel_ct["patch"] == patch_name)], + cum_pel_ct[ + (cum_pel_ct["subject"] == subj) + & (cum_pel_ct["patch"] == patch_name) + ], pd.DataFrame( { "time": wheel_ts[patch_name], @@ -1046,11 +1207,15 @@ def calculate_running_preference(group, pref_col, out_col): mode="markers", marker={ "symbol": patch_markers_dict[patch_name], - "color": gen_hex_grad(pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"]), + "color": gen_hex_grad( + pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"] + ), "size": 8, }, name=f"{subj} - {cur_p} pellets", - customdata=np.stack((cur_cum_pel_ct["threshold"],), axis=-1), + customdata=np.stack( + (cur_cum_pel_ct["threshold"],), axis=-1 + ), hovertemplate="Threshold: %{customdata[0]:.2f} cm", ) ) @@ -1064,7 +1229,9 @@ def calculate_running_preference(group, pref_col, out_col): # Figure 8 - Weighted patch preference: weighted by 'wheel_dist_spun : pel_ct' ratio # --- # Create multi-indexed dataframe with weighted distance for each subject-patch pair - pel_patches = [p for p in patch_names if "dummy" not in p.lower()] # exclude dummy patches + pel_patches = [ + p for p in patch_names if "dummy" not in p.lower() + ] # exclude dummy patches data = [] for patch in pel_patches: for subject in subject_names: @@ -1077,12 +1244,16 @@ def calculate_running_preference(group, pref_col, out_col): } ) subj_wheel_pel_weighted_dist = pd.DataFrame(data) - subj_wheel_pel_weighted_dist.set_index(["patch_name", "subject_name"], inplace=True) + subj_wheel_pel_weighted_dist.set_index( + ["patch_name", "subject_name"], inplace=True + ) subj_wheel_pel_weighted_dist["weighted_dist"] = np.nan # Calculate weighted distance subject_patch_data = (BlockSubjectAnalysis.Patch() & key).fetch(format="frame") - subject_patch_data.reset_index(level=["experiment_name", "block_start"], drop=True, inplace=True) + subject_patch_data.reset_index( + level=["experiment_name", "block_start"], drop=True, inplace=True + ) subj_wheel_pel_weighted_dist = defaultdict(lambda: defaultdict(dict)) for s in subject_names: for p in pel_patches: @@ -1090,11 +1261,14 @@ def calculate_running_preference(group, pref_col, out_col): cur_wheel_cum_dist_df = pd.DataFrame(columns=["time", "cum_wheel_dist"]) cur_wheel_cum_dist_df["time"] = wheel_ts[p] cur_wheel_cum_dist_df["cum_wheel_dist"] = ( - subject_patch_data.loc[p].loc[s]["wheel_cumsum_distance_travelled"] + 1 + subject_patch_data.loc[p].loc[s]["wheel_cumsum_distance_travelled"] + + 1 ) # Get cumulative pellet count cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[(cum_pel_ct["subject"] == s) & (cum_pel_ct["patch"] == p)], + cum_pel_ct[ + (cum_pel_ct["subject"] == s) & (cum_pel_ct["patch"] == p) + ], cur_wheel_cum_dist_df.sort_values("time"), on="time", direction="forward", @@ -1113,7 +1287,9 @@ def calculate_running_preference(group, pref_col, out_col): on="time", direction="forward", ) - max_weight = cur_cum_pel_ct.iloc[-1]["counter"] + 1 # for values after last pellet + max_weight = ( + cur_cum_pel_ct.iloc[-1]["counter"] + 1 + ) # for values after last pellet merged_df["counter"] = merged_df["counter"].fillna(max_weight) merged_df["weighted_cum_wheel_dist"] = ( merged_df.groupby("counter") @@ -1124,7 +1300,9 @@ def calculate_running_preference(group, pref_col, out_col): else: weighted_dist = cur_wheel_cum_dist_df["cum_wheel_dist"].values # Assign to dict - subj_wheel_pel_weighted_dist[p][s]["time"] = cur_wheel_cum_dist_df["time"].values + subj_wheel_pel_weighted_dist[p][s]["time"] = cur_wheel_cum_dist_df[ + "time" + ].values subj_wheel_pel_weighted_dist[p][s]["weighted_dist"] = weighted_dist # Convert back to dataframe data = [] @@ -1135,11 +1313,15 @@ def calculate_running_preference(group, pref_col, out_col): "patch_name": p, "subject_name": s, "time": subj_wheel_pel_weighted_dist[p][s]["time"], - "weighted_dist": subj_wheel_pel_weighted_dist[p][s]["weighted_dist"], + "weighted_dist": subj_wheel_pel_weighted_dist[p][s][ + "weighted_dist" + ], } ) subj_wheel_pel_weighted_dist = pd.DataFrame(data) - subj_wheel_pel_weighted_dist.set_index(["patch_name", "subject_name"], inplace=True) + subj_wheel_pel_weighted_dist.set_index( + ["patch_name", "subject_name"], inplace=True + ) # Calculate normalized weighted value def norm_inv_norm(group): @@ -1148,20 +1330,28 @@ def norm_inv_norm(group): inv_norm_dist = 1 / norm_dist inv_norm_dist = inv_norm_dist / (np.sum(inv_norm_dist, axis=0)) # Map each inv_norm_dist back to patch name. - return pd.Series(inv_norm_dist.tolist(), index=group.index, name="norm_value") + return pd.Series( + inv_norm_dist.tolist(), index=group.index, name="norm_value" + ) subj_wheel_pel_weighted_dist["norm_value"] = ( subj_wheel_pel_weighted_dist.groupby("subject_name") .apply(norm_inv_norm) .reset_index(level=0, drop=True) ) - subj_wheel_pel_weighted_dist["wheel_pref"] = patch_pref["running_preference_by_wheel"] + subj_wheel_pel_weighted_dist["wheel_pref"] = patch_pref[ + "running_preference_by_wheel" + ] # Plot it weighted_patch_pref_fig = make_subplots( rows=len(pel_patches), cols=len(subject_names), - subplot_titles=[f"{patch} - {subject}" for patch in pel_patches for subject in subject_names], + subplot_titles=[ + f"{patch} - {subject}" + for patch in pel_patches + for subject in subject_names + ], specs=[[{"secondary_y": True}] * len(subject_names)] * len(pel_patches), shared_xaxes=True, vertical_spacing=0.1, @@ -1342,7 +1532,9 @@ def make(self, key): for id_val, id_grp in centroid_df.groupby("identity_name"): # Add counts of x,y points to a grid that will be used for heatmap img_grid = np.zeros((max_x + 1, max_y + 1)) - points, counts = np.unique(id_grp[["x", "y"]].values, return_counts=True, axis=0) + points, counts = np.unique( + id_grp[["x", "y"]].values, return_counts=True, axis=0 + ) for point, count in zip(points, counts, strict=True): img_grid[point[0], point[1]] = count img_grid /= img_grid.max() # normalize @@ -1351,7 +1543,9 @@ def make(self, key): # so 45 cm/frame ~= 9 px/frame win_sz = 9 # in pixels (ensure odd for centering) kernel = np.ones((win_sz, win_sz)) / win_sz**2 # moving avg kernel - img_grid_p = np.pad(img_grid, win_sz // 2, mode="edge") # pad for full output from convolution + img_grid_p = np.pad( + img_grid, win_sz // 2, mode="edge" + ) # pad for full output from convolution img_grid_smooth = conv2d(img_grid_p, kernel) heatmaps.append((id_val, img_grid_smooth)) @@ -1380,11 +1574,17 @@ def make(self, key): # Figure 3 - Position ethogram # --- # Get Active Region (ROI) locations - epoch_query = acquisition.Epoch & (acquisition.Chunk & key & chunk_restriction).proj("epoch_start") + epoch_query = acquisition.Epoch & ( + acquisition.Chunk & key & chunk_restriction + ).proj("epoch_start") active_region_query = acquisition.EpochConfig.ActiveRegion & epoch_query - roi_locs = dict(zip(*active_region_query.fetch("region_name", "region_data"), strict=True)) + roi_locs = dict( + zip(*active_region_query.fetch("region_name", "region_data"), strict=True) + ) # get RFID reader locations - recent_rfid_query = (acquisition.Experiment.proj() * streams.Device.proj() & key).aggr( + recent_rfid_query = ( + acquisition.Experiment.proj() * streams.Device.proj() & key + ).aggr( streams.RfidReader & f"rfid_reader_install_time <= '{block_start}'", rfid_reader_install_time="max(rfid_reader_install_time)", ) @@ -1394,7 +1594,10 @@ def make(self, key): & "attribute_name = 'Location'" ) rfid_locs = dict( - zip(*rfid_location_query.fetch("rfid_reader_name", "attribute_value"), strict=True) + zip( + *rfid_location_query.fetch("rfid_reader_name", "attribute_value"), + strict=True, + ) ) ## Create position ethogram df @@ -1419,18 +1622,30 @@ def make(self, key): # For each ROI, compute if within ROI for roi in rois: - if roi == "Corridor": # special case for corridor, based on between inner and outer radius + if ( + roi == "Corridor" + ): # special case for corridor, based on between inner and outer radius dist = np.linalg.norm( (np.vstack((centroid_df["x"], centroid_df["y"])).T) - arena_center, axis=1, ) - pos_eth_df[roi] = (dist >= arena_inner_radius) & (dist <= arena_outer_radius) + pos_eth_df[roi] = (dist >= arena_inner_radius) & ( + dist <= arena_outer_radius + ) elif roi == "Nest": # special case for nest, based on 4 corners nest_corners = roi_locs["NestRegion"]["ArrayOfPoint"] - nest_br_x, nest_br_y = int(nest_corners[0]["X"]), int(nest_corners[0]["Y"]) - nest_bl_x, nest_bl_y = int(nest_corners[1]["X"]), int(nest_corners[1]["Y"]) - nest_tl_x, nest_tl_y = int(nest_corners[2]["X"]), int(nest_corners[2]["Y"]) - nest_tr_x, nest_tr_y = int(nest_corners[3]["X"]), int(nest_corners[3]["Y"]) + nest_br_x, nest_br_y = int(nest_corners[0]["X"]), int( + nest_corners[0]["Y"] + ) + nest_bl_x, nest_bl_y = int(nest_corners[1]["X"]), int( + nest_corners[1]["Y"] + ) + nest_tl_x, nest_tl_y = int(nest_corners[2]["X"]), int( + nest_corners[2]["Y"] + ) + nest_tr_x, nest_tr_y = int(nest_corners[3]["X"]), int( + nest_corners[3]["Y"] + ) pos_eth_df[roi] = ( (centroid_df["x"] <= nest_br_x) & (centroid_df["y"] >= nest_br_y) @@ -1444,10 +1659,13 @@ def make(self, key): else: roi_radius = gate_radius if roi == "Gate" else patch_radius # Get ROI coords - roi_x, roi_y = int(rfid_locs[roi + "Rfid"]["X"]), int(rfid_locs[roi + "Rfid"]["Y"]) + roi_x, roi_y = int(rfid_locs[roi + "Rfid"]["X"]), int( + rfid_locs[roi + "Rfid"]["Y"] + ) # Check if in ROI dist = np.linalg.norm( - (np.vstack((centroid_df["x"], centroid_df["y"])).T) - (roi_x, roi_y), + (np.vstack((centroid_df["x"], centroid_df["y"])).T) + - (roi_x, roi_y), axis=1, ) pos_eth_df[roi] = dist < roi_radius @@ -1498,6 +1716,7 @@ def make(self, key): # ---- Foraging Bout Analysis ---- + @schema class BlockForaging(dj.Computed): definition = """ @@ -1549,6 +1768,7 @@ class AnalysisNote(dj.Manual): # ---- Helper Functions ---- + def get_threshold_associated_pellets(patch_key, start, end): """Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time. @@ -1573,7 +1793,9 @@ def get_threshold_associated_pellets(patch_key, start, end): - offset - rate """ - chunk_restriction = acquisition.create_chunk_restriction(patch_key["experiment_name"], start, end) + chunk_restriction = acquisition.create_chunk_restriction( + patch_key["experiment_name"], start, end + ) # Step 1 - fetch data # pellet delivery trigger @@ -1581,9 +1803,9 @@ def get_threshold_associated_pellets(patch_key, start, end): streams.UndergroundFeederDeliverPellet & patch_key & chunk_restriction )[start:end] # beambreak - beambreak_df = fetch_stream(streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction)[ - start:end - ] + beambreak_df = fetch_stream( + streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction + )[start:end] # patch threshold depletion_state_df = fetch_stream( streams.UndergroundFeederDepletionState & patch_key & chunk_restriction @@ -1635,14 +1857,18 @@ def get_threshold_associated_pellets(patch_key, start, end): .set_index("time") .dropna(subset=["beam_break_timestamp"]) ) - pellet_beam_break_df.drop_duplicates(subset="beam_break_timestamp", keep="last", inplace=True) + pellet_beam_break_df.drop_duplicates( + subset="beam_break_timestamp", keep="last", inplace=True + ) # Find pellet delivery triggers that approximately coincide with each threshold update # i.e. nearest pellet delivery within 100ms before or after threshold update pellet_ts_threshold_df = ( pd.merge_asof( depletion_state_df.reset_index(), - pellet_beam_break_df.reset_index().rename(columns={"time": "pellet_timestamp"}), + pellet_beam_break_df.reset_index().rename( + columns={"time": "pellet_timestamp"} + ), left_on="time", right_on="pellet_timestamp", tolerance=pd.Timedelta("100ms"), @@ -1655,8 +1881,12 @@ def get_threshold_associated_pellets(patch_key, start, end): # Clean up the df pellet_ts_threshold_df = pellet_ts_threshold_df.drop(columns=["event_x", "event_y"]) # Shift back the pellet_timestamp values by 1 to match with the previous threshold update - pellet_ts_threshold_df.pellet_timestamp = pellet_ts_threshold_df.pellet_timestamp.shift(-1) - pellet_ts_threshold_df.beam_break_timestamp = pellet_ts_threshold_df.beam_break_timestamp.shift(-1) + pellet_ts_threshold_df.pellet_timestamp = ( + pellet_ts_threshold_df.pellet_timestamp.shift(-1) + ) + pellet_ts_threshold_df.beam_break_timestamp = ( + pellet_ts_threshold_df.beam_break_timestamp.shift(-1) + ) pellet_ts_threshold_df = pellet_ts_threshold_df.dropna( subset=["pellet_timestamp", "beam_break_timestamp"] ) @@ -1683,8 +1913,12 @@ def get_foraging_bouts( Returns: DataFrame containing foraging bouts. Columns: duration, n_pellets, cum_wheel_dist, subject. """ - max_inactive_time = pd.Timedelta(seconds=60) if max_inactive_time is None else max_inactive_time - bout_data = pd.DataFrame(columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"]) + max_inactive_time = ( + pd.Timedelta(seconds=60) if max_inactive_time is None else max_inactive_time + ) + bout_data = pd.DataFrame( + columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"] + ) subject_patch_data = (BlockSubjectAnalysis.Patch() & key).fetch(format="frame") if subject_patch_data.empty: return bout_data @@ -1722,38 +1956,55 @@ def get_foraging_bouts( spun_indices = np.where(diffs > wheel_spun_thresh) patch_spun[spun_indices[1]] = patch_names[spun_indices[0]] patch_spun_df = pd.DataFrame( - {"cum_wheel_dist": comb_cum_wheel_dist, "patch_spun": patch_spun}, index=wheel_ts + {"cum_wheel_dist": comb_cum_wheel_dist, "patch_spun": patch_spun}, + index=wheel_ts, ) wheel_s_r = pd.Timedelta(wheel_ts[1] - wheel_ts[0], unit="ns") max_inactive_win_len = int(max_inactive_time / wheel_s_r) # Find times when foraging - max_windowed_wheel_vals = patch_spun_df["cum_wheel_dist"].shift(-(max_inactive_win_len - 1)).ffill() - foraging_mask = max_windowed_wheel_vals > (patch_spun_df["cum_wheel_dist"] + min_wheel_movement) + max_windowed_wheel_vals = ( + patch_spun_df["cum_wheel_dist"].shift(-(max_inactive_win_len - 1)).ffill() + ) + foraging_mask = max_windowed_wheel_vals > ( + patch_spun_df["cum_wheel_dist"] + min_wheel_movement + ) # Discretize into foraging bouts - bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + (max_inactive_win_len - 1) + bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + ( + max_inactive_win_len - 1 + ) n_samples_in_1s = int(1 / wheel_s_r.total_seconds()) bout_end_indxs = ( np.where(np.diff(foraging_mask, prepend=0) == -1)[0] + (max_inactive_win_len - 1) + n_samples_in_1s ) - bout_end_indxs[-1] = min(bout_end_indxs[-1], len(wheel_ts) - 1) # ensure last bout ends in block + bout_end_indxs[-1] = min( + bout_end_indxs[-1], len(wheel_ts) - 1 + ) # ensure last bout ends in block # Remove bout that starts at block end if bout_start_indxs[-1] >= len(wheel_ts): bout_start_indxs = bout_start_indxs[:-1] bout_end_indxs = bout_end_indxs[:-1] assert len(bout_start_indxs) == len(bout_end_indxs) - bout_durations = (wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs]).astype( # in seconds + bout_durations = ( + wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs] + ).astype( # in seconds "timedelta64[ns]" - ).astype(float) / 1e9 + ).astype( + float + ) / 1e9 bout_starts_ends = np.array( [ (wheel_ts[start_idx], wheel_ts[end_idx]) - for start_idx, end_idx in zip(bout_start_indxs, bout_end_indxs, strict=True) + for start_idx, end_idx in zip( + bout_start_indxs, bout_end_indxs, strict=True + ) ] ) all_pel_ts = np.sort( - np.concatenate([arr for arr in cur_subject_data["pellet_timestamps"] if len(arr) > 0]) + np.concatenate( + [arr for arr in cur_subject_data["pellet_timestamps"] if len(arr) > 0] + ) ) bout_pellets = np.array( [ @@ -1767,7 +2018,8 @@ def get_foraging_bouts( bout_pellets = bout_pellets[bout_pellets >= min_pellets] bout_cum_wheel_dist = np.array( [ - patch_spun_df.loc[end, "cum_wheel_dist"] - patch_spun_df.loc[start, "cum_wheel_dist"] + patch_spun_df.loc[end, "cum_wheel_dist"] + - patch_spun_df.loc[start, "cum_wheel_dist"] for start, end in bout_starts_ends ] ) diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index babae2fb..5dcea011 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -1,3 +1,5 @@ +"""Module for visit-related tables in the analysis schema.""" + import datetime import datajoint as dj import pandas as pd @@ -67,14 +69,14 @@ class Visit(dj.Part): @property def key_source(self): - return dj.U("experiment_name", "place", "overlap_start") & (Visit & VisitEnd).proj( - overlap_start="visit_start" - ) + return dj.U("experiment_name", "place", "overlap_start") & ( + Visit & VisitEnd + ).proj(overlap_start="visit_start") def make(self, key): - visit_starts, visit_ends = (Visit * VisitEnd & key & {"visit_start": key["overlap_start"]}).fetch( - "visit_start", "visit_end" - ) + visit_starts, visit_ends = ( + Visit * VisitEnd & key & {"visit_start": key["overlap_start"]} + ).fetch("visit_start", "visit_end") visit_start = min(visit_starts) visit_end = max(visit_ends) @@ -88,7 +90,9 @@ def make(self, key): if len(overlap_query) <= 1: break overlap_visits.extend( - overlap_query.proj(overlap_start=f'"{key["overlap_start"]}"').fetch(as_dict=True) + overlap_query.proj(overlap_start=f'"{key["overlap_start"]}"').fetch( + as_dict=True + ) ) visit_starts, visit_ends = overlap_query.fetch("visit_start", "visit_end") if visit_start == max(visit_starts) and visit_end == max(visit_ends): @@ -102,7 +106,10 @@ def make(self, key): { **key, "overlap_end": visit_end, - "overlap_duration": (visit_end - key["overlap_start"]).total_seconds() / 3600, + "overlap_duration": ( + visit_end - key["overlap_start"] + ).total_seconds() + / 3600, "subject_count": len({v["subject"] for v in overlap_visits}), } ) @@ -188,16 +195,22 @@ def ingest_environment_visits(experiment_names: list | None = None): def get_maintenance_periods(experiment_name, start, end): # get states from acquisition.Environment.EnvironmentState - chunk_restriction = acquisition.create_chunk_restriction(experiment_name, start, end) + chunk_restriction = acquisition.create_chunk_restriction( + experiment_name, start, end + ) state_query = ( - acquisition.Environment.EnvironmentState & {"experiment_name": experiment_name} & chunk_restriction + acquisition.Environment.EnvironmentState + & {"experiment_name": experiment_name} + & chunk_restriction ) env_state_df = fetch_stream(state_query)[start:end] if env_state_df.empty: return deque([]) env_state_df.reset_index(inplace=True) - env_state_df = env_state_df[env_state_df["state"].shift() != env_state_df["state"]].reset_index( + env_state_df = env_state_df[ + env_state_df["state"].shift() != env_state_df["state"] + ].reset_index( drop=True ) # remove duplicates and keep the first one # An experiment starts with visit start (anything before the first maintenance is experiment) @@ -213,8 +226,12 @@ def get_maintenance_periods(experiment_name, start, end): env_state_df = pd.concat([env_state_df, log_df_end]) env_state_df.reset_index(drop=True, inplace=True) - maintenance_starts = env_state_df.loc[env_state_df["state"] == "Maintenance", "time"].values - maintenance_ends = env_state_df.loc[env_state_df["state"] != "Maintenance", "time"].values + maintenance_starts = env_state_df.loc[ + env_state_df["state"] == "Maintenance", "time" + ].values + maintenance_ends = env_state_df.loc[ + env_state_df["state"] != "Maintenance", "time" + ].values return deque( [ @@ -230,7 +247,9 @@ def filter_out_maintenance_periods(data_df, maintenance_period, end_time, dropna (maintenance_start, maintenance_end) = maint_period[0] if end_time < maintenance_start: # no more maintenance for this date break - maintenance_filter = (data_df.index >= maintenance_start) & (data_df.index <= maintenance_end) + maintenance_filter = (data_df.index >= maintenance_start) & ( + data_df.index <= maintenance_end + ) data_df[maintenance_filter] = np.nan if end_time >= maintenance_end: # remove this range maint_period.popleft() diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 4569193e..8a882f48 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -1,3 +1,5 @@ +"""Module for visit analysis.""" + import datetime from datetime import time @@ -89,7 +91,8 @@ def key_source(self): + chunk starts after visit_start and ends before visit_end (or NOW() - i.e. ongoing visits). """ return ( - Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") * acquisition.Chunk + Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") + * acquisition.Chunk & acquisition.SubjectEnterExit & [ "visit_start BETWEEN chunk_start AND chunk_end", @@ -100,7 +103,9 @@ def key_source(self): ) def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) # -- Determine the time to start time_slicing in this chunk if chunk_start < key["visit_start"] < chunk_end: @@ -169,8 +174,12 @@ def make(self, key): end_time = np.array(end_time, dtype="datetime64[ns]") while time_slice_start < end_time: - time_slice_end = time_slice_start + min(self._time_slice_duration, end_time - time_slice_start) - in_time_slice = np.logical_and(timestamps >= time_slice_start, timestamps < time_slice_end) + time_slice_end = time_slice_start + min( + self._time_slice_duration, end_time - time_slice_start + ) + in_time_slice = np.logical_and( + timestamps >= time_slice_start, timestamps < time_slice_end + ) chunk_time_slices.append( { **key, @@ -194,7 +203,10 @@ def get_position(cls, visit_key=None, subject=None, start=None, end=None): if visit_key is not None: assert len(Visit & visit_key) == 1 start, end = ( - Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") & visit_key + Visit.join(VisitEnd, left=True).proj( + visit_end="IFNULL(visit_end, NOW())" + ) + & visit_key ).fetch1("visit_start", "visit_end") subject = visit_key["subject"] elif all((subject, start, end)): @@ -255,14 +267,18 @@ class FoodPatch(dj.Part): """ # Work on finished visits - key_source = Visit & (VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end") + key_source = Visit & ( + VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end" + ) def make(self, key): visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) ) - maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) + maintenance_period = get_maintenance_periods( + key["experiment_name"], visit_start, visit_end + ) for visit_date in visit_dates: day_start = datetime.datetime.combine(visit_date.date(), time.min) @@ -282,12 +298,16 @@ def make(self, key): subject=key["subject"], start=day_start, end=day_end ) # filter out maintenance period based on logs - position = filter_out_maintenance_periods(position, maintenance_period, day_end) + position = filter_out_maintenance_periods( + position, maintenance_period, day_end + ) # filter for objects of the correct size valid_position = (position.area > 0) & (position.area < 1000) position[~valid_position] = np.nan - position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) + position.rename( + columns={"position_x": "x", "position_y": "y"}, inplace=True + ) # in corridor distance_from_center = tracking.compute_distance( position[["x", "y"]], @@ -331,9 +351,9 @@ def make(self, key): in_food_patch_times = [] for food_patch_key in food_patch_keys: # wheel data - food_patch_description = (acquisition.ExperimentFoodPatch & food_patch_key).fetch1( - "food_patch_description" - ) + food_patch_description = ( + acquisition.ExperimentFoodPatch & food_patch_key + ).fetch1("food_patch_description") wheel_data = acquisition.FoodPatchWheel.get_wheel_data( experiment_name=key["experiment_name"], start=pd.Timestamp(day_start), @@ -342,10 +362,12 @@ def make(self, key): using_aeon_io=True, ) # filter out maintenance period based on logs - wheel_data = filter_out_maintenance_periods(wheel_data, maintenance_period, day_end) - patch_position = (acquisition.ExperimentFoodPatch.Position & food_patch_key).fetch1( - "food_patch_position_x", "food_patch_position_y" + wheel_data = filter_out_maintenance_periods( + wheel_data, maintenance_period, day_end ) + patch_position = ( + acquisition.ExperimentFoodPatch.Position & food_patch_key + ).fetch1("food_patch_position_x", "food_patch_position_y") in_patch = tracking.is_position_in_patch( position, patch_position, @@ -400,14 +422,18 @@ class FoodPatch(dj.Part): """ # Work on finished visits - key_source = Visit & (VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end") + key_source = Visit & ( + VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end" + ) def make(self, key): visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) ) - maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) + maintenance_period = get_maintenance_periods( + key["experiment_name"], visit_start, visit_end + ) for visit_date in visit_dates: day_start = datetime.datetime.combine(visit_date.date(), time.min) @@ -428,12 +454,18 @@ def make(self, key): subject=key["subject"], start=day_start, end=day_end ) # filter out maintenance period based on logs - position = filter_out_maintenance_periods(position, maintenance_period, day_end) + position = filter_out_maintenance_periods( + position, maintenance_period, day_end + ) # filter for objects of the correct size valid_position = (position.area > 0) & (position.area < 1000) position[~valid_position] = np.nan - position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) - position_diff = np.sqrt(np.square(np.diff(position.x)) + np.square(np.diff(position.y))) + position.rename( + columns={"position_x": "x", "position_y": "y"}, inplace=True + ) + position_diff = np.sqrt( + np.square(np.diff(position.x)) + np.square(np.diff(position.y)) + ) total_distance_travelled = np.nansum(position_diff) # in food patches - loop through all in-use patches during this visit @@ -469,9 +501,9 @@ def make(self, key): dropna=True, ).index.values # wheel data - food_patch_description = (acquisition.ExperimentFoodPatch & food_patch_key).fetch1( - "food_patch_description" - ) + food_patch_description = ( + acquisition.ExperimentFoodPatch & food_patch_key + ).fetch1("food_patch_description") wheel_data = acquisition.FoodPatchWheel.get_wheel_data( experiment_name=key["experiment_name"], start=pd.Timestamp(day_start), @@ -480,7 +512,9 @@ def make(self, key): using_aeon_io=True, ) # filter out maintenance period based on logs - wheel_data = filter_out_maintenance_periods(wheel_data, maintenance_period, day_end) + wheel_data = filter_out_maintenance_periods( + wheel_data, maintenance_period, day_end + ) food_patch_statistics.append( { @@ -488,11 +522,15 @@ def make(self, key): **food_patch_key, "visit_date": visit_date.date(), "pellet_count": len(pellet_events), - "wheel_distance_travelled": wheel_data.distance_travelled.values[-1], + "wheel_distance_travelled": wheel_data.distance_travelled.values[ + -1 + ], } ) - total_pellet_count = np.sum([p["pellet_count"] for p in food_patch_statistics]) + total_pellet_count = np.sum( + [p["pellet_count"] for p in food_patch_statistics] + ) total_wheel_distance_travelled = np.sum( [p["wheel_distance_travelled"] for p in food_patch_statistics] ) @@ -526,20 +564,27 @@ class VisitForagingBout(dj.Computed): # Work on 24/7 experiments key_source = ( - Visit & VisitSummary & (VisitEnd & "visit_duration > 24") & "experiment_name= 'exp0.2-r0'" + Visit + & VisitSummary + & (VisitEnd & "visit_duration > 24") + & "experiment_name= 'exp0.2-r0'" ) * acquisition.ExperimentFoodPatch def make(self, key): visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") # get in_patch timestamps - food_patch_description = (acquisition.ExperimentFoodPatch & key).fetch1("food_patch_description") + food_patch_description = (acquisition.ExperimentFoodPatch & key).fetch1( + "food_patch_description" + ) in_patch_times = np.concatenate( - (VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & key).fetch( - "in_patch", order_by="visit_date" - ) + ( + VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & key + ).fetch("in_patch", order_by="visit_date") + ) + maintenance_period = get_maintenance_periods( + key["experiment_name"], visit_start, visit_end ) - maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) in_patch_times = filter_out_maintenance_periods( pd.DataFrame( [[food_patch_description]] * len(in_patch_times), @@ -567,8 +612,12 @@ def make(self, key): .set_index("event_time") ) # TODO: handle multiple retries of pellet delivery - maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) - patch = filter_out_maintenance_periods(patch, maintenance_period, visit_end, True) + maintenance_period = get_maintenance_periods( + key["experiment_name"], visit_start, visit_end + ) + patch = filter_out_maintenance_periods( + patch, maintenance_period, visit_end, True + ) if len(in_patch_times): change_ind = ( @@ -584,7 +633,9 @@ def make(self, key): ts_array = in_patch_times[change_ind[i - 1] : change_ind[i]] wheel_start, wheel_end = ts_array[0], ts_array[-1] - if wheel_start >= wheel_end: # skip if timestamps were misaligned or a single timestamp + if ( + wheel_start >= wheel_end + ): # skip if timestamps were misaligned or a single timestamp continue wheel_data = acquisition.FoodPatchWheel.get_wheel_data( @@ -594,14 +645,19 @@ def make(self, key): patch_name=food_patch_description, using_aeon_io=True, ) - maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) - wheel_data = filter_out_maintenance_periods(wheel_data, maintenance_period, visit_end, True) + maintenance_period = get_maintenance_periods( + key["experiment_name"], visit_start, visit_end + ) + wheel_data = filter_out_maintenance_periods( + wheel_data, maintenance_period, visit_end, True + ) self.insert1( { **key, "bout_start": ts_array[0], "bout_end": ts_array[-1], - "bout_duration": (ts_array[-1] - ts_array[0]) / np.timedelta64(1, "s"), + "bout_duration": (ts_array[-1] - ts_array[0]) + / np.timedelta64(1, "s"), "wheel_distance_travelled": wheel_data.distance_travelled[-1], "pellet_count": len(patch.loc[wheel_start:wheel_end]), } diff --git a/aeon/dj_pipeline/create_experiments/create_experiment_01.py b/aeon/dj_pipeline/create_experiments/create_experiment_01.py index 7b87bf11..aa1f8675 100644 --- a/aeon/dj_pipeline/create_experiments/create_experiment_01.py +++ b/aeon/dj_pipeline/create_experiments/create_experiment_01.py @@ -1,3 +1,5 @@ +"""Functions to populate the database with the metadata for experiment 0.1.""" + import pathlib import yaml @@ -32,7 +34,10 @@ def ingest_exp01_metadata(metadata_yml_filepath, experiment_name): & camera_key ) if current_camera_query: # If the same camera is currently installed - if current_camera_query.fetch1("camera_install_time") == arena_setup["start-time"]: + if ( + current_camera_query.fetch1("camera_install_time") + == arena_setup["start-time"] + ): # If it is installed at the same time as that read from this yml file # then it is the same ExperimentCamera instance, no need to do anything continue @@ -52,7 +57,9 @@ def ingest_exp01_metadata(metadata_yml_filepath, experiment_name): "experiment_name": experiment_name, "camera_install_time": arena_setup["start-time"], "camera_description": camera["description"], - "camera_sampling_rate": device_frequency_mapper[camera["trigger-source"].lower()], + "camera_sampling_rate": device_frequency_mapper[ + camera["trigger-source"].lower() + ], } ) acquisition.ExperimentCamera.Position.insert1( @@ -68,17 +75,23 @@ def ingest_exp01_metadata(metadata_yml_filepath, experiment_name): # ---- Load food patches ---- for patch in arena_setup["patches"]: # ---- Check if this is a new food patch, add to lab.FoodPatch if needed - patch_key = {"food_patch_serial_number": patch["serial-number"] or patch["port-name"]} + patch_key = { + "food_patch_serial_number": patch["serial-number"] or patch["port-name"] + } if patch_key not in lab.FoodPatch(): lab.FoodPatch.insert1(patch_key) # ---- Check if this food patch is currently installed - if so, remove it current_patch_query = ( - acquisition.ExperimentFoodPatch - acquisition.ExperimentFoodPatch.RemovalTime + acquisition.ExperimentFoodPatch + - acquisition.ExperimentFoodPatch.RemovalTime & {"experiment_name": experiment_name} & patch_key ) if current_patch_query: # If the same food-patch is currently installed - if current_patch_query.fetch1("food_patch_install_time") == arena_setup["start-time"]: + if ( + current_patch_query.fetch1("food_patch_install_time") + == arena_setup["start-time"] + ): # If it is installed at the same time as that read from this yml file # then it is the same ExperimentFoodPatch instance, no need to do anything continue @@ -113,16 +126,21 @@ def ingest_exp01_metadata(metadata_yml_filepath, experiment_name): ) # ---- Load weight scales ---- for weight_scale in arena_setup["weight-scales"]: - weight_scale_key = {"weight_scale_serial_number": weight_scale["serial-number"]} + weight_scale_key = { + "weight_scale_serial_number": weight_scale["serial-number"] + } if weight_scale_key not in lab.WeightScale(): lab.WeightScale.insert1(weight_scale_key) # ---- Check if this weight scale is currently installed - if so, remove it current_weight_scale_query = ( - acquisition.ExperimentWeightScale - acquisition.ExperimentWeightScale.RemovalTime + acquisition.ExperimentWeightScale + - acquisition.ExperimentWeightScale.RemovalTime & {"experiment_name": experiment_name} & weight_scale_key ) - if current_weight_scale_query: # If the same weight scale is currently installed + if ( + current_weight_scale_query + ): # If the same weight scale is currently installed if ( current_weight_scale_query.fetch1("weight_scale_install_time") == arena_setup["start-time"] @@ -250,8 +268,12 @@ def add_arena_setup(): # manually update coordinates of foodpatch and nest patch_coordinates = {"Patch1": (1.13, 1.59, 0), "Patch2": (1.19, 0.50, 0)} - for patch_key in (acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name}).fetch("KEY"): - patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1("food_patch_description") + for patch_key in ( + acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name} + ).fetch("KEY"): + patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1( + "food_patch_description" + ) x, y, z = patch_coordinates[patch] acquisition.ExperimentFoodPatch.Position.update1( { diff --git a/aeon/dj_pipeline/create_experiments/create_experiment_02.py b/aeon/dj_pipeline/create_experiments/create_experiment_02.py index 82a8f03f..ef3611ca 100644 --- a/aeon/dj_pipeline/create_experiments/create_experiment_02.py +++ b/aeon/dj_pipeline/create_experiments/create_experiment_02.py @@ -1,3 +1,5 @@ +"""Function to create new experiments for experiment0.2""" + from aeon.dj_pipeline import acquisition, lab, subject # ============ Manual and automatic steps to for experiment 0.2 populate ============ @@ -30,7 +32,10 @@ def create_new_experiment(): skip_duplicates=True, ) acquisition.Experiment.Subject.insert( - [{"experiment_name": experiment_name, "subject": s["subject"]} for s in subject_list], + [ + {"experiment_name": experiment_name, "subject": s["subject"]} + for s in subject_list + ], skip_duplicates=True, ) diff --git a/aeon/dj_pipeline/create_experiments/create_octagon_1.py b/aeon/dj_pipeline/create_experiments/create_octagon_1.py index 4b077e65..ae3e831d 100644 --- a/aeon/dj_pipeline/create_experiments/create_octagon_1.py +++ b/aeon/dj_pipeline/create_experiments/create_octagon_1.py @@ -1,3 +1,5 @@ +"""Function to create new experiments for octagon1.0""" + from aeon.dj_pipeline import acquisition, subject # ============ Manual and automatic steps to for experiment 0.2 populate ============ @@ -33,7 +35,10 @@ def create_new_experiment(): skip_duplicates=True, ) acquisition.Experiment.Subject.insert( - [{"experiment_name": experiment_name, "subject": s["subject"]} for s in subject_list], + [ + {"experiment_name": experiment_name, "subject": s["subject"]} + for s in subject_list + ], skip_duplicates=True, ) diff --git a/aeon/dj_pipeline/create_experiments/create_presocial.py b/aeon/dj_pipeline/create_experiments/create_presocial.py index 05dc0dc8..855f4d00 100644 --- a/aeon/dj_pipeline/create_experiments/create_presocial.py +++ b/aeon/dj_pipeline/create_experiments/create_presocial.py @@ -1,3 +1,5 @@ +"""Function to create new experiments for presocial0.1""" + from aeon.dj_pipeline import acquisition, lab, subject experiment_type = "presocial0.1" @@ -9,7 +11,9 @@ def create_new_experiment(): lab.Location.insert1({"lab": "SWC", "location": location}, skip_duplicates=True) - acquisition.ExperimentType.insert1({"experiment_type": experiment_type}, skip_duplicates=True) + acquisition.ExperimentType.insert1( + {"experiment_type": experiment_type}, skip_duplicates=True + ) acquisition.Experiment.insert( [ diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py index 9a67eadc..501f6893 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py @@ -1,3 +1,5 @@ +"""Function to create new social experiments""" + from pathlib import Path from datetime import datetime from aeon.dj_pipeline import acquisition @@ -37,7 +39,9 @@ def create_new_social_experiment(experiment_name): "experiment_name": experiment_name, "repository_name": "ceph_aeon", "directory_type": dir_type, - "directory_path": (ceph_data_dir / dir_type / machine_name.upper() / exp_name) + "directory_path": ( + ceph_data_dir / dir_type / machine_name.upper() / exp_name + ) .relative_to(ceph_dir) .as_posix(), "load_order": load_order, @@ -50,9 +54,13 @@ def create_new_social_experiment(experiment_name): new_experiment_entry, skip_duplicates=True, ) - acquisition.Experiment.Directory.insert(experiment_directories, skip_duplicates=True) + acquisition.Experiment.Directory.insert( + experiment_directories, skip_duplicates=True + ) acquisition.Experiment.DevicesSchema.insert1( - {"experiment_name": experiment_name, "devices_schema_name": exp_name.replace(".", "")}, + { + "experiment_name": experiment_name, + "devices_schema_name": exp_name.replace(".", ""), + }, skip_duplicates=True, ) - diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py index b50c0a17..79c4cba0 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py @@ -1,7 +1,11 @@ +"""Function to create new experiments for social0-r1""" + import pathlib from aeon.dj_pipeline import acquisition, lab, subject -from aeon.dj_pipeline.create_experiments.create_experiment_01 import ingest_exp01_metadata +from aeon.dj_pipeline.create_experiments.create_experiment_01 import ( + ingest_exp01_metadata, +) # ============ Manual and automatic steps to for experiment 0.1 populate ============ experiment_name = "social0-r1" @@ -33,7 +37,10 @@ def create_new_experiment(): skip_duplicates=True, ) acquisition.Experiment.Subject.insert( - [{"experiment_name": experiment_name, "subject": s["subject"]} for s in subject_list], + [ + {"experiment_name": experiment_name, "subject": s["subject"]} + for s in subject_list + ], skip_duplicates=True, ) @@ -88,8 +95,12 @@ def add_arena_setup(): # manually update coordinates of foodpatch and nest patch_coordinates = {"Patch1": (1.13, 1.59, 0), "Patch2": (1.19, 0.50, 0)} - for patch_key in (acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name}).fetch("KEY"): - patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1("food_patch_description") + for patch_key in ( + acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name} + ).fetch("KEY"): + patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1( + "food_patch_description" + ) x, y, z = patch_coordinates[patch] acquisition.ExperimentFoodPatch.Position.update1( { @@ -147,15 +158,11 @@ def fixID(subjid, valid_ids=None, valid_id_file=None): # The subjid is a combo subjid. if ";" in subjid: subjidA, subjidB = subjid.split(";") - return ( - f"{fixID(subjidA.strip(), valid_ids=valid_ids)};{fixID(subjidB.strip(), valid_ids=valid_ids)}" - ) + return f"{fixID(subjidA.strip(), valid_ids=valid_ids)};{fixID(subjidB.strip(), valid_ids=valid_ids)}" if "vs" in subjid: subjidA, tmp, subjidB = subjid.split(" ")[1:] - return ( - f"{fixID(subjidA.strip(), valid_ids=valid_ids)};{fixID(subjidB.strip(), valid_ids=valid_ids)}" - ) + return f"{fixID(subjidA.strip(), valid_ids=valid_ids)};{fixID(subjidB.strip(), valid_ids=valid_ids)}" try: ld = [jl.levenshtein_distance(subjid, x[-len(subjid) :]) for x in valid_ids] diff --git a/aeon/dj_pipeline/lab.py b/aeon/dj_pipeline/lab.py index 2f10665f..203f4a47 100644 --- a/aeon/dj_pipeline/lab.py +++ b/aeon/dj_pipeline/lab.py @@ -1,3 +1,5 @@ +"""DataJoint schema for the lab pipeline.""" + import datajoint as dj from . import get_schema_name diff --git a/aeon/dj_pipeline/populate/worker.py b/aeon/dj_pipeline/populate/worker.py index 68f1803a..309eea36 100644 --- a/aeon/dj_pipeline/populate/worker.py +++ b/aeon/dj_pipeline/populate/worker.py @@ -1,5 +1,14 @@ +""" +This module defines the workers for the AEON pipeline. +""" + import datajoint as dj -from datajoint_utilities.dj_worker import DataJointWorker, ErrorLog, WorkerLog, RegisteredWorker +from datajoint_utilities.dj_worker import ( + DataJointWorker, + ErrorLog, + WorkerLog, + RegisteredWorker, +) from datajoint_utilities.dj_worker.worker_schema import is_djtable from aeon.dj_pipeline import db_prefix @@ -107,4 +116,6 @@ def ingest_epochs_chunks(): def get_workflow_operation_overview(): from datajoint_utilities.dj_worker.utils import get_workflow_operation_overview - return get_workflow_operation_overview(worker_schema_name=worker_schema_name, db_prefixes=[db_prefix]) + return get_workflow_operation_overview( + worker_schema_name=worker_schema_name, db_prefixes=[db_prefix] + ) diff --git a/aeon/dj_pipeline/qc.py b/aeon/dj_pipeline/qc.py index 7044da0e..cc52d48f 100644 --- a/aeon/dj_pipeline/qc.py +++ b/aeon/dj_pipeline/qc.py @@ -1,3 +1,5 @@ +"""DataJoint schema for the quality control pipeline.""" + import datajoint as dj import numpy as np import pandas as pd @@ -58,7 +60,9 @@ def key_source(self): return ( acquisition.Chunk * ( - streams.SpinnakerVideoSource.join(streams.SpinnakerVideoSource.RemovalTime, left=True) + streams.SpinnakerVideoSource.join( + streams.SpinnakerVideoSource.RemovalTime, left=True + ) & "spinnaker_video_source_name='CameraTop'" ) & "chunk_start >= spinnaker_video_source_install_time" @@ -66,16 +70,21 @@ def key_source(self): ) # CameraTop def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) - device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name") + device_name = (streams.SpinnakerVideoSource & key).fetch1( + "spinnaker_video_source_name" + ) data_dirs = acquisition.Experiment.get_data_directories(key) devices_schema = getattr( acquisition.aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(devices_schema, device_name).Video @@ -102,9 +111,11 @@ def make(self, key): **key, "drop_count": deltas.frame_offset.iloc[-1], "max_harp_delta": deltas.time_delta.max().total_seconds(), - "max_camera_delta": deltas.hw_timestamp_delta.max() / 1e9, # convert to seconds + "max_camera_delta": deltas.hw_timestamp_delta.max() + / 1e9, # convert to seconds "timestamps": videodata.index.values, - "time_delta": deltas.time_delta.values / np.timedelta64(1, "s"), # convert to seconds + "time_delta": deltas.time_delta.values + / np.timedelta64(1, "s"), # convert to seconds "frame_delta": deltas.frame_delta.values, "hw_counter_delta": deltas.hw_counter_delta.values, "hw_timestamp_delta": deltas.hw_timestamp_delta.values, diff --git a/aeon/dj_pipeline/report.py b/aeon/dj_pipeline/report.py index ec88ae7c..0a66dc75 100644 --- a/aeon/dj_pipeline/report.py +++ b/aeon/dj_pipeline/report.py @@ -1,3 +1,5 @@ +"""DataJoint schema dedicated for tables containing figures. """ + import datetime import json import os @@ -19,11 +21,6 @@ os.environ["DJ_SUPPORT_FILEPATH_MANAGEMENT"] = "TRUE" -""" - DataJoint schema dedicated for tables containing figures -""" - - @schema class InArenaSummaryPlot(dj.Computed): definition = """ @@ -33,7 +30,9 @@ class InArenaSummaryPlot(dj.Computed): summary_plot_png: attach """ - key_source = analysis.InArena & analysis.InArenaTimeDistribution & analysis.InArenaSummary + key_source = ( + analysis.InArena & analysis.InArenaTimeDistribution & analysis.InArenaSummary + ) color_code = { "Patch1": "b", @@ -44,15 +43,17 @@ class InArenaSummaryPlot(dj.Computed): } def make(self, key): - in_arena_start, in_arena_end = (analysis.InArena * analysis.InArenaEnd & key).fetch1( - "in_arena_start", "in_arena_end" - ) + in_arena_start, in_arena_end = ( + analysis.InArena * analysis.InArenaEnd & key + ).fetch1("in_arena_start", "in_arena_end") # subject's position data in the time_slices position = analysis.InArenaSubjectPosition.get_position(key) position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) - position_minutes_elapsed = (position.index - in_arena_start).total_seconds() / 60 + position_minutes_elapsed = ( + position.index - in_arena_start + ).total_seconds() / 60 # figure fig = plt.figure(figsize=(20, 9)) @@ -67,12 +68,16 @@ def make(self, key): # position plot non_nan = np.logical_and(~np.isnan(position.x), ~np.isnan(position.y)) - analysis_plotting.heatmap(position[non_nan], 50, ax=position_ax, bins=500, alpha=0.5) + analysis_plotting.heatmap( + position[non_nan], 50, ax=position_ax, bins=500, alpha=0.5 + ) # event rate plots in_arena_food_patches = ( analysis.InArena - * acquisition.ExperimentFoodPatch.join(acquisition.ExperimentFoodPatch.RemovalTime, left=True) + * acquisition.ExperimentFoodPatch.join( + acquisition.ExperimentFoodPatch.RemovalTime, left=True + ) & key & "in_arena_start >= food_patch_install_time" & 'in_arena_start < IFNULL(food_patch_remove_time, "2200-01-01")' @@ -139,7 +144,9 @@ def make(self, key): color=self.color_code[food_patch_key["food_patch_description"]], alpha=0.3, ) - threshold_change_ind = np.where(wheel_threshold[:-1] != wheel_threshold[1:])[0] + threshold_change_ind = np.where( + wheel_threshold[:-1] != wheel_threshold[1:] + )[0] threshold_ax.vlines( wheel_time[threshold_change_ind + 1], ymin=wheel_threshold[threshold_change_ind], @@ -151,17 +158,20 @@ def make(self, key): ) # ethogram - in_arena, in_corridor, arena_time, corridor_time = (analysis.InArenaTimeDistribution & key).fetch1( + in_arena, in_corridor, arena_time, corridor_time = ( + analysis.InArenaTimeDistribution & key + ).fetch1( "in_arena", "in_corridor", "time_fraction_in_arena", "time_fraction_in_corridor", ) - nest_keys, in_nests, nests_times = (analysis.InArenaTimeDistribution.Nest & key).fetch( - "KEY", "in_nest", "time_fraction_in_nest" - ) + nest_keys, in_nests, nests_times = ( + analysis.InArenaTimeDistribution.Nest & key + ).fetch("KEY", "in_nest", "time_fraction_in_nest") patch_names, in_patches, patches_times = ( - analysis.InArenaTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & key + analysis.InArenaTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch + & key ).fetch("food_patch_description", "in_patch", "time_fraction_in_patch") ethogram_ax.plot( @@ -192,7 +202,9 @@ def make(self, key): alpha=0.6, label="nest", ) - for patch_idx, (patch_name, in_patch) in enumerate(zip(patch_names, in_patches)): + for patch_idx, (patch_name, in_patch) in enumerate( + zip(patch_names, in_patches) + ): ethogram_ax.plot( position_minutes_elapsed[in_patch], np.full_like(position_minutes_elapsed[in_patch], (patch_idx + 3)), @@ -233,7 +245,9 @@ def make(self, key): rate_ax.set_title("foraging rate (bin size = 10 min)") distance_ax.set_ylabel("distance travelled (m)") threshold_ax.set_ylabel("threshold") - threshold_ax.set_ylim([threshold_ax.get_ylim()[0] - 100, threshold_ax.get_ylim()[1] + 100]) + threshold_ax.set_ylim( + [threshold_ax.get_ylim()[0] - 100, threshold_ax.get_ylim()[1] + 100] + ) ethogram_ax.set_xlabel("time (min)") analysis_plotting.set_ymargin(distance_ax, 0.2, 0.1) for ax in (rate_ax, distance_ax, pellet_ax, time_dist_ax, threshold_ax): @@ -262,7 +276,9 @@ def make(self, key): # ---- Save fig and insert ---- save_dir = _make_path(key) - fig_dict = _save_figs((fig,), ("summary_plot_png",), save_dir=save_dir, prefix=save_dir.name) + fig_dict = _save_figs( + (fig,), ("summary_plot_png",), save_dir=save_dir, prefix=save_dir.name + ) self.insert1({**key, **fig_dict}) @@ -427,7 +443,10 @@ class VisitDailySummaryPlot(dj.Computed): """ key_source = ( - Visit & analysis.VisitSummary & (VisitEnd & "visit_duration > 24") & "experiment_name= 'exp0.2-r0'" + Visit + & analysis.VisitSummary + & (VisitEnd & "visit_duration > 24") + & "experiment_name= 'exp0.2-r0'" ) def make(self, key): @@ -534,7 +553,12 @@ def _make_path(in_arena_key): experiment_name, subject, in_arena_start = (analysis.InArena & in_arena_key).fetch1( "experiment_name", "subject", "in_arena_start" ) - output_dir = store_stage / experiment_name / subject / in_arena_start.strftime("%y%m%d_%H%M%S_%f") + output_dir = ( + store_stage + / experiment_name + / subject + / in_arena_start.strftime("%y%m%d_%H%M%S_%f") + ) output_dir.mkdir(parents=True, exist_ok=True) return output_dir diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index 20ad8bef..4b79ce72 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -1,3 +1,5 @@ +"""DataJoint schema for animal subjects.""" + import json import os import time @@ -83,7 +85,9 @@ def make(self, key): ) return elif len(animal_resp) > 1: - raise ValueError(f"Found {len(animal_resp)} with eartag {eartag_or_id}, expect one") + raise ValueError( + f"Found {len(animal_resp)} with eartag {eartag_or_id}, expect one" + ) else: animal_resp = animal_resp[0] @@ -96,7 +100,10 @@ def make(self, key): } ) Strain.insert1( - {"strain_id": animal_resp["strain_id"], "strain_name": animal_resp["strain_id"]}, + { + "strain_id": animal_resp["strain_id"], + "strain_name": animal_resp["strain_id"], + }, skip_duplicates=True, ) entry = { @@ -108,7 +115,10 @@ def make(self, key): } if animal_resp["gen_bg_id"] is not None: GeneticBackground.insert1( - {"gen_bg_id": animal_resp["gen_bg_id"], "gen_bg": animal_resp["gen_bg"]}, + { + "gen_bg_id": animal_resp["gen_bg_id"], + "gen_bg": animal_resp["gen_bg"], + }, skip_duplicates=True, ) entry["gen_bg_id"] = animal_resp["gen_bg_id"] @@ -175,17 +185,21 @@ class SubjectReferenceWeight(dj.Manual): def get_reference_weight(cls, subject_name): subj_key = {"subject": subject_name} - food_restrict_query = SubjectProcedure & subj_key & "procedure_name = 'R02 - food restriction'" + food_restrict_query = ( + SubjectProcedure & subj_key & "procedure_name = 'R02 - food restriction'" + ) if food_restrict_query: - ref_date = food_restrict_query.fetch("procedure_date", order_by="procedure_date DESC", limit=1)[ - 0 - ] + ref_date = food_restrict_query.fetch( + "procedure_date", order_by="procedure_date DESC", limit=1 + )[0] else: ref_date = datetime.now().date() weight_query = SubjectWeight & subj_key & f"weight_time < '{ref_date}'" ref_weight = ( - weight_query.fetch("weight", order_by="weight_time DESC", limit=1)[0] if weight_query else -1 + weight_query.fetch("weight", order_by="weight_time DESC", limit=1)[0] + if weight_query + else -1 ) entry = { @@ -242,7 +256,9 @@ def _auto_schedule(self): ): return - PyratIngestionTask.insert1({"pyrat_task_scheduled_time": next_task_schedule_time}) + PyratIngestionTask.insert1( + {"pyrat_task_scheduled_time": next_task_schedule_time} + ) def make(self, key): execution_time = datetime.utcnow() @@ -250,11 +266,15 @@ def make(self, key): new_eartags = [] for responsible_id in lab.User.fetch("responsible_id"): # 1 - retrieve all animals from this user - animal_resp = get_pyrat_data(endpoint="animals", params={"responsible_id": responsible_id}) + animal_resp = get_pyrat_data( + endpoint="animals", params={"responsible_id": responsible_id} + ) for animal_entry in animal_resp: # 2 - find animal with comment - Project Aeon eartag_or_id = animal_entry["eartag_or_id"] - comment_resp = get_pyrat_data(endpoint=f"animals/{eartag_or_id}/comments") + comment_resp = get_pyrat_data( + endpoint=f"animals/{eartag_or_id}/comments" + ) for comment in comment_resp: if comment["attributes"]: first_attr = comment["attributes"][0] @@ -283,7 +303,9 @@ def make(self, key): { **key, "execution_time": execution_time, - "execution_duration": (completion_time - execution_time).total_seconds(), + "execution_duration": ( + completion_time - execution_time + ).total_seconds(), "new_pyrat_entry_count": new_entry_count, } ) @@ -328,7 +350,9 @@ def make(self, key): for cmt in comment_resp: cmt["subject"] = eartag_or_id cmt["attributes"] = json.dumps(cmt["attributes"], default=str) - SubjectComment.insert(comment_resp, skip_duplicates=True, allow_direct_insert=True) + SubjectComment.insert( + comment_resp, skip_duplicates=True, allow_direct_insert=True + ) weight_resp = get_pyrat_data(endpoint=f"animals/{eartag_or_id}/weights") SubjectWeight.insert( @@ -337,7 +361,9 @@ def make(self, key): allow_direct_insert=True, ) - procedure_resp = get_pyrat_data(endpoint=f"animals/{eartag_or_id}/procedures") + procedure_resp = get_pyrat_data( + endpoint=f"animals/{eartag_or_id}/procedures" + ) SubjectProcedure.insert( [{**v, "subject": eartag_or_id} for v in procedure_resp], skip_duplicates=True, @@ -352,7 +378,9 @@ def make(self, key): { **key, "execution_time": execution_time, - "execution_duration": (completion_time - execution_time).total_seconds(), + "execution_duration": ( + completion_time - execution_time + ).total_seconds(), } ) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 01b0a039..842d99f2 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -1,3 +1,5 @@ +"""DataJoint schema for tracking data.""" + from pathlib import Path import datajoint as dj @@ -5,7 +7,14 @@ import numpy as np import pandas as pd -from aeon.dj_pipeline import acquisition, dict_to_uuid, get_schema_name, lab, qc, streams +from aeon.dj_pipeline import ( + acquisition, + dict_to_uuid, + get_schema_name, + lab, + qc, + streams, +) from aeon.io import api as io_api from aeon.schema import schemas as aeon_schemas @@ -72,14 +81,18 @@ def insert_new_params( tracking_paramset_id: int = None, ): if tracking_paramset_id is None: - tracking_paramset_id = (dj.U().aggr(cls, n="max(tracking_paramset_id)").fetch1("n") or 0) + 1 + tracking_paramset_id = ( + dj.U().aggr(cls, n="max(tracking_paramset_id)").fetch1("n") or 0 + ) + 1 param_dict = { "tracking_method": tracking_method, "tracking_paramset_id": tracking_paramset_id, "paramset_description": paramset_description, "params": params, - "param_set_hash": dict_to_uuid({**params, "tracking_method": tracking_method}), + "param_set_hash": dict_to_uuid( + {**params, "tracking_method": tracking_method} + ), } param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} @@ -141,7 +154,9 @@ def key_source(self): return ( acquisition.Chunk * ( - streams.SpinnakerVideoSource.join(streams.SpinnakerVideoSource.RemovalTime, left=True) + streams.SpinnakerVideoSource.join( + streams.SpinnakerVideoSource.RemovalTime, left=True + ) & "spinnaker_video_source_name='CameraTop'" ) * (TrackingParamSet & "tracking_paramset_id = 1") @@ -150,17 +165,22 @@ def key_source(self): ) # SLEAP & CameraTop def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) - device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name") + device_name = (streams.SpinnakerVideoSource & key).fetch1( + "spinnaker_video_source_name" + ) devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(devices_schema, device_name).Pose @@ -172,7 +192,9 @@ def make(self, key): ) if not len(pose_data): - raise ValueError(f"No SLEAP data found for {key['experiment_name']} - {device_name}") + raise ValueError( + f"No SLEAP data found for {key['experiment_name']} - {device_name}" + ) # get identity names class_names = np.unique(pose_data.identity) @@ -205,7 +227,9 @@ def make(self, key): if part == anchor_part: identity_likelihood = part_position.identity_likelihood.values if isinstance(identity_likelihood[0], dict): - identity_likelihood = np.array([v[identity] for v in identity_likelihood]) + identity_likelihood = np.array( + [v[identity] for v in identity_likelihood] + ) pose_identity_entries.append( { @@ -247,7 +271,9 @@ def is_position_in_nest(position_df, nest_key, xcol="x", ycol="y") -> pd.Series: """Given the session key and the position data - arrays of x and y return an array of boolean indicating whether or not a position is inside the nest. """ - nest_vertices = list(zip(*(lab.ArenaNest.Vertex & nest_key).fetch("vertex_x", "vertex_y"))) + nest_vertices = list( + zip(*(lab.ArenaNest.Vertex & nest_key).fetch("vertex_x", "vertex_y")) + ) nest_path = matplotlib.path.Path(nest_vertices) position_df["in_nest"] = nest_path.contains_points(position_df[[xcol, ycol]]) return position_df["in_nest"] @@ -273,7 +299,9 @@ def _get_position( start_query = table & obj_restriction & start_restriction end_query = table & obj_restriction & end_restriction if not (start_query and end_query): - raise ValueError(f"No position data found for {object_name} between {start} and {end}") + raise ValueError( + f"No position data found for {object_name} between {start} and {end}" + ) time_restriction = ( f'{start_attr} >= "{min(start_query.fetch(start_attr))}"' @@ -281,10 +309,14 @@ def _get_position( ) # subject's position data in the time slice - fetched_data = (table & obj_restriction & time_restriction).fetch(*fetch_attrs, order_by=start_attr) + fetched_data = (table & obj_restriction & time_restriction).fetch( + *fetch_attrs, order_by=start_attr + ) if not len(fetched_data[0]): - raise ValueError(f"No position data found for {object_name} between {start} and {end}") + raise ValueError( + f"No position data found for {object_name} between {start} and {end}" + ) timestamp_attr = next(attr for attr in fetch_attrs if "timestamps" in attr) diff --git a/aeon/dj_pipeline/utils/load_metadata.py b/aeon/dj_pipeline/utils/load_metadata.py index ce1c2775..8468a8bd 100644 --- a/aeon/dj_pipeline/utils/load_metadata.py +++ b/aeon/dj_pipeline/utils/load_metadata.py @@ -1,3 +1,7 @@ +""" +Load metadata from the experiment and insert into streams schema. +""" + import datetime import inspect import json @@ -38,7 +42,9 @@ def insert_stream_types(): else: # If the existed stream type does not have the same name: # human error, trying to add the same content with different name - raise dj.DataJointError(f"The specified stream type already exists - name: {pname}") + raise dj.DataJointError( + f"The specified stream type already exists - name: {pname}" + ) else: streams.StreamType.insert1(entry) @@ -51,7 +57,9 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): streams = dj.VirtualModule("streams", streams_maker.schema_name) device_info: dict[dict] = get_device_info(devices_schema) - device_type_mapper, device_sn = get_device_mapper(devices_schema, metadata_yml_filepath) + device_type_mapper, device_sn = get_device_mapper( + devices_schema, metadata_yml_filepath + ) # Add device type to device_info. Only add if device types that are defined in Metadata.yml device_info = { @@ -88,7 +96,8 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): {"device_type": device_type, "stream_type": stream_type} for device_type, stream_list in device_stream_map.items() for stream_type in stream_list - if not streams.DeviceType.Stream & {"device_type": device_type, "stream_type": stream_type} + if not streams.DeviceType.Stream + & {"device_type": device_type, "stream_type": stream_type} ] new_devices = [ @@ -97,7 +106,8 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): "device_type": device_config["device_type"], } for device_name, device_config in device_info.items() - if device_sn[device_name] and not streams.Device & {"device_serial_number": device_sn[device_name]} + if device_sn[device_name] + and not streams.Device & {"device_serial_number": device_sn[device_name]} ] # Insert new entries. @@ -115,7 +125,9 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): streams.Device.insert(new_devices) -def extract_epoch_config(experiment_name: str, devices_schema: DotMap, metadata_yml_filepath: str) -> dict: +def extract_epoch_config( + experiment_name: str, devices_schema: DotMap, metadata_yml_filepath: str +) -> dict: """Parse experiment metadata YAML file and extract epoch configuration. Args: @@ -127,7 +139,9 @@ def extract_epoch_config(experiment_name: str, devices_schema: DotMap, metadata_ dict: epoch_config [dict] """ metadata_yml_filepath = pathlib.Path(metadata_yml_filepath) - epoch_start = datetime.datetime.strptime(metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S") + epoch_start = datetime.datetime.strptime( + metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S" + ) epoch_config: dict = ( io_api.load( metadata_yml_filepath.parent.as_posix(), @@ -144,12 +158,16 @@ def extract_epoch_config(experiment_name: str, devices_schema: DotMap, metadata_ assert commit, f'Neither "Commit" nor "Revision" found in {metadata_yml_filepath}' devices: list[dict] = json.loads( - json.dumps(epoch_config["metadata"]["Devices"], default=lambda x: x.__dict__, indent=4) + json.dumps( + epoch_config["metadata"]["Devices"], default=lambda x: x.__dict__, indent=4 + ) ) # Maintain backward compatibility - In exp02, it is a list of dict. From presocial onward, it's a dict of dict. if isinstance(devices, list): - devices: dict = {d.pop("Name"): d for d in devices} # {deivce_name: device_config} + devices: dict = { + d.pop("Name"): d for d in devices + } # {deivce_name: device_config} return { "experiment_name": experiment_name, @@ -173,15 +191,17 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath experiment_key = {"experiment_name": experiment_name} metadata_yml_filepath = pathlib.Path(metadata_yml_filepath) - epoch_config = extract_epoch_config(experiment_name, devices_schema, metadata_yml_filepath) + epoch_config = extract_epoch_config( + experiment_name, devices_schema, metadata_yml_filepath + ) previous_epoch = (acquisition.Experiment & experiment_key).aggr( acquisition.Epoch & f'epoch_start < "{epoch_config["epoch_start"]}"', epoch_start="MAX(epoch_start)", ) - if len(acquisition.EpochConfig.Meta & previous_epoch) and epoch_config["commit"] == ( - acquisition.EpochConfig.Meta & previous_epoch - ).fetch1("commit"): + if len(acquisition.EpochConfig.Meta & previous_epoch) and epoch_config[ + "commit" + ] == (acquisition.EpochConfig.Meta & previous_epoch).fetch1("commit"): # if identical commit -> no changes return @@ -213,7 +233,9 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath table_entry = { "experiment_name": experiment_name, **device_key, - f"{dj.utils.from_camel_case(table.__name__)}_install_time": epoch_config["epoch_start"], + f"{dj.utils.from_camel_case(table.__name__)}_install_time": epoch_config[ + "epoch_start" + ], f"{dj.utils.from_camel_case(table.__name__)}_name": device_name, } @@ -230,15 +252,21 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath { **table_entry, "attribute_name": "SamplingFrequency", - "attribute_value": video_controller[device_config["TriggerFrequency"]], + "attribute_value": video_controller[ + device_config["TriggerFrequency"] + ], } ) """Check if this device is currently installed. If the same device serial number is currently installed check for any changes in configuration. If not, skip this""" - current_device_query = table - table.RemovalTime & experiment_key & device_key + current_device_query = ( + table - table.RemovalTime & experiment_key & device_key + ) if current_device_query: - current_device_config: list[dict] = (table.Attribute & current_device_query).fetch( + current_device_config: list[dict] = ( + table.Attribute & current_device_query + ).fetch( "experiment_name", "device_serial_number", "attribute_name", @@ -246,7 +274,11 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath as_dict=True, ) new_device_config: list[dict] = [ - {k: v for k, v in entry.items() if dj.utils.from_camel_case(table.__name__) not in k} + { + k: v + for k, v in entry.items() + if dj.utils.from_camel_case(table.__name__) not in k + } for entry in table_attribute_entry ] @@ -256,7 +288,10 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath for config in current_device_config } ) == dict_to_uuid( - {config["attribute_name"]: config["attribute_value"] for config in new_device_config} + { + config["attribute_name"]: config["attribute_value"] + for config in new_device_config + } ): # Skip if none of the configuration has changed. continue @@ -373,10 +408,14 @@ def _get_class_path(obj): "aeon.schema.social", ]: device_info[device_name]["stream_type"].append(stream_type) - device_info[device_name]["stream_reader"].append(_get_class_path(stream_obj)) + device_info[device_name]["stream_reader"].append( + _get_class_path(stream_obj) + ) required_args = [ - k for k in inspect.signature(stream_obj.__init__).parameters if k != "self" + k + for k in inspect.signature(stream_obj.__init__).parameters + if k != "self" ] pattern = schema_dict[device_name][stream_type].get("pattern") schema_dict[device_name][stream_type]["pattern"] = pattern.replace( @@ -384,23 +423,35 @@ def _get_class_path(obj): ) kwargs = { - k: v for k, v in schema_dict[device_name][stream_type].items() if k in required_args + k: v + for k, v in schema_dict[device_name][stream_type].items() + if k in required_args } device_info[device_name]["stream_reader_kwargs"].append(kwargs) # Add hash device_info[device_name]["stream_hash"].append( - dict_to_uuid({**kwargs, "stream_reader": _get_class_path(stream_obj)}) + dict_to_uuid( + {**kwargs, "stream_reader": _get_class_path(stream_obj)} + ) ) else: stream_type = device.__class__.__name__ device_info[device_name]["stream_type"].append(stream_type) device_info[device_name]["stream_reader"].append(_get_class_path(device)) - required_args = {k: None for k in inspect.signature(device.__init__).parameters if k != "self"} + required_args = { + k: None + for k in inspect.signature(device.__init__).parameters + if k != "self" + } pattern = schema_dict[device_name].get("pattern") - schema_dict[device_name]["pattern"] = pattern.replace(device_name, "{pattern}") + schema_dict[device_name]["pattern"] = pattern.replace( + device_name, "{pattern}" + ) - kwargs = {k: v for k, v in schema_dict[device_name].items() if k in required_args} + kwargs = { + k: v for k, v in schema_dict[device_name].items() if k in required_args + } device_info[device_name]["stream_reader_kwargs"].append(kwargs) # Add hash device_info[device_name]["stream_hash"].append( @@ -490,7 +541,9 @@ def ingest_epoch_metadata_octagon(experiment_name, metadata_yml_filepath): ("Wall8", "Wall"), ] - epoch_start = datetime.datetime.strptime(metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S") + epoch_start = datetime.datetime.strptime( + metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S" + ) for device_idx, (device_name, device_type) in enumerate(oct01_devices): device_sn = f"oct01_{device_idx}" @@ -499,8 +552,13 @@ def ingest_epoch_metadata_octagon(experiment_name, metadata_yml_filepath): skip_duplicates=True, ) experiment_table = getattr(streams, f"Experiment{device_type}") - if not (experiment_table & {"experiment_name": experiment_name, "device_serial_number": device_sn}): - experiment_table.insert1((experiment_name, device_sn, epoch_start, device_name)) + if not ( + experiment_table + & {"experiment_name": experiment_name, "device_serial_number": device_sn} + ): + experiment_table.insert1( + (experiment_name, device_sn, epoch_start, device_name) + ) # endregion diff --git a/aeon/dj_pipeline/utils/paths.py b/aeon/dj_pipeline/utils/paths.py index 63a13a1f..1bdc7b7f 100644 --- a/aeon/dj_pipeline/utils/paths.py +++ b/aeon/dj_pipeline/utils/paths.py @@ -1,3 +1,7 @@ +""" +Utility functions for working with paths in the context of the DJ pipeline. +""" + from __future__ import annotations import pathlib @@ -63,5 +67,6 @@ def find_root_directory( except StopIteration: raise FileNotFoundError( - f"No valid root directory found (from {root_directories})" f" for {full_path}" + f"No valid root directory found (from {root_directories})" + f" for {full_path}" ) diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index 01cd14e7..0e3b29cb 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -1,3 +1,5 @@ +"""Utility functions for plotting visit data.""" + import datajoint as dj import numpy as np import pandas as pd @@ -34,13 +36,17 @@ def plot_reward_rate_differences(subject_keys): """ subj_names, sess_starts, rate_timestamps, rate_diffs = ( analysis.InArenaRewardRate & subject_keys - ).fetch("subject", "in_arena_start", "pellet_rate_timestamps", "patch2_patch1_rate_diff") + ).fetch( + "subject", "in_arena_start", "pellet_rate_timestamps", "patch2_patch1_rate_diff" + ) nSessions = len(sess_starts) longest_rateDiff = np.max([len(t) for t in rate_timestamps]) max_session_idx = np.argmax([len(t) for t in rate_timestamps]) - max_session_elapsed_times = rate_timestamps[max_session_idx] - rate_timestamps[max_session_idx][0] + max_session_elapsed_times = ( + rate_timestamps[max_session_idx] - rate_timestamps[max_session_idx][0] + ) x_labels = [t.total_seconds() / 60 for t in max_session_elapsed_times] y_labels = [ @@ -85,12 +91,15 @@ def plot_wheel_travelled_distance(session_keys): ``` """ distance_travelled_query = ( - analysis.InArenaSummary.FoodPatch * acquisition.ExperimentFoodPatch.proj("food_patch_description") + analysis.InArenaSummary.FoodPatch + * acquisition.ExperimentFoodPatch.proj("food_patch_description") & session_keys ) distance_travelled_df = ( - distance_travelled_query.proj("food_patch_description", "wheel_distance_travelled") + distance_travelled_query.proj( + "food_patch_description", "wheel_distance_travelled" + ) .fetch(format="frame") .reset_index() ) @@ -151,7 +160,8 @@ def plot_average_time_distribution(session_keys): & session_keys ) .aggr( - analysis.InArenaTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch, + analysis.InArenaTimeDistribution.FoodPatch + * acquisition.ExperimentFoodPatch, avg_in_patch="AVG(time_fraction_in_patch)", ) .fetch("subject", "food_patch_description", "avg_in_patch") @@ -229,11 +239,15 @@ def plot_visit_daily_summary( .reset_index() ) else: - visit_per_day_df = (VisitSummary & visit_key).fetch(format="frame").reset_index() + visit_per_day_df = ( + (VisitSummary & visit_key).fetch(format="frame").reset_index() + ) if not attr.startswith("total"): attr = "total_" + attr - visit_per_day_df["day"] = visit_per_day_df["visit_date"] - visit_per_day_df["visit_date"].min() + visit_per_day_df["day"] = ( + visit_per_day_df["visit_date"] - visit_per_day_df["visit_date"].min() + ) visit_per_day_df["day"] = visit_per_day_df["day"].dt.days fig = px.bar( @@ -324,10 +338,14 @@ def plot_foraging_bouts_count( else [foraging_bouts["bout_start"].dt.floor("D")] ) - foraging_bouts_count = foraging_bouts.groupby(group_by_attrs).size().reset_index(name="count") + foraging_bouts_count = ( + foraging_bouts.groupby(group_by_attrs).size().reset_index(name="count") + ) visit_start = (VisitEnd & visit_key).fetch1("visit_start") - foraging_bouts_count["day"] = (foraging_bouts_count["bout_start"].dt.date - visit_start.date()).dt.days + foraging_bouts_count["day"] = ( + foraging_bouts_count["bout_start"].dt.date - visit_start.date() + ).dt.days fig = px.bar( foraging_bouts_count, @@ -341,7 +359,10 @@ def plot_foraging_bouts_count( width=700, height=400, template="simple_white", - title=visit_key["subject"] + "
Foraging bouts: count (freq='" + freq + "')", + title=visit_key["subject"] + + "
Foraging bouts: count (freq='" + + freq + + "')", ) fig.update_layout( @@ -413,7 +434,9 @@ def plot_foraging_bouts_distribution( fig = go.Figure() if per_food_patch: - patch_names = (acquisition.ExperimentFoodPatch & visit_key).fetch("food_patch_description") + patch_names = (acquisition.ExperimentFoodPatch & visit_key).fetch( + "food_patch_description" + ) for patch in patch_names: bouts = foraging_bouts[foraging_bouts["food_patch_description"] == patch] fig.add_trace( @@ -440,7 +463,9 @@ def plot_foraging_bouts_distribution( ) fig.update_layout( - title_text=visit_key["subject"] + "
Foraging bouts: " + attr.replace("_", " "), + title_text=visit_key["subject"] + + "
Foraging bouts: " + + attr.replace("_", " "), xaxis_title="date", yaxis_title=attr.replace("_", " "), violingap=0, @@ -449,7 +474,13 @@ def plot_foraging_bouts_distribution( width=700, height=400, template="simple_white", - legend={"orientation": "h", "yanchor": "bottom", "y": 1, "xanchor": "right", "x": 1}, + legend={ + "orientation": "h", + "yanchor": "bottom", + "y": 1, + "xanchor": "right", + "x": 1, + }, ) return fig @@ -472,11 +503,17 @@ def plot_visit_time_distribution(visit_key, freq="D"): region = _get_region_data(visit_key) # Compute time spent per region - time_spent = region.groupby([region.index.floor(freq), "region"]).size().reset_index(name="count") - time_spent["time_fraction"] = time_spent["count"] / time_spent.groupby("timestamps")["count"].transform( - "sum" + time_spent = ( + region.groupby([region.index.floor(freq), "region"]) + .size() + .reset_index(name="count") ) - time_spent["day"] = (time_spent["timestamps"] - time_spent["timestamps"].min()).dt.days + time_spent["time_fraction"] = time_spent["count"] / time_spent.groupby( + "timestamps" + )["count"].transform("sum") + time_spent["day"] = ( + time_spent["timestamps"] - time_spent["timestamps"].min() + ).dt.days fig = px.bar( time_spent, @@ -488,7 +525,10 @@ def plot_visit_time_distribution(visit_key, freq="D"): "time_fraction": "time fraction", "timestamps": "date" if freq == "D" else "time", }, - title=visit_key["subject"] + "
Fraction of time spent in each region (freq='" + freq + "')", + title=visit_key["subject"] + + "
Fraction of time spent in each region (freq='" + + freq + + "')", width=700, height=400, template="simple_white", @@ -511,7 +551,9 @@ def plot_visit_time_distribution(visit_key, freq="D"): return fig -def _get_region_data(visit_key, attrs=["in_nest", "in_arena", "in_corridor", "in_patch"]): +def _get_region_data( + visit_key, attrs=["in_nest", "in_arena", "in_corridor", "in_patch"] +): """Retrieve region data from VisitTimeDistribution tables. Args: @@ -528,7 +570,9 @@ def _get_region_data(visit_key, attrs=["in_nest", "in_arena", "in_corridor", "in for attr in attrs: if attr == "in_nest": # Nest in_nest = np.concatenate( - (VisitTimeDistribution.Nest & visit_key).fetch(attr, order_by="visit_date") + (VisitTimeDistribution.Nest & visit_key).fetch( + attr, order_by="visit_date" + ) ) region = pd.concat( [ @@ -543,14 +587,16 @@ def _get_region_data(visit_key, attrs=["in_nest", "in_arena", "in_corridor", "in elif attr == "in_patch": # Food patch # Find all patches patches = np.unique( - (VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & visit_key).fetch( - "food_patch_description" - ) + ( + VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch + & visit_key + ).fetch("food_patch_description") ) for patch in patches: in_patch = np.concatenate( ( - VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch + VisitTimeDistribution.FoodPatch + * acquisition.ExperimentFoodPatch & visit_key & f"food_patch_description = '{patch}'" ).fetch("in_patch", order_by="visit_date") @@ -582,13 +628,19 @@ def _get_region_data(visit_key, attrs=["in_nest", "in_arena", "in_corridor", "in region = region.sort_index().rename_axis("timestamps") # Exclude data during maintenance - maintenance_period = get_maintenance_periods(visit_key["experiment_name"], visit_start, visit_end) - region = filter_out_maintenance_periods(region, maintenance_period, visit_end, dropna=True) + maintenance_period = get_maintenance_periods( + visit_key["experiment_name"], visit_start, visit_end + ) + region = filter_out_maintenance_periods( + region, maintenance_period, visit_end, dropna=True + ) return region -def plot_weight_patch_data(visit_key, freq="H", smooth_weight=True, min_weight=0, max_weight=35): +def plot_weight_patch_data( + visit_key, freq="H", smooth_weight=True, min_weight=0, max_weight=35 +): """Plot subject weight and patch data (pellet trigger count) per visit. Args: @@ -605,7 +657,9 @@ def plot_weight_patch_data(visit_key, freq="H", smooth_weight=True, min_weight=0 >>> fig = plot_weight_patch_data(visit_key, freq="H", smooth_weight=True) >>> fig = plot_weight_patch_data(visit_key, freq="D") """ - subject_weight = _get_filtered_subject_weight(visit_key, smooth_weight, min_weight, max_weight) + subject_weight = _get_filtered_subject_weight( + visit_key, smooth_weight, min_weight, max_weight + ) # Count pellet trigger per patch per day/hour/... patch = _get_patch_data(visit_key) @@ -633,8 +687,12 @@ def plot_weight_patch_data(visit_key, freq="H", smooth_weight=True, min_weight=0 for p in patch_names: fig.add_trace( go.Bar( - x=patch_summary[patch_summary["food_patch_description"] == p]["event_time"], - y=patch_summary[patch_summary["food_patch_description"] == p]["event_type"], + x=patch_summary[patch_summary["food_patch_description"] == p][ + "event_time" + ], + y=patch_summary[patch_summary["food_patch_description"] == p][ + "event_type" + ], name=p, ), secondary_y=False, @@ -659,7 +717,10 @@ def plot_weight_patch_data(visit_key, freq="H", smooth_weight=True, min_weight=0 fig.update_layout( barmode="stack", hovermode="x", - title_text=visit_key["subject"] + "
Weight and pellet count (freq='" + freq + "')", + title_text=visit_key["subject"] + + "
Weight and pellet count (freq='" + + freq + + "')", xaxis_title="date" if freq == "D" else "time", yaxis={"title": "pellet count"}, yaxis2={"title": "weight"}, @@ -680,7 +741,9 @@ def plot_weight_patch_data(visit_key, freq="H", smooth_weight=True, min_weight=0 return fig -def _get_filtered_subject_weight(visit_key, smooth_weight=True, min_weight=0, max_weight=35): +def _get_filtered_subject_weight( + visit_key, smooth_weight=True, min_weight=0, max_weight=35 +): """Retrieve subject weight from WeightMeasurementFiltered table. Args: @@ -719,7 +782,9 @@ def _get_filtered_subject_weight(visit_key, smooth_weight=True, min_weight=0, ma subject_weight = subject_weight.loc[visit_start:visit_end] # Exclude data during maintenance - maintenance_period = get_maintenance_periods(visit_key["experiment_name"], visit_start, visit_end) + maintenance_period = get_maintenance_periods( + visit_key["experiment_name"], visit_start, visit_end + ) subject_weight = filter_out_maintenance_periods( subject_weight, maintenance_period, visit_end, dropna=True ) @@ -736,7 +801,9 @@ def _get_filtered_subject_weight(visit_key, smooth_weight=True, min_weight=0, ma subject_weight = subject_weight.resample("1T").mean().dropna() if smooth_weight: - subject_weight["weight_subject"] = savgol_filter(subject_weight["weight_subject"], 10, 3) + subject_weight["weight_subject"] = savgol_filter( + subject_weight["weight_subject"], 10, 3 + ) return subject_weight @@ -757,7 +824,9 @@ def _get_patch_data(visit_key): ( dj.U("event_time", "event_type", "food_patch_description") & ( - acquisition.FoodPatchEvent * acquisition.EventType * acquisition.ExperimentFoodPatch + acquisition.FoodPatchEvent + * acquisition.EventType + * acquisition.ExperimentFoodPatch & f'event_time BETWEEN "{visit_start}" AND "{visit_end}"' & 'event_type = "TriggerPellet"' ) @@ -770,7 +839,11 @@ def _get_patch_data(visit_key): # TODO: handle repeat attempts (pellet delivery trigger and beam break) # Exclude data during maintenance - maintenance_period = get_maintenance_periods(visit_key["experiment_name"], visit_start, visit_end) - patch = filter_out_maintenance_periods(patch, maintenance_period, visit_end, dropna=True) + maintenance_period = get_maintenance_periods( + visit_key["experiment_name"], visit_start, visit_end + ) + patch = filter_out_maintenance_periods( + patch, maintenance_period, visit_end, dropna=True + ) return patch diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index 78e5ebaf..2d85c500 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -1,3 +1,5 @@ +"""Module for stream-related tables in the analysis schema.""" + import importlib import inspect import re @@ -103,14 +105,19 @@ def get_device_stream_template(device_type: str, stream_type: str, streams_modul # DeviceDataStream table(s) stream_detail = ( streams_module.StreamType - & (streams_module.DeviceType.Stream & {"device_type": device_type, "stream_type": stream_type}) + & ( + streams_module.DeviceType.Stream + & {"device_type": device_type, "stream_type": stream_type} + ) ).fetch1() for i, n in enumerate(stream_detail["stream_reader"].split(".")): reader = aeon if i == 0 else getattr(reader, n) if reader is aeon.io.reader.Pose: - logger.warning("Automatic generation of stream table for Pose reader is not supported. Skipping...") + logger.warning( + "Automatic generation of stream table for Pose reader is not supported. Skipping..." + ) return None, None stream = reader(**stream_detail["stream_reader_kwargs"]) @@ -140,24 +147,32 @@ def key_source(self): + Chunk(s) that started after {device_type} install time for {device_type} that are not yet removed """ return ( - acquisition.Chunk * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) + acquisition.Chunk + * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) & f"chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time" & f'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, "2200-01-01")' ) def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) - device_name = (ExperimentDevice & key).fetch1(f"{dj.utils.from_camel_case(device_type)}_name") + device_name = (ExperimentDevice & key).fetch1( + f"{dj.utils.from_camel_case(device_type)}_name" + ) devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), + ) + stream_reader = getattr( + getattr(devices_schema, device_name), "{stream_type}" ) - stream_reader = getattr(getattr(devices_schema, device_name), "{stream_type}") stream_data = io_api.load( root=data_dirs, diff --git a/aeon/dj_pipeline/utils/video.py b/aeon/dj_pipeline/utils/video.py index 63b64f24..16a2f27f 100644 --- a/aeon/dj_pipeline/utils/video.py +++ b/aeon/dj_pipeline/utils/video.py @@ -1,3 +1,5 @@ +"""Utility functions for video processing.""" + import base64 from pathlib import Path diff --git a/aeon/io/api.py b/aeon/io/api.py index 5d505ea6..5fe532c3 100644 --- a/aeon/io/api.py +++ b/aeon/io/api.py @@ -1,3 +1,5 @@ +"""API for reading Aeon data from disk.""" + import bisect import datetime from os import PathLike @@ -25,7 +27,9 @@ def chunk(time): return pd.to_datetime(time.dt.date) + pd.to_timedelta(hour, "h") else: hour = CHUNK_DURATION * (time.hour // CHUNK_DURATION) - return pd.to_datetime(datetime.datetime.combine(time.date(), datetime.time(hour=hour))) + return pd.to_datetime( + datetime.datetime.combine(time.date(), datetime.time(hour=hour)) + ) def chunk_range(start, end): @@ -35,7 +39,9 @@ def chunk_range(start, end): :param datetime end: The right bound of the time range. :return: A DatetimeIndex representing the acquisition chunk range. """ - return pd.date_range(chunk(start), chunk(end), freq=pd.DateOffset(hours=CHUNK_DURATION)) + return pd.date_range( + chunk(start), chunk(end), freq=pd.DateOffset(hours=CHUNK_DURATION) + ) def chunk_key(file): @@ -47,7 +53,9 @@ def chunk_key(file): except ValueError: epoch = file.parts[-2] date_str, time_str = epoch.split("T") - return epoch, datetime.datetime.fromisoformat(date_str + "T" + time_str.replace("-", ":")) + return epoch, datetime.datetime.fromisoformat( + date_str + "T" + time_str.replace("-", ":") + ) def _set_index(data): @@ -60,7 +68,9 @@ def _empty(columns): return pd.DataFrame(columns=columns, index=pd.DatetimeIndex([], name="time")) -def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=None, **kwargs): +def load( + root, reader, start=None, end=None, time=None, tolerance=None, epoch=None, **kwargs +): """Extracts chunk data from the root path of an Aeon dataset. Reads all chunk data using the specified data stream reader. A subset of the data can be loaded @@ -87,7 +97,9 @@ def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=No fileset = { chunk_key(fname): fname for path in root - for fname in Path(path).glob(f"{epoch_pattern}/**/{reader.pattern}.{reader.extension}") + for fname in Path(path).glob( + f"{epoch_pattern}/**/{reader.pattern}.{reader.extension}" + ) } files = sorted(fileset.items()) @@ -132,7 +144,9 @@ def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=No if start is not None or end is not None: chunk_start = chunk(start) if start is not None else pd.Timestamp.min chunk_end = chunk(end) if end is not None else pd.Timestamp.max - files = list(filter(lambda item: chunk_start <= chunk(item[0][1]) <= chunk_end, files)) + files = list( + filter(lambda item: chunk_start <= chunk(item[0][1]) <= chunk_end, files) + ) if len(files) == 0: return _empty(reader.columns) @@ -147,11 +161,15 @@ def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=No if not data.index.has_duplicates: warnings.warn( - f"data index for {reader.pattern} contains out-of-order timestamps!", stacklevel=2 + f"data index for {reader.pattern} contains out-of-order timestamps!", + stacklevel=2, ) data = data.sort_index() else: - warnings.warn(f"data index for {reader.pattern} contains duplicate keys!", stacklevel=2) + warnings.warn( + f"data index for {reader.pattern} contains duplicate keys!", + stacklevel=2, + ) data = data[~data.index.duplicated(keep="first")] return data.loc[start:end] return data diff --git a/aeon/io/device.py b/aeon/io/device.py index d7707fb0..62e8b3e1 100644 --- a/aeon/io/device.py +++ b/aeon/io/device.py @@ -1,3 +1,5 @@ +""" Deprecated Device class for grouping multiple Readers into a logical device. """ + import inspect from typing_extensions import deprecated diff --git a/aeon/io/reader.py b/aeon/io/reader.py index bf6e8c23..a4a0738a 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -1,3 +1,5 @@ +""" Module for reading data from raw files in an Aeon dataset.""" + from __future__ import annotations import datetime @@ -66,15 +68,25 @@ def read(self, file): payloadtype = _payloadtypes[data[4] & ~0x10] elementsize = payloadtype.itemsize payloadshape = (length, payloadsize // elementsize) - seconds = np.ndarray(length, dtype=np.uint32, buffer=data, offset=5, strides=stride) - ticks = np.ndarray(length, dtype=np.uint16, buffer=data, offset=9, strides=stride) + seconds = np.ndarray( + length, dtype=np.uint32, buffer=data, offset=5, strides=stride + ) + ticks = np.ndarray( + length, dtype=np.uint16, buffer=data, offset=9, strides=stride + ) seconds = ticks * _SECONDS_PER_TICK + seconds payload = np.ndarray( - payloadshape, dtype=payloadtype, buffer=data, offset=11, strides=(stride, elementsize) + payloadshape, + dtype=payloadtype, + buffer=data, + offset=11, + strides=(stride, elementsize), ) if self.columns is not None and payloadshape[1] < len(self.columns): - data = pd.DataFrame(payload, index=seconds, columns=self.columns[: payloadshape[1]]) + data = pd.DataFrame( + payload, index=seconds, columns=self.columns[: payloadshape[1]] + ) data[self.columns[payloadshape[1] :]] = math.nan return data else: @@ -101,13 +113,17 @@ class Metadata(Reader): """Extracts metadata information from all epochs in the dataset.""" def __init__(self, pattern="Metadata"): - super().__init__(pattern, columns=["workflow", "commit", "metadata"], extension="yml") + super().__init__( + pattern, columns=["workflow", "commit", "metadata"], extension="yml" + ) def read(self, file): """Returns metadata for the specified epoch.""" epoch_str = file.parts[-2] date_str, time_str = epoch_str.split("T") - time = datetime.datetime.fromisoformat(date_str + "T" + time_str.replace("-", ":")) + time = datetime.datetime.fromisoformat( + date_str + "T" + time_str.replace("-", ":") + ) with open(file) as fp: metadata = json.load(fp) workflow = metadata.pop("Workflow") @@ -242,7 +258,9 @@ class Position(Harp): """ def __init__(self, pattern): - super().__init__(pattern, columns=["x", "y", "angle", "major", "minor", "area", "id"]) + super().__init__( + pattern, columns=["x", "y", "angle", "major", "minor", "area", "id"] + ) class BitmaskEvent(Harp): @@ -298,7 +316,9 @@ class Video(Csv): """ def __init__(self, pattern): - super().__init__(pattern, columns=["hw_counter", "hw_timestamp", "_frame", "_path", "_epoch"]) + super().__init__( + pattern, columns=["hw_counter", "hw_timestamp", "_frame", "_path", "_epoch"] + ) self._rawcolumns = ["time"] + self.columns[0:2] def read(self, file): @@ -323,7 +343,9 @@ class (int): Int ID of a subject in the environment. y (float): Y-coordinate of the bodypart. """ - def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/processed"): + def __init__( + self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/processed" + ): """Pose reader constructor.""" # `pattern` for this reader should typically be '_*' super().__init__(pattern, columns=None) @@ -362,10 +384,16 @@ def read(self, file: Path) -> pd.DataFrame: # Drop any repeat parts. unique_parts, unique_idxs = np.unique(parts, return_index=True) repeat_idxs = np.setdiff1d(np.arange(len(parts)), unique_idxs) - if repeat_idxs: # drop x, y, and likelihood cols for repeat parts (skip first 5 cols) + if ( + repeat_idxs + ): # drop x, y, and likelihood cols for repeat parts (skip first 5 cols) init_rep_part_col_idx = (repeat_idxs - 1) * 3 + 5 - rep_part_col_idxs = np.concatenate([np.arange(i, i + 3) for i in init_rep_part_col_idx]) - keep_part_col_idxs = np.setdiff1d(np.arange(len(data.columns)), rep_part_col_idxs) + rep_part_col_idxs = np.concatenate( + [np.arange(i, i + 3) for i in init_rep_part_col_idx] + ) + keep_part_col_idxs = np.setdiff1d( + np.arange(len(data.columns)), rep_part_col_idxs + ) data = data.iloc[:, keep_part_col_idxs] parts = unique_parts @@ -373,22 +401,36 @@ def read(self, file: Path) -> pd.DataFrame: data = self.class_int2str(data, config_file) n_parts = len(parts) part_data_list = [pd.DataFrame()] * n_parts - new_columns = pd.Series(["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]) + new_columns = pd.Series( + ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"] + ) new_data = pd.DataFrame(columns=new_columns) for i, part in enumerate(parts): part_columns = ( - columns[0 : (len(identities) + 1)] if bonsai_sleap_v == BONSAI_SLEAP_V3 else columns[0:2] + columns[0 : (len(identities) + 1)] + if bonsai_sleap_v == BONSAI_SLEAP_V3 + else columns[0:2] ) part_columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"]) part_data = pd.DataFrame(data[part_columns]) if bonsai_sleap_v == BONSAI_SLEAP_V3: # combine all identity_likelihood cols into a single col as dict part_data["identity_likelihood"] = part_data.apply( - lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, axis=1 + lambda row: { + identity: row[f"{identity}_likelihood"] + for identity in identities + }, + axis=1, ) part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True) part_data = part_data[ # reorder columns - ["identity", "identity_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"] + [ + "identity", + "identity_likelihood", + f"{part}_x", + f"{part}_y", + f"{part}_likelihood", + ] ] part_data.insert(2, "part", part) part_data.columns = new_columns @@ -442,10 +484,14 @@ def class_int2str(data: pd.DataFrame, config_file: Path) -> pd.DataFrame: return data @classmethod - def get_config_file(cls, config_file_dir: Path, config_file_names: None | list[str] = None) -> Path: + def get_config_file( + cls, config_file_dir: Path, config_file_names: None | list[str] = None + ) -> Path: """Returns the config file from a model's config directory.""" if config_file_names is None: - config_file_names = ["confmap_config.json"] # SLEAP (add for other trackers to this list) + config_file_names = [ + "confmap_config.json" + ] # SLEAP (add for other trackers to this list) config_file = None for f in config_file_names: if (config_file_dir / f).exists(): @@ -464,14 +510,21 @@ def from_dict(data, pattern=None): return globals()[reader_type](pattern=pattern, **kwargs) return DotMap( - {k: from_dict(v, f"{pattern}_{k}" if pattern is not None else k) for k, v in data.items()} + { + k: from_dict(v, f"{pattern}_{k}" if pattern is not None else k) + for k, v in data.items() + } ) def to_dict(dotmap): """Converts a DotMap object to a dictionary.""" if isinstance(dotmap, Reader): - kwargs = {k: v for k, v in vars(dotmap).items() if k not in ["pattern"] and not k.startswith("_")} + kwargs = { + k: v + for k, v in vars(dotmap).items() + if k not in ["pattern"] and not k.startswith("_") + } kwargs["type"] = type(dotmap).__name__ return kwargs return {k: to_dict(v) for k, v in dotmap.items()} diff --git a/aeon/io/video.py b/aeon/io/video.py index 26c49827..dbdc173b 100644 --- a/aeon/io/video.py +++ b/aeon/io/video.py @@ -1,3 +1,5 @@ +"""This module provides functions to read and write video files using OpenCV.""" + import cv2 @@ -27,7 +29,9 @@ def frames(data): index = frameidx success, frame = capture.read() if not success: - raise ValueError(f'Unable to read frame {frameidx} from video path "{path}".') + raise ValueError( + f'Unable to read frame {frameidx} from video path "{path}".' + ) yield frame index = index + 1 finally: @@ -50,7 +54,9 @@ def export(frames, file, fps, fourcc=None): if writer is None: if fourcc is None: fourcc = cv2.VideoWriter_fourcc("m", "p", "4", "v") # type: ignore - writer = cv2.VideoWriter(file, fourcc, fps, (frame.shape[1], frame.shape[0])) + writer = cv2.VideoWriter( + file, fourcc, fps, (frame.shape[1], frame.shape[0]) + ) writer.write(frame) finally: if writer is not None: diff --git a/aeon/schema/core.py b/aeon/schema/core.py index 6f70c8b4..63c403a6 100644 --- a/aeon/schema/core.py +++ b/aeon/schema/core.py @@ -1,3 +1,5 @@ +"""Schema definition for core Harp data streams.""" + import aeon.io.reader as _reader from aeon.schema.streams import Stream, StreamGroup diff --git a/aeon/schema/dataset.py b/aeon/schema/dataset.py index 0facd64f..225cf335 100644 --- a/aeon/schema/dataset.py +++ b/aeon/schema/dataset.py @@ -1,3 +1,5 @@ +""" Dataset schema definitions. """ + from dotmap import DotMap import aeon.schema.core as stream diff --git a/aeon/schema/foraging.py b/aeon/schema/foraging.py index 0eaf593c..40c4e7f6 100644 --- a/aeon/schema/foraging.py +++ b/aeon/schema/foraging.py @@ -1,3 +1,5 @@ +"""Schema definition for foraging experiments.""" + from enum import Enum import pandas as pd @@ -22,7 +24,9 @@ def __init__(self, pattern): def read(self, file): data = super().read(file) - categorical = pd.Categorical(data.region, categories=range(len(Area._member_names_))) + categorical = pd.Categorical( + data.region, categories=range(len(Area._member_names_)) + ) data["region"] = categorical.rename_categories(Area._member_names_) return data @@ -78,7 +82,9 @@ class BeamBreak(Stream): """Beam break events for pellet detection.""" def __init__(self, pattern): - super().__init__(_reader.BitmaskEvent(f"{pattern}_32_*", 0x22, "PelletDetected")) + super().__init__( + _reader.BitmaskEvent(f"{pattern}_32_*", 0x22, "PelletDetected") + ) class DeliverPellet(Stream): @@ -127,4 +133,6 @@ class SessionData(Stream): """Session metadata for Experiment 0.1.""" def __init__(self, pattern): - super().__init__(_reader.Csv(f"{pattern}_2*", columns=["id", "weight", "event"])) + super().__init__( + _reader.Csv(f"{pattern}_2*", columns=["id", "weight", "event"]) + ) diff --git a/aeon/schema/octagon.py b/aeon/schema/octagon.py index 2ea85b4e..064cae9a 100644 --- a/aeon/schema/octagon.py +++ b/aeon/schema/octagon.py @@ -1,3 +1,5 @@ +""" Octagon schema definition. """ + import aeon.io.reader as _reader from aeon.schema.streams import Stream, StreamGroup @@ -14,24 +16,33 @@ def __init__(self, path): class BackgroundColor(Stream): def __init__(self, pattern): super().__init__( - _reader.Csv(f"{pattern}_backgroundcolor_*", columns=["typetag", "r", "g", "b", "a"]) + _reader.Csv( + f"{pattern}_backgroundcolor_*", + columns=["typetag", "r", "g", "b", "a"], + ) ) class ChangeSubjectState(Stream): def __init__(self, pattern): super().__init__( - _reader.Csv(f"{pattern}_changesubjectstate_*", columns=["typetag", "id", "weight", "event"]) + _reader.Csv( + f"{pattern}_changesubjectstate_*", + columns=["typetag", "id", "weight", "event"], + ) ) class EndTrial(Stream): def __init__(self, pattern): - super().__init__(_reader.Csv(f"{pattern}_endtrial_*", columns=["typetag", "value"])) + super().__init__( + _reader.Csv(f"{pattern}_endtrial_*", columns=["typetag", "value"]) + ) class Slice(Stream): def __init__(self, pattern): super().__init__( _reader.Csv( - f"{pattern}_octagonslice_*", columns=["typetag", "wall_id", "r", "g", "b", "a", "delay"] + f"{pattern}_octagonslice_*", + columns=["typetag", "wall_id", "r", "g", "b", "a", "delay"], ) ) @@ -74,7 +85,8 @@ class Response(Stream): def __init__(self, pattern): super().__init__( _reader.Csv( - f"{pattern}_response_*", columns=["typetag", "wall_id", "poke_id", "response_time"] + f"{pattern}_response_*", + columns=["typetag", "wall_id", "poke_id", "response_time"], ) ) @@ -96,7 +108,9 @@ def __init__(self, pattern): class StartNewSession(Stream): def __init__(self, pattern): - super().__init__(_reader.Csv(f"{pattern}_startnewsession_*", columns=["typetag", "path"])) + super().__init__( + _reader.Csv(f"{pattern}_startnewsession_*", columns=["typetag", "path"]) + ) class TaskLogic(StreamGroup): @@ -109,7 +123,9 @@ def __init__(self, pattern): class Response(Stream): def __init__(self, pattern): - super().__init__(_reader.Harp(f"{pattern}_2_*", columns=["wall_id", "poke_id"])) + super().__init__( + _reader.Harp(f"{pattern}_2_*", columns=["wall_id", "poke_id"]) + ) class PreTrialState(Stream): def __init__(self, pattern): @@ -138,15 +154,21 @@ def __init__(self, path): class BeamBreak0(Stream): def __init__(self, pattern): - super().__init__(_reader.DigitalBitmask(f"{pattern}_32_*", 0x1, columns=["state"])) + super().__init__( + _reader.DigitalBitmask(f"{pattern}_32_*", 0x1, columns=["state"]) + ) class BeamBreak1(Stream): def __init__(self, pattern): - super().__init__(_reader.DigitalBitmask(f"{pattern}_32_*", 0x2, columns=["state"])) + super().__init__( + _reader.DigitalBitmask(f"{pattern}_32_*", 0x2, columns=["state"]) + ) class BeamBreak2(Stream): def __init__(self, pattern): - super().__init__(_reader.DigitalBitmask(f"{pattern}_32_*", 0x4, columns=["state"])) + super().__init__( + _reader.DigitalBitmask(f"{pattern}_32_*", 0x4, columns=["state"]) + ) class SetLed0(Stream): def __init__(self, pattern): diff --git a/aeon/schema/schemas.py b/aeon/schema/schemas.py index 0da2f1bf..06f8598c 100644 --- a/aeon/schema/schemas.py +++ b/aeon/schema/schemas.py @@ -1,3 +1,5 @@ +""" Schemas for different experiments. """ + from dotmap import DotMap import aeon.schema.core as stream @@ -115,7 +117,12 @@ social03 = DotMap( [ Device("Metadata", stream.Metadata), - Device("Environment", social_02.Environment, social_02.SubjectData, social_03.EnvironmentActiveConfiguration), + Device( + "Environment", + social_02.Environment, + social_02.SubjectData, + social_03.EnvironmentActiveConfiguration, + ), Device("CameraTop", stream.Video, social_03.Pose), Device("CameraNorth", stream.Video), Device("CameraSouth", stream.Video), @@ -146,7 +153,12 @@ social04 = DotMap( [ Device("Metadata", stream.Metadata), - Device("Environment", social_02.Environment, social_02.SubjectData, social_03.EnvironmentActiveConfiguration), + Device( + "Environment", + social_02.Environment, + social_02.SubjectData, + social_03.EnvironmentActiveConfiguration, + ), Device("CameraTop", stream.Video, social_03.Pose), Device("CameraNorth", stream.Video), Device("CameraSouth", stream.Video), @@ -174,4 +186,12 @@ ) -__all__ = ["exp01", "exp02", "octagon01", "social01", "social02", "social03", "social04"] +__all__ = [ + "exp01", + "exp02", + "octagon01", + "social01", + "social02", + "social03", + "social04", +] diff --git a/aeon/schema/social_01.py b/aeon/schema/social_01.py index 7f6e2ab0..719ef9a3 100644 --- a/aeon/schema/social_01.py +++ b/aeon/schema/social_01.py @@ -1,3 +1,5 @@ +""" This module contains the schema for the social_01 dataset. """ + import aeon.io.reader as _reader from aeon.schema.streams import Stream diff --git a/aeon/schema/social_02.py b/aeon/schema/social_02.py index 04946679..ce98f081 100644 --- a/aeon/schema/social_02.py +++ b/aeon/schema/social_02.py @@ -1,3 +1,5 @@ +""" This module defines the schema for the social_02 dataset. """ + import aeon.io.reader as _reader from aeon.schema import core, foraging from aeon.schema.streams import Stream, StreamGroup @@ -12,12 +14,17 @@ def __init__(self, path): class BlockState(Stream): def __init__(self, path): super().__init__( - _reader.Csv(f"{path}_BlockState_*", columns=["pellet_ct", "pellet_ct_thresh", "due_time"]) + _reader.Csv( + f"{path}_BlockState_*", + columns=["pellet_ct", "pellet_ct_thresh", "due_time"], + ) ) class LightEvents(Stream): def __init__(self, path): - super().__init__(_reader.Csv(f"{path}_LightEvents_*", columns=["channel", "value"])) + super().__init__( + _reader.Csv(f"{path}_LightEvents_*", columns=["channel", "value"]) + ) MessageLog = core.MessageLog @@ -28,17 +35,22 @@ def __init__(self, path): class SubjectState(Stream): def __init__(self, path): - super().__init__(_reader.Csv(f"{path}_SubjectState_*", columns=["id", "weight", "type"])) + super().__init__( + _reader.Csv(f"{path}_SubjectState_*", columns=["id", "weight", "type"]) + ) class SubjectVisits(Stream): def __init__(self, path): - super().__init__(_reader.Csv(f"{path}_SubjectVisits_*", columns=["id", "type", "region"])) + super().__init__( + _reader.Csv(f"{path}_SubjectVisits_*", columns=["id", "type", "region"]) + ) class SubjectWeight(Stream): def __init__(self, path): super().__init__( _reader.Csv( - f"{path}_SubjectWeight_*", columns=["weight", "confidence", "subject_id", "int_id"] + f"{path}_SubjectWeight_*", + columns=["weight", "confidence", "subject_id", "int_id"], ) ) @@ -64,7 +76,9 @@ def __init__(self, path): class DepletionState(Stream): def __init__(self, path): - super().__init__(_reader.Csv(f"{path}_State_*", columns=["threshold", "offset", "rate"])) + super().__init__( + _reader.Csv(f"{path}_State_*", columns=["threshold", "offset", "rate"]) + ) Encoder = core.Encoder diff --git a/aeon/schema/social_03.py b/aeon/schema/social_03.py index fdb1f7df..c624a06d 100644 --- a/aeon/schema/social_03.py +++ b/aeon/schema/social_03.py @@ -1,3 +1,5 @@ +""" This module contains the schema for the social_03 dataset. """ + import json import pandas as pd import aeon.io.reader as _reader @@ -12,4 +14,6 @@ def __init__(self, path): class EnvironmentActiveConfiguration(Stream): def __init__(self, path): - super().__init__(_reader.JsonList(f"{path}_ActiveConfiguration_*", columns=["name"])) + super().__init__( + _reader.JsonList(f"{path}_ActiveConfiguration_*", columns=["name"]) + ) diff --git a/aeon/schema/streams.py b/aeon/schema/streams.py index 2c5d57b2..a6656f4b 100644 --- a/aeon/schema/streams.py +++ b/aeon/schema/streams.py @@ -1,3 +1,5 @@ +""" Contains classes for defining data streams and devices. """ + import inspect from itertools import chain from warnings import warn diff --git a/pyproject.toml b/pyproject.toml index 658abf5b..2bf2ae8d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,7 +97,6 @@ lint.select = [ ] line-length = 108 lint.ignore = [ - "D100", # skip adding docstrings for module "D104", # ignore missing docstring in public package "D105", # skip adding docstrings for magic methods "D107", # skip adding docstrings for __init__ diff --git a/tests/dj_pipeline/test_acquisition.py b/tests/dj_pipeline/test_acquisition.py index 28f39d12..d54a491e 100644 --- a/tests/dj_pipeline/test_acquisition.py +++ b/tests/dj_pipeline/test_acquisition.py @@ -1,3 +1,5 @@ +""" Tests for the acquisition pipeline. """ + from pytest import mark @@ -16,18 +18,29 @@ def test_epoch_chunk_ingestion(test_params, pipeline, epoch_chunk_ingestion): @mark.ingestion -def test_experimentlog_ingestion(test_params, pipeline, epoch_chunk_ingestion, experimentlog_ingestion): +def test_experimentlog_ingestion( + test_params, pipeline, epoch_chunk_ingestion, experimentlog_ingestion +): acquisition = pipeline["acquisition"] assert ( - len(acquisition.ExperimentLog.Message & {"experiment_name": test_params["experiment_name"]}) + len( + acquisition.ExperimentLog.Message + & {"experiment_name": test_params["experiment_name"]} + ) == test_params["experiment_log_message_count"] ) assert ( - len(acquisition.SubjectEnterExit.Time & {"experiment_name": test_params["experiment_name"]}) + len( + acquisition.SubjectEnterExit.Time + & {"experiment_name": test_params["experiment_name"]} + ) == test_params["subject_enter_exit_count"] ) assert ( - len(acquisition.SubjectWeight.WeightTime & {"experiment_name": test_params["experiment_name"]}) + len( + acquisition.SubjectWeight.WeightTime + & {"experiment_name": test_params["experiment_name"]} + ) == test_params["subject_weight_time_count"] ) diff --git a/tests/dj_pipeline/test_pipeline_instantiation.py b/tests/dj_pipeline/test_pipeline_instantiation.py index cb3b51fb..d6bdc96f 100644 --- a/tests/dj_pipeline/test_pipeline_instantiation.py +++ b/tests/dj_pipeline/test_pipeline_instantiation.py @@ -1,3 +1,5 @@ +""" Tests for pipeline instantiation and experiment creation """ + from pytest import mark @@ -18,9 +20,12 @@ def test_experiment_creation(test_params, pipeline, experiment_creation): experiment_name = test_params["experiment_name"] assert acquisition.Experiment.fetch1("experiment_name") == experiment_name raw_dir = ( - acquisition.Experiment.Directory & {"experiment_name": experiment_name, "directory_type": "raw"} + acquisition.Experiment.Directory + & {"experiment_name": experiment_name, "directory_type": "raw"} ).fetch1("directory_path") assert raw_dir == test_params["raw_dir"] - exp_subjects = (acquisition.Experiment.Subject & {"experiment_name": experiment_name}).fetch("subject") + exp_subjects = ( + acquisition.Experiment.Subject & {"experiment_name": experiment_name} + ).fetch("subject") assert len(exp_subjects) == test_params["subject_count"] assert "BAA-1100701" in exp_subjects diff --git a/tests/dj_pipeline/test_qc.py b/tests/dj_pipeline/test_qc.py index 9815031e..c0ced19f 100644 --- a/tests/dj_pipeline/test_qc.py +++ b/tests/dj_pipeline/test_qc.py @@ -1,3 +1,5 @@ +""" Tests for the QC pipeline. """ + from pytest import mark diff --git a/tests/dj_pipeline/test_tracking.py b/tests/dj_pipeline/test_tracking.py index 973e0741..42e1ede3 100644 --- a/tests/dj_pipeline/test_tracking.py +++ b/tests/dj_pipeline/test_tracking.py @@ -1,3 +1,5 @@ +""" Test tracking pipeline. """ + import datetime import pathlib @@ -6,9 +8,7 @@ index = 0 column_name = "position_x" # data column to run test on -file_name = ( - "exp0.2-r0-20220524090000-21053810-20220524082942-0-0.npy" # test file to be saved with save_test_data -) +file_name = "exp0.2-r0-20220524090000-21053810-20220524082942-0-0.npy" # test file to be saved with save_test_data def save_test_data(pipeline, test_params): @@ -19,7 +19,11 @@ def save_test_data(pipeline, test_params): file_name = ( "-".join( [ - v.strftime("%Y%m%d%H%M%S") if isinstance(v, datetime.datetime) else str(v) + ( + v.strftime("%Y%m%d%H%M%S") + if isinstance(v, datetime.datetime) + else str(v) + ) for v in key.values() ] ) @@ -38,13 +42,20 @@ def save_test_data(pipeline, test_params): def test_camera_tracking_ingestion(test_params, pipeline, camera_tracking_ingestion): tracking = pipeline["tracking"] - assert len(tracking.CameraTracking.Object()) == test_params["camera_tracking_object_count"] + assert ( + len(tracking.CameraTracking.Object()) + == test_params["camera_tracking_object_count"] + ) key = tracking.CameraTracking.Object().fetch("KEY")[index] file_name = ( "-".join( [ - v.strftime("%Y%m%d%H%M%S") if isinstance(v, datetime.datetime) else str(v) + ( + v.strftime("%Y%m%d%H%M%S") + if isinstance(v, datetime.datetime) + else str(v) + ) for v in key.values() ] ) diff --git a/tests/io/test_api.py b/tests/io/test_api.py index 095439de..71018e72 100644 --- a/tests/io/test_api.py +++ b/tests/io/test_api.py @@ -1,3 +1,5 @@ +""" Tests for the aeon API """ + from pathlib import Path import pandas as pd @@ -14,7 +16,10 @@ @mark.api def test_load_start_only(): data = aeon.load( - nonmonotonic_path, exp02.Patch2.Encoder, start=pd.Timestamp("2022-06-06T13:00:49"), downsample=None + nonmonotonic_path, + exp02.Patch2.Encoder, + start=pd.Timestamp("2022-06-06T13:00:49"), + downsample=None, ) assert len(data) > 0 @@ -22,14 +27,19 @@ def test_load_start_only(): @mark.api def test_load_end_only(): data = aeon.load( - nonmonotonic_path, exp02.Patch2.Encoder, end=pd.Timestamp("2022-06-06T13:00:49"), downsample=None + nonmonotonic_path, + exp02.Patch2.Encoder, + end=pd.Timestamp("2022-06-06T13:00:49"), + downsample=None, ) assert len(data) > 0 @mark.api def test_load_filter_nonchunked(): - data = aeon.load(nonmonotonic_path, exp02.Metadata, start=pd.Timestamp("2022-06-06T09:00:00")) + data = aeon.load( + nonmonotonic_path, exp02.Metadata, start=pd.Timestamp("2022-06-06T09:00:00") + ) assert len(data) > 0 From edb30e8d765e2f6f4dfcd91d5e600e0305ea074b Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 11:17:45 +0000 Subject: [PATCH 002/143] fix: fix: resolve D104 error --- aeon/__init__.py | 2 ++ aeon/analysis/__init__.py | 1 + aeon/dj_pipeline/__init__.py | 17 ++++++++++++++--- aeon/dj_pipeline/analysis/__init__.py | 1 + aeon/dj_pipeline/populate/__init__.py | 1 + aeon/dj_pipeline/utils/__init__.py | 1 + aeon/io/__init__.py | 1 + aeon/schema/__init__.py | 1 + pyproject.toml | 1 - 9 files changed, 22 insertions(+), 4 deletions(-) diff --git a/aeon/__init__.py b/aeon/__init__.py index 2a691c53..0dc5ee9d 100644 --- a/aeon/__init__.py +++ b/aeon/__init__.py @@ -1,3 +1,5 @@ +""" Top-level package for aeon. """ + from importlib.metadata import PackageNotFoundError, version try: diff --git a/aeon/analysis/__init__.py b/aeon/analysis/__init__.py index e69de29b..b48aecd3 100644 --- a/aeon/analysis/__init__.py +++ b/aeon/analysis/__init__.py @@ -0,0 +1 @@ +""" Utilities for analyzing data. """ diff --git a/aeon/dj_pipeline/__init__.py b/aeon/dj_pipeline/__init__.py index 9bb1128e..44e0d498 100644 --- a/aeon/dj_pipeline/__init__.py +++ b/aeon/dj_pipeline/__init__.py @@ -1,3 +1,5 @@ +""" DataJoint pipeline for Aeon. """ + import hashlib import os import uuid @@ -13,7 +15,9 @@ db_prefix = dj.config["custom"].get("database.prefix", _default_database_prefix) -repository_config = dj.config["custom"].get("repository_config", _default_repository_config) +repository_config = dj.config["custom"].get( + "repository_config", _default_repository_config +) def get_schema_name(name) -> str: @@ -38,7 +42,9 @@ def fetch_stream(query, drop_pk=True): """ df = (query & "sample_count > 0").fetch(format="frame").reset_index() cols2explode = [ - c for c in query.heading.secondary_attributes if query.heading.attributes[c].type == "longblob" + c + for c in query.heading.secondary_attributes + if query.heading.attributes[c].type == "longblob" ] df = df.explode(column=cols2explode) cols2drop = ["sample_count"] + (query.primary_key if drop_pk else []) @@ -46,7 +52,12 @@ def fetch_stream(query, drop_pk=True): df.rename(columns={"timestamps": "time"}, inplace=True) df.set_index("time", inplace=True) df.sort_index(inplace=True) - df = df.convert_dtypes(convert_string=False, convert_integer=False, convert_boolean=False, convert_floating=False) + df = df.convert_dtypes( + convert_string=False, + convert_integer=False, + convert_boolean=False, + convert_floating=False, + ) return df diff --git a/aeon/dj_pipeline/analysis/__init__.py b/aeon/dj_pipeline/analysis/__init__.py index e69de29b..b48aecd3 100644 --- a/aeon/dj_pipeline/analysis/__init__.py +++ b/aeon/dj_pipeline/analysis/__init__.py @@ -0,0 +1 @@ +""" Utilities for analyzing data. """ diff --git a/aeon/dj_pipeline/populate/__init__.py b/aeon/dj_pipeline/populate/__init__.py index e69de29b..ca091e15 100644 --- a/aeon/dj_pipeline/populate/__init__.py +++ b/aeon/dj_pipeline/populate/__init__.py @@ -0,0 +1 @@ +""" Utilities for the workflow orchestration. """ diff --git a/aeon/dj_pipeline/utils/__init__.py b/aeon/dj_pipeline/utils/__init__.py index e69de29b..82bbb4bd 100644 --- a/aeon/dj_pipeline/utils/__init__.py +++ b/aeon/dj_pipeline/utils/__init__.py @@ -0,0 +1 @@ +""" Helper functions and utilities for the Aeon project. """ diff --git a/aeon/io/__init__.py b/aeon/io/__init__.py index e69de29b..f481ec8e 100644 --- a/aeon/io/__init__.py +++ b/aeon/io/__init__.py @@ -0,0 +1 @@ +""" Utilities for I/O operations. """ diff --git a/aeon/schema/__init__.py b/aeon/schema/__init__.py index e69de29b..3de266c2 100644 --- a/aeon/schema/__init__.py +++ b/aeon/schema/__init__.py @@ -0,0 +1 @@ +""" Utilities for the schemas. """ diff --git a/pyproject.toml b/pyproject.toml index 2bf2ae8d..bc325ab9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,7 +97,6 @@ lint.select = [ ] line-length = 108 lint.ignore = [ - "D104", # ignore missing docstring in public package "D105", # skip adding docstrings for magic methods "D107", # skip adding docstrings for __init__ "E201", From 46b832b04c504e5cab932c8106676d79b17b531e Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 11:21:07 +0000 Subject: [PATCH 003/143] fix: fix: resolve D105 error --- aeon/io/device.py | 1 + aeon/schema/streams.py | 3 +++ pyproject.toml | 1 - 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/aeon/io/device.py b/aeon/io/device.py index 62e8b3e1..40704edc 100644 --- a/aeon/io/device.py +++ b/aeon/io/device.py @@ -39,6 +39,7 @@ def __init__(self, name, *args, pattern=None): self.registry = compositeStream(name if pattern is None else pattern, *args) def __iter__(self): + """Iterates over the device registry.""" if len(self.registry) == 1: singleton = self.registry.get(self.name, None) if singleton: diff --git a/aeon/schema/streams.py b/aeon/schema/streams.py index a6656f4b..cbbfc8b2 100644 --- a/aeon/schema/streams.py +++ b/aeon/schema/streams.py @@ -16,6 +16,7 @@ def __init__(self, reader): self.reader = reader def __iter__(self): + """Yields the stream name and reader.""" yield (self.__class__.__name__, self.reader) @@ -37,6 +38,7 @@ def __init__(self, path, *args): ) def __iter__(self): + """Yields the stream name and reader for each data stream in the group.""" for factory in chain(self._nested, self._args): yield from iter(factory(self.path)) @@ -79,6 +81,7 @@ def _createStreams(path, args): return streams def __iter__(self): + """Iterates over the device streams.""" if len(self._streams) == 1: singleton = self._streams.get(self.name, None) if singleton: diff --git a/pyproject.toml b/pyproject.toml index bc325ab9..5fd60607 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,7 +97,6 @@ lint.select = [ ] line-length = 108 lint.ignore = [ - "D105", # skip adding docstrings for magic methods "D107", # skip adding docstrings for __init__ "E201", "E202", From 7cfd46cdb89d429bb43c951ddaa8d962da2ce215 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 11:44:13 +0000 Subject: [PATCH 004/143] fix: resolve D107 error --- aeon/io/device.py | 1 + aeon/io/reader.py | 14 ++++++++++++++ aeon/schema/core.py | 9 +++++++++ aeon/schema/foraging.py | 14 ++++++++++++++ aeon/schema/octagon.py | 35 +++++++++++++++++++++++++++++++++++ aeon/schema/social_01.py | 2 ++ aeon/schema/social_02.py | 16 ++++++++++++++++ aeon/schema/social_03.py | 2 ++ aeon/schema/streams.py | 3 +++ pyproject.toml | 1 - 10 files changed, 96 insertions(+), 1 deletion(-) diff --git a/aeon/io/device.py b/aeon/io/device.py index 40704edc..56d3dacb 100644 --- a/aeon/io/device.py +++ b/aeon/io/device.py @@ -35,6 +35,7 @@ class Device: """ def __init__(self, name, *args, pattern=None): + """Initializes the Device class.""" self.name = name self.registry = compositeStream(name if pattern is None else pattern, *args) diff --git a/aeon/io/reader.py b/aeon/io/reader.py index a4a0738a..f631ab43 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -41,6 +41,7 @@ class Reader: """ def __init__(self, pattern, columns, extension): + """Initialize the object with specified pattern, columns, and file extension.""" self.pattern = pattern self.columns = columns self.extension = extension @@ -54,6 +55,7 @@ class Harp(Reader): """Extracts data from raw binary files encoded using the Harp protocol.""" def __init__(self, pattern, columns, extension="bin"): + """Initialize the object.""" super().__init__(pattern, columns, extension) def read(self, file): @@ -97,6 +99,7 @@ class Chunk(Reader): """Extracts path and epoch information from chunk files in the dataset.""" def __init__(self, reader=None, pattern=None, extension=None): + """Initialize the object with optional reader, pattern, and file extension.""" if isinstance(reader, Reader): pattern = reader.pattern extension = reader.extension @@ -113,6 +116,7 @@ class Metadata(Reader): """Extracts metadata information from all epochs in the dataset.""" def __init__(self, pattern="Metadata"): + """Initialize the object with the specified pattern.""" super().__init__( pattern, columns=["workflow", "commit", "metadata"], extension="yml" ) @@ -139,6 +143,7 @@ class Csv(Reader): """ def __init__(self, pattern, columns, dtype=None, extension="csv"): + """Initialize the object with the specified pattern, columns, and data type.""" super().__init__(pattern, columns, extension) self.dtype = dtype @@ -159,6 +164,7 @@ class JsonList(Reader): """ def __init__(self, pattern, columns=(), root_key="value", extension="jsonl"): + """Initialize the object with the specified pattern, columns, and root key.""" super().__init__(pattern, columns, extension) self.columns = columns self.root_key = root_key @@ -184,6 +190,7 @@ class Subject(Csv): """ def __init__(self, pattern): + """Initialize the object with a specified pattern.""" super().__init__(pattern, columns=["id", "weight", "event"]) @@ -198,6 +205,7 @@ class Log(Csv): """ def __init__(self, pattern): + """Initialize the object with a specified pattern and columns.""" super().__init__(pattern, columns=["priority", "type", "message"]) @@ -209,6 +217,7 @@ class Heartbeat(Harp): """ def __init__(self, pattern): + """Initialize the object with a specified pattern.""" super().__init__(pattern, columns=["second"]) @@ -221,6 +230,7 @@ class Encoder(Harp): """ def __init__(self, pattern): + """Initialize the object with a specified pattern and columns.""" super().__init__(pattern, columns=["angle", "intensity"]) def read(self, file, downsample=True): @@ -258,6 +268,7 @@ class Position(Harp): """ def __init__(self, pattern): + """Initialize the object with a specified pattern and columns.""" super().__init__( pattern, columns=["x", "y", "angle", "major", "minor", "area", "id"] ) @@ -271,6 +282,7 @@ class BitmaskEvent(Harp): """ def __init__(self, pattern, value, tag): + """Initialize the object with specified pattern, value, and tag.""" super().__init__(pattern, columns=["event"]) self.value = value self.tag = tag @@ -294,6 +306,7 @@ class DigitalBitmask(Harp): """ def __init__(self, pattern, mask, columns): + """Initialize the object with specified pattern, mask, and columns.""" super().__init__(pattern, columns) self.mask = mask @@ -316,6 +329,7 @@ class Video(Csv): """ def __init__(self, pattern): + """Initialize the object with a specified pattern.""" super().__init__( pattern, columns=["hw_counter", "hw_timestamp", "_frame", "_path", "_epoch"] ) diff --git a/aeon/schema/core.py b/aeon/schema/core.py index 63c403a6..02703e74 100644 --- a/aeon/schema/core.py +++ b/aeon/schema/core.py @@ -8,6 +8,7 @@ class Heartbeat(Stream): """Heartbeat event for Harp devices.""" def __init__(self, pattern): + """Initializes the Heartbeat stream.""" super().__init__(_reader.Heartbeat(f"{pattern}_8_*")) @@ -15,6 +16,7 @@ class Video(Stream): """Video frame metadata.""" def __init__(self, pattern): + """Initializes the Video stream.""" super().__init__(_reader.Video(f"{pattern}_*")) @@ -22,6 +24,7 @@ class Position(Stream): """Position tracking data for the specified camera.""" def __init__(self, pattern): + """Initializes the Position stream.""" super().__init__(_reader.Position(f"{pattern}_200_*")) @@ -29,6 +32,7 @@ class Encoder(Stream): """Wheel magnetic encoder data.""" def __init__(self, pattern): + """Initializes the Encoder stream.""" super().__init__(_reader.Encoder(f"{pattern}_90_*")) @@ -36,6 +40,7 @@ class Environment(StreamGroup): """Metadata for environment mode and subjects.""" def __init__(self, pattern): + """Initializes the Environment stream group.""" super().__init__(pattern, EnvironmentState, SubjectState) @@ -43,6 +48,7 @@ class EnvironmentState(Stream): """Environment state log.""" def __init__(self, pattern): + """Initializes the EnvironmentState stream.""" super().__init__(_reader.Csv(f"{pattern}_EnvironmentState_*", ["state"])) @@ -50,6 +56,7 @@ class SubjectState(Stream): """Subject state log.""" def __init__(self, pattern): + """Initialises the SubjectState stream.""" super().__init__(_reader.Subject(f"{pattern}_SubjectState_*")) @@ -57,6 +64,7 @@ class MessageLog(Stream): """Message log data.""" def __init__(self, pattern): + """Initializes the MessageLog stream.""" super().__init__(_reader.Log(f"{pattern}_MessageLog_*")) @@ -64,4 +72,5 @@ class Metadata(Stream): """Metadata for acquisition epochs.""" def __init__(self, pattern): + """Initializes the Metadata stream.""" super().__init__(_reader.Metadata(pattern)) diff --git a/aeon/schema/foraging.py b/aeon/schema/foraging.py index 40c4e7f6..05ce480a 100644 --- a/aeon/schema/foraging.py +++ b/aeon/schema/foraging.py @@ -20,6 +20,7 @@ class Area(Enum): class _RegionReader(_reader.Harp): def __init__(self, pattern): + """Initializes the RegionReader class.""" super().__init__(pattern, columns=["region"]) def read(self, file): @@ -41,6 +42,7 @@ class _PatchState(_reader.Csv): """ def __init__(self, pattern): + """Initializes the PatchState class.""" super().__init__(pattern, columns=["threshold", "d1", "delta"]) @@ -54,6 +56,7 @@ class _Weight(_reader.Harp): """ def __init__(self, pattern): + """Initializes the Weight class.""" super().__init__(pattern, columns=["value", "stable"]) @@ -61,6 +64,7 @@ class Region(Stream): """Region tracking data for the specified camera.""" def __init__(self, pattern): + """Initializes the Region stream.""" super().__init__(_RegionReader(f"{pattern}_201_*")) @@ -68,6 +72,7 @@ class DepletionFunction(Stream): """State of the linear depletion function for foraging patches.""" def __init__(self, pattern): + """Initializes the DepletionFunction stream.""" super().__init__(_PatchState(f"{pattern}_State_*")) @@ -75,6 +80,7 @@ class Feeder(StreamGroup): """Feeder commands and events.""" def __init__(self, pattern): + """Initializes the Feeder stream group.""" super().__init__(pattern, BeamBreak, DeliverPellet) @@ -82,6 +88,7 @@ class BeamBreak(Stream): """Beam break events for pellet detection.""" def __init__(self, pattern): + """Initializes the BeamBreak stream.""" super().__init__( _reader.BitmaskEvent(f"{pattern}_32_*", 0x22, "PelletDetected") ) @@ -91,6 +98,7 @@ class DeliverPellet(Stream): """Pellet delivery commands.""" def __init__(self, pattern): + """Initializes the DeliverPellet stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x01, "TriggerPellet")) @@ -98,6 +106,7 @@ class Patch(StreamGroup): """Data streams for a patch.""" def __init__(self, pattern): + """Initializes the Patch stream group.""" super().__init__(pattern, DepletionFunction, _stream.Encoder, Feeder) @@ -105,6 +114,7 @@ class Weight(StreamGroup): """Weight measurement data streams for a specific nest.""" def __init__(self, pattern): + """Initializes the Weight stream group.""" super().__init__(pattern, WeightRaw, WeightFiltered, WeightSubject) @@ -112,6 +122,7 @@ class WeightRaw(Stream): """Raw weight measurement for a specific nest.""" def __init__(self, pattern): + """Initializes the WeightRaw stream.""" super().__init__(_Weight(f"{pattern}_200_*")) @@ -119,6 +130,7 @@ class WeightFiltered(Stream): """Filtered weight measurement for a specific nest.""" def __init__(self, pattern): + """Initializes the WeightFiltered stream.""" super().__init__(_Weight(f"{pattern}_202_*")) @@ -126,6 +138,7 @@ class WeightSubject(Stream): """Subject weight measurement for a specific nest.""" def __init__(self, pattern): + """Initializes the WeightSubject stream.""" super().__init__(_Weight(f"{pattern}_204_*")) @@ -133,6 +146,7 @@ class SessionData(Stream): """Session metadata for Experiment 0.1.""" def __init__(self, pattern): + """Initializes the SessionData stream.""" super().__init__( _reader.Csv(f"{pattern}_2*", columns=["id", "weight", "event"]) ) diff --git a/aeon/schema/octagon.py b/aeon/schema/octagon.py index 064cae9a..f031d53b 100644 --- a/aeon/schema/octagon.py +++ b/aeon/schema/octagon.py @@ -6,15 +6,18 @@ class Photodiode(Stream): def __init__(self, path): + """Initializes the Photodiode stream.""" super().__init__(_reader.Harp(f"{path}_44_*", columns=["adc", "encoder"])) class OSC(StreamGroup): def __init__(self, path): + """Initializes the OSC stream group.""" super().__init__(path) class BackgroundColor(Stream): def __init__(self, pattern): + """Initializes the BackgroundColor stream.""" super().__init__( _reader.Csv( f"{pattern}_backgroundcolor_*", @@ -24,6 +27,7 @@ def __init__(self, pattern): class ChangeSubjectState(Stream): def __init__(self, pattern): + """Initializes the ChangeSubjectState stream.""" super().__init__( _reader.Csv( f"{pattern}_changesubjectstate_*", @@ -33,12 +37,14 @@ def __init__(self, pattern): class EndTrial(Stream): def __init__(self, pattern): + """Initialises the EndTrial stream.""" super().__init__( _reader.Csv(f"{pattern}_endtrial_*", columns=["typetag", "value"]) ) class Slice(Stream): def __init__(self, pattern): + """Initialises the Slice.""" super().__init__( _reader.Csv( f"{pattern}_octagonslice_*", @@ -48,6 +54,7 @@ def __init__(self, pattern): class GratingsSlice(Stream): def __init__(self, pattern): + """Initialises the GratingsSlice stream.""" super().__init__( _reader.Csv( f"{pattern}_octagongratingsslice_*", @@ -66,6 +73,7 @@ def __init__(self, pattern): class Poke(Stream): def __init__(self, pattern): + """Initializes the Poke class.""" super().__init__( _reader.Csv( f"{pattern}_poke_*", @@ -83,6 +91,7 @@ def __init__(self, pattern): class Response(Stream): def __init__(self, pattern): + """Initialises the Response class.""" super().__init__( _reader.Csv( f"{pattern}_response_*", @@ -92,6 +101,7 @@ def __init__(self, pattern): class RunPreTrialNoPoke(Stream): def __init__(self, pattern): + """Initialises the RunPreTrialNoPoke class.""" super().__init__( _reader.Csv( f"{pattern}_run_pre_no_poke_*", @@ -108,6 +118,7 @@ def __init__(self, pattern): class StartNewSession(Stream): def __init__(self, pattern): + """Initializes the StartNewSession class.""" super().__init__( _reader.Csv(f"{pattern}_startnewsession_*", columns=["typetag", "path"]) ) @@ -115,105 +126,129 @@ def __init__(self, pattern): class TaskLogic(StreamGroup): def __init__(self, path): + """Initialises the TaskLogic stream group.""" super().__init__(path) class TrialInitiation(Stream): def __init__(self, pattern): + """Initializes the TrialInitiation stream.""" super().__init__(_reader.Harp(f"{pattern}_1_*", columns=["trial_type"])) class Response(Stream): def __init__(self, pattern): + """Initializes the Response stream.""" super().__init__( _reader.Harp(f"{pattern}_2_*", columns=["wall_id", "poke_id"]) ) class PreTrialState(Stream): def __init__(self, pattern): + """Initializes the PreTrialState stream.""" super().__init__(_reader.Harp(f"{pattern}_3_*", columns=["state"])) class InterTrialInterval(Stream): def __init__(self, pattern): + """Initializes the InterTrialInterval stream.""" super().__init__(_reader.Harp(f"{pattern}_4_*", columns=["state"])) class SliceOnset(Stream): def __init__(self, pattern): + """Initializes the SliceOnset stream.""" super().__init__(_reader.Harp(f"{pattern}_10_*", columns=["wall_id"])) class DrawBackground(Stream): def __init__(self, pattern): + """Initializes the DrawBackground stream.""" super().__init__(_reader.Harp(f"{pattern}_11_*", columns=["state"])) class GratingsSliceOnset(Stream): def __init__(self, pattern): + """Initializes the GratingsSliceOnset stream.""" super().__init__(_reader.Harp(f"{pattern}_12_*", columns=["wall_id"])) class Wall(StreamGroup): def __init__(self, path): + """Initialises the Wall stream group.""" super().__init__(path) class BeamBreak0(Stream): def __init__(self, pattern): + """Initialises the BeamBreak0 stream.""" super().__init__( _reader.DigitalBitmask(f"{pattern}_32_*", 0x1, columns=["state"]) ) class BeamBreak1(Stream): def __init__(self, pattern): + """Initialises the BeamBreak1 stream.""" super().__init__( _reader.DigitalBitmask(f"{pattern}_32_*", 0x2, columns=["state"]) ) class BeamBreak2(Stream): def __init__(self, pattern): + """Initialises the BeamBreak2 stream.""" super().__init__( _reader.DigitalBitmask(f"{pattern}_32_*", 0x4, columns=["state"]) ) class SetLed0(Stream): def __init__(self, pattern): + """Initialises the SetLed0 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_34_*", 0x1, "Set")) class SetLed1(Stream): def __init__(self, pattern): + """Initialises the SetLed1 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_34_*", 0x2, "Set")) class SetLed2(Stream): def __init__(self, pattern): + """Initialises the SetLed2 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_34_*", 0x4, "Set")) class SetValve0(Stream): def __init__(self, pattern): + """Initialises the SetValve0 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_34_*", 0x8, "Set")) class SetValve1(Stream): def __init__(self, pattern): + """Initialises the SetValve1 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_34_*", 0x10, "Set")) class SetValve2(Stream): def __init__(self, pattern): + """Initialises the SetValve2 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_34_*", 0x20, "Set")) class ClearLed0(Stream): def __init__(self, pattern): + """Initialises the ClearLed0 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x1, "Clear")) class ClearLed1(Stream): def __init__(self, pattern): + """Initializes the ClearLed1 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x2, "Clear")) class ClearLed2(Stream): def __init__(self, pattern): + """Initializes the ClearLed2 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x4, "Clear")) class ClearValve0(Stream): def __init__(self, pattern): + """Initializes the ClearValve0 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x8, "Clear")) class ClearValve1(Stream): def __init__(self, pattern): + """Initializes the ClearValve1 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x10, "Clear")) class ClearValve2(Stream): def __init__(self, pattern): + """Initializes the ClearValve2 stream.""" super().__init__(_reader.BitmaskEvent(f"{pattern}_35_*", 0x20, "Clear")) diff --git a/aeon/schema/social_01.py b/aeon/schema/social_01.py index 719ef9a3..4edaec9f 100644 --- a/aeon/schema/social_01.py +++ b/aeon/schema/social_01.py @@ -6,6 +6,7 @@ class RfidEvents(Stream): def __init__(self, path): + """Initializes the RfidEvents stream.""" path = path.replace("Rfid", "") if path.startswith("Events"): path = path.replace("Events", "") @@ -15,4 +16,5 @@ def __init__(self, path): class Pose(Stream): def __init__(self, path): + """Initializes the Pose stream.""" super().__init__(_reader.Pose(f"{path}_node-0*")) diff --git a/aeon/schema/social_02.py b/aeon/schema/social_02.py index ce98f081..8a5183dd 100644 --- a/aeon/schema/social_02.py +++ b/aeon/schema/social_02.py @@ -7,12 +7,14 @@ class Environment(StreamGroup): def __init__(self, path): + """Initializes the Environment stream group.""" super().__init__(path) EnvironmentState = core.EnvironmentState class BlockState(Stream): def __init__(self, path): + """Initializes the BlockState stream.""" super().__init__( _reader.Csv( f"{path}_BlockState_*", @@ -22,6 +24,7 @@ def __init__(self, path): class LightEvents(Stream): def __init__(self, path): + """Initializes the LightEvents stream.""" super().__init__( _reader.Csv(f"{path}_LightEvents_*", columns=["channel", "value"]) ) @@ -31,22 +34,26 @@ def __init__(self, path): class SubjectData(StreamGroup): def __init__(self, path): + """Initializes the SubjectData stream group.""" super().__init__(path) class SubjectState(Stream): def __init__(self, path): + """Initializes the SubjectState stream.""" super().__init__( _reader.Csv(f"{path}_SubjectState_*", columns=["id", "weight", "type"]) ) class SubjectVisits(Stream): def __init__(self, path): + """Initializes the SubjectVisits stream.""" super().__init__( _reader.Csv(f"{path}_SubjectVisits_*", columns=["id", "type", "region"]) ) class SubjectWeight(Stream): def __init__(self, path): + """Initializes the SubjectWeight stream.""" super().__init__( _reader.Csv( f"{path}_SubjectWeight_*", @@ -57,25 +64,30 @@ def __init__(self, path): class Pose(Stream): def __init__(self, path): + """Initializes the Pose stream.""" super().__init__(_reader.Pose(f"{path}_test-node1*")) class WeightRaw(Stream): def __init__(self, path): + """Initialize the WeightRaw stream.""" super().__init__(_reader.Harp(f"{path}_200_*", ["weight(g)", "stability"])) class WeightFiltered(Stream): def __init__(self, path): + """Initializes the WeightFiltered stream.""" super().__init__(_reader.Harp(f"{path}_202_*", ["weight(g)", "stability"])) class Patch(StreamGroup): def __init__(self, path): + """Initializes the Patch stream group.""" super().__init__(path) class DepletionState(Stream): def __init__(self, path): + """Initializes the DepletionState stream.""" super().__init__( _reader.Csv(f"{path}_State_*", columns=["threshold", "offset", "rate"]) ) @@ -86,17 +98,21 @@ def __init__(self, path): class ManualDelivery(Stream): def __init__(self, path): + """Initializes the ManualDelivery stream.""" super().__init__(_reader.Harp(f"{path}_201_*", ["manual_delivery"])) class MissedPellet(Stream): def __init__(self, path): + """Initializes the MissedPellet stream.""" super().__init__(_reader.Harp(f"{path}_202_*", ["missed_pellet"])) class RetriedDelivery(Stream): def __init__(self, path): + """Initializes the RetriedDelivery stream.""" super().__init__(_reader.Harp(f"{path}_203_*", ["retried_delivery"])) class RfidEvents(Stream): def __init__(self, path): + """Initializes the RfidEvents stream.""" super().__init__(_reader.Harp(f"{path}_32*", ["rfid"])) diff --git a/aeon/schema/social_03.py b/aeon/schema/social_03.py index c624a06d..8ff87066 100644 --- a/aeon/schema/social_03.py +++ b/aeon/schema/social_03.py @@ -8,12 +8,14 @@ class Pose(Stream): def __init__(self, path): + """Initializes the Pose stream.""" super().__init__(_reader.Pose(f"{path}_202_*")) class EnvironmentActiveConfiguration(Stream): def __init__(self, path): + """Initializes the EnvironmentActiveConfiguration stream.""" super().__init__( _reader.JsonList(f"{path}_ActiveConfiguration_*", columns=["name"]) ) diff --git a/aeon/schema/streams.py b/aeon/schema/streams.py index cbbfc8b2..2c1cb94a 100644 --- a/aeon/schema/streams.py +++ b/aeon/schema/streams.py @@ -13,6 +13,7 @@ class Stream: """ def __init__(self, reader): + """Initializes the stream with a reader.""" self.reader = reader def __iter__(self): @@ -29,6 +30,7 @@ class StreamGroup: """ def __init__(self, path, *args): + """Initializes the stream group with a path and a list of data streams.""" self.path = path self._args = args self._nested = ( @@ -57,6 +59,7 @@ class Device: """ def __init__(self, name, *args, path=None): + """Initializes the device with a name and a list of data streams.""" if name is None: raise ValueError("name cannot be None.") diff --git a/pyproject.toml b/pyproject.toml index 5fd60607..9883959f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,7 +97,6 @@ lint.select = [ ] line-length = 108 lint.ignore = [ - "D107", # skip adding docstrings for __init__ "E201", "E202", "E203", From 7c596c3741fef8acc2344e2bf862a76d0ad65d21 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 11:48:45 +0000 Subject: [PATCH 005/143] fix: resolve E201, E202, and E203 errors by applying Black formatting --- aeon/analysis/block_plotting.py | 8 +- .../social_experiments_block_analysis.ipynb | 5 +- aeon/dj_pipeline/populate/process.py | 4 +- .../scripts/clone_and_freeze_exp01.py | 10 +- .../scripts/clone_and_freeze_exp02.py | 13 +- .../scripts/update_timestamps_longblob.py | 8 +- aeon/dj_pipeline/streams.py | 1053 +++++++++-------- pyproject.toml | 3 - tests/dj_pipeline/conftest.py | 14 +- 9 files changed, 599 insertions(+), 519 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index 87c548e2..ac06eaa2 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -29,10 +29,14 @@ def gen_hex_grad(hex_col, vals, min_l=0.3): """Generates an array of hex color values based on a gradient defined by unit-normalized values.""" # Convert hex to rgb to hls - h, l, s = rgb_to_hls(*[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)]) # noqa: E741 + h, l, s = rgb_to_hls( + *[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)] + ) # noqa: E741 grad = np.empty(shape=(len(vals),), dtype=" acquisition.Experiment -> Device @@ -66,16 +66,16 @@ class RfidReader(dj.Manual): rfid_reader_name : varchar(36) """ - class Attribute(dj.Part): - definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device + class Attribute(dj.Part): + definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device -> master attribute_name : varchar(32) --- attribute_value=null : longblob """ - class RemovalTime(dj.Part): - definition = f""" + class RemovalTime(dj.Part): + definition = f""" -> master --- rfid_reader_removal_time: datetime(6) # time of the rfid_reader being removed @@ -84,7 +84,7 @@ class RemovalTime(dj.Part): @schema class SpinnakerVideoSource(dj.Manual): - definition = f""" + definition = f""" # spinnaker_video_source placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-unknown) -> acquisition.Experiment -> Device @@ -93,16 +93,16 @@ class SpinnakerVideoSource(dj.Manual): spinnaker_video_source_name : varchar(36) """ - class Attribute(dj.Part): - definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device + class Attribute(dj.Part): + definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device -> master attribute_name : varchar(32) --- attribute_value=null : longblob """ - class RemovalTime(dj.Part): - definition = f""" + class RemovalTime(dj.Part): + definition = f""" -> master --- spinnaker_video_source_removal_time: datetime(6) # time of the spinnaker_video_source being removed @@ -111,7 +111,7 @@ class RemovalTime(dj.Part): @schema class UndergroundFeeder(dj.Manual): - definition = f""" + definition = f""" # underground_feeder placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-unknown) -> acquisition.Experiment -> Device @@ -120,16 +120,16 @@ class UndergroundFeeder(dj.Manual): underground_feeder_name : varchar(36) """ - class Attribute(dj.Part): - definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device + class Attribute(dj.Part): + definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device -> master attribute_name : varchar(32) --- attribute_value=null : longblob """ - class RemovalTime(dj.Part): - definition = f""" + class RemovalTime(dj.Part): + definition = f""" -> master --- underground_feeder_removal_time: datetime(6) # time of the underground_feeder being removed @@ -138,7 +138,7 @@ class RemovalTime(dj.Part): @schema class WeightScale(dj.Manual): - definition = f""" + definition = f""" # weight_scale placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-unknown) -> acquisition.Experiment -> Device @@ -147,16 +147,16 @@ class WeightScale(dj.Manual): weight_scale_name : varchar(36) """ - class Attribute(dj.Part): - definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device + class Attribute(dj.Part): + definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device -> master attribute_name : varchar(32) --- attribute_value=null : longblob """ - class RemovalTime(dj.Part): - definition = f""" + class RemovalTime(dj.Part): + definition = f""" -> master --- weight_scale_removal_time: datetime(6) # time of the weight_scale being removed @@ -165,7 +165,7 @@ class RemovalTime(dj.Part): @schema class RfidReaderRfidEvents(dj.Imported): - definition = """ # Raw per-chunk RfidEvents data stream from RfidReader (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk RfidEvents data stream from RfidReader (auto-generated with aeon_mecha-unknown) -> RfidReader -> acquisition.Chunk --- @@ -174,59 +174,62 @@ class RfidReaderRfidEvents(dj.Imported): rfid: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and RfidReader with overlapping time + Chunk(s) that started after RfidReader install time and ended before RfidReader remove time + Chunk(s) that started after RfidReader install time for RfidReader that are not yet removed """ - return ( - acquisition.Chunk * RfidReader.join(RfidReader.RemovalTime, left=True) - & 'chunk_start >= rfid_reader_install_time' - & 'chunk_start < IFNULL(rfid_reader_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (RfidReader & key).fetch1('rfid_reader_name') - - devices_schema = getattr( - aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "RfidEvents") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") - }, + return ( + acquisition.Chunk * RfidReader.join(RfidReader.RemovalTime, left=True) + & "chunk_start >= rfid_reader_install_time" + & 'chunk_start < IFNULL(rfid_reader_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (RfidReader & key).fetch1("rfid_reader_name") + + devices_schema = getattr( + aeon_schemas, + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "RfidEvents") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") }, - ignore_extra_fields=True, - ) + }, + ignore_extra_fields=True, + ) @schema class SpinnakerVideoSourceVideo(dj.Imported): - definition = """ # Raw per-chunk Video data stream from SpinnakerVideoSource (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk Video data stream from SpinnakerVideoSource (auto-generated with aeon_mecha-unknown) -> SpinnakerVideoSource -> acquisition.Chunk --- @@ -236,59 +239,63 @@ class SpinnakerVideoSourceVideo(dj.Imported): hw_timestamp: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and SpinnakerVideoSource with overlapping time + Chunk(s) that started after SpinnakerVideoSource install time and ended before SpinnakerVideoSource remove time + Chunk(s) that started after SpinnakerVideoSource install time for SpinnakerVideoSource that are not yet removed """ - return ( - acquisition.Chunk * SpinnakerVideoSource.join(SpinnakerVideoSource.RemovalTime, left=True) - & 'chunk_start >= spinnaker_video_source_install_time' - & 'chunk_start < IFNULL(spinnaker_video_source_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (SpinnakerVideoSource & key).fetch1('spinnaker_video_source_name') - - devices_schema = getattr( - aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "Video") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") - }, + return ( + acquisition.Chunk + * SpinnakerVideoSource.join(SpinnakerVideoSource.RemovalTime, left=True) + & "chunk_start >= spinnaker_video_source_install_time" + & 'chunk_start < IFNULL(spinnaker_video_source_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name") + + devices_schema = getattr( + aeon_schemas, + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "Video") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") }, - ignore_extra_fields=True, - ) + }, + ignore_extra_fields=True, + ) @schema class UndergroundFeederBeamBreak(dj.Imported): - definition = """ # Raw per-chunk BeamBreak data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk BeamBreak data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -297,59 +304,63 @@ class UndergroundFeederBeamBreak(dj.Imported): event: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & 'chunk_start >= underground_feeder_install_time' - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') - - devices_schema = getattr( - aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "BeamBreak") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") - }, + return ( + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & "chunk_start >= underground_feeder_install_time" + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") + + devices_schema = getattr( + aeon_schemas, + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "BeamBreak") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") }, - ignore_extra_fields=True, - ) + }, + ignore_extra_fields=True, + ) @schema class UndergroundFeederDeliverPellet(dj.Imported): - definition = """ # Raw per-chunk DeliverPellet data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk DeliverPellet data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -358,59 +369,63 @@ class UndergroundFeederDeliverPellet(dj.Imported): event: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & 'chunk_start >= underground_feeder_install_time' - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') - - devices_schema = getattr( - aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "DeliverPellet") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") - }, + return ( + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & "chunk_start >= underground_feeder_install_time" + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") + + devices_schema = getattr( + aeon_schemas, + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "DeliverPellet") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") }, - ignore_extra_fields=True, - ) + }, + ignore_extra_fields=True, + ) @schema class UndergroundFeederDepletionState(dj.Imported): - definition = """ # Raw per-chunk DepletionState data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk DepletionState data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -421,59 +436,63 @@ class UndergroundFeederDepletionState(dj.Imported): rate: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & 'chunk_start >= underground_feeder_install_time' - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') - - devices_schema = getattr( - aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "DepletionState") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") - }, + return ( + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & "chunk_start >= underground_feeder_install_time" + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") + + devices_schema = getattr( + aeon_schemas, + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "DepletionState") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") }, - ignore_extra_fields=True, - ) + }, + ignore_extra_fields=True, + ) @schema class UndergroundFeederEncoder(dj.Imported): - definition = """ # Raw per-chunk Encoder data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk Encoder data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -483,59 +502,63 @@ class UndergroundFeederEncoder(dj.Imported): intensity: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & 'chunk_start >= underground_feeder_install_time' - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') - - devices_schema = getattr( - aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "Encoder") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") - }, + return ( + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & "chunk_start >= underground_feeder_install_time" + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") + + devices_schema = getattr( + aeon_schemas, + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "Encoder") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") }, - ignore_extra_fields=True, - ) + }, + ignore_extra_fields=True, + ) @schema class UndergroundFeederManualDelivery(dj.Imported): - definition = """ # Raw per-chunk ManualDelivery data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk ManualDelivery data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -544,59 +567,63 @@ class UndergroundFeederManualDelivery(dj.Imported): manual_delivery: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & 'chunk_start >= underground_feeder_install_time' - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') - - devices_schema = getattr( - aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "ManualDelivery") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") - }, + return ( + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & "chunk_start >= underground_feeder_install_time" + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") + + devices_schema = getattr( + aeon_schemas, + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "ManualDelivery") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") }, - ignore_extra_fields=True, - ) + }, + ignore_extra_fields=True, + ) @schema class UndergroundFeederMissedPellet(dj.Imported): - definition = """ # Raw per-chunk MissedPellet data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk MissedPellet data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -605,59 +632,63 @@ class UndergroundFeederMissedPellet(dj.Imported): missed_pellet: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & 'chunk_start >= underground_feeder_install_time' - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') - - devices_schema = getattr( - aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "MissedPellet") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") - }, + return ( + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & "chunk_start >= underground_feeder_install_time" + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") + + devices_schema = getattr( + aeon_schemas, + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "MissedPellet") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") }, - ignore_extra_fields=True, - ) + }, + ignore_extra_fields=True, + ) @schema class UndergroundFeederRetriedDelivery(dj.Imported): - definition = """ # Raw per-chunk RetriedDelivery data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk RetriedDelivery data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -666,59 +697,63 @@ class UndergroundFeederRetriedDelivery(dj.Imported): retried_delivery: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & 'chunk_start >= underground_feeder_install_time' - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') - - devices_schema = getattr( - aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "RetriedDelivery") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") - }, + return ( + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & "chunk_start >= underground_feeder_install_time" + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") + + devices_schema = getattr( + aeon_schemas, + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "RetriedDelivery") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") }, - ignore_extra_fields=True, - ) + }, + ignore_extra_fields=True, + ) @schema class WeightScaleWeightFiltered(dj.Imported): - definition = """ # Raw per-chunk WeightFiltered data stream from WeightScale (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk WeightFiltered data stream from WeightScale (auto-generated with aeon_mecha-unknown) -> WeightScale -> acquisition.Chunk --- @@ -728,59 +763,62 @@ class WeightScaleWeightFiltered(dj.Imported): stability: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and WeightScale with overlapping time + Chunk(s) that started after WeightScale install time and ended before WeightScale remove time + Chunk(s) that started after WeightScale install time for WeightScale that are not yet removed """ - return ( - acquisition.Chunk * WeightScale.join(WeightScale.RemovalTime, left=True) - & 'chunk_start >= weight_scale_install_time' - & 'chunk_start < IFNULL(weight_scale_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (WeightScale & key).fetch1('weight_scale_name') - - devices_schema = getattr( - aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "WeightFiltered") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") - }, + return ( + acquisition.Chunk * WeightScale.join(WeightScale.RemovalTime, left=True) + & "chunk_start >= weight_scale_install_time" + & 'chunk_start < IFNULL(weight_scale_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (WeightScale & key).fetch1("weight_scale_name") + + devices_schema = getattr( + aeon_schemas, + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "WeightFiltered") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") }, - ignore_extra_fields=True, - ) + }, + ignore_extra_fields=True, + ) @schema class WeightScaleWeightRaw(dj.Imported): - definition = """ # Raw per-chunk WeightRaw data stream from WeightScale (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk WeightRaw data stream from WeightScale (auto-generated with aeon_mecha-unknown) -> WeightScale -> acquisition.Chunk --- @@ -790,51 +828,54 @@ class WeightScaleWeightRaw(dj.Imported): stability: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and WeightScale with overlapping time + Chunk(s) that started after WeightScale install time and ended before WeightScale remove time + Chunk(s) that started after WeightScale install time for WeightScale that are not yet removed """ - return ( - acquisition.Chunk * WeightScale.join(WeightScale.RemovalTime, left=True) - & 'chunk_start >= weight_scale_install_time' - & 'chunk_start < IFNULL(weight_scale_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (WeightScale & key).fetch1('weight_scale_name') - - devices_schema = getattr( - aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "WeightRaw") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") - }, + return ( + acquisition.Chunk * WeightScale.join(WeightScale.RemovalTime, left=True) + & "chunk_start >= weight_scale_install_time" + & 'chunk_start < IFNULL(weight_scale_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (WeightScale & key).fetch1("weight_scale_name") + + devices_schema = getattr( + aeon_schemas, + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "WeightRaw") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") }, - ignore_extra_fields=True, - ) + }, + ignore_extra_fields=True, + ) diff --git a/pyproject.toml b/pyproject.toml index 9883959f..152537d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,9 +97,6 @@ lint.select = [ ] line-length = 108 lint.ignore = [ - "E201", - "E202", - "E203", "E231", "E731", "E702", diff --git a/tests/dj_pipeline/conftest.py b/tests/dj_pipeline/conftest.py index 3604890f..f4ec3a39 100644 --- a/tests/dj_pipeline/conftest.py +++ b/tests/dj_pipeline/conftest.py @@ -53,11 +53,21 @@ def dj_config(): dj.config.load(dj_config_fp) dj.config["safemode"] = False assert "custom" in dj.config - dj.config["custom"]["database.prefix"] = f"u_{dj.config['database.user']}_testsuite_" + dj.config["custom"][ + "database.prefix" + ] = f"u_{dj.config['database.user']}_testsuite_" def load_pipeline(): - from aeon.dj_pipeline import acquisition, analysis, lab, qc, report, subject, tracking + from aeon.dj_pipeline import ( + acquisition, + analysis, + lab, + qc, + report, + subject, + tracking, + ) return { "subject": subject, From dabe579a6c234a9f33cc7757c4d478d1490baaf5 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 11:50:45 +0000 Subject: [PATCH 006/143] fix: resolve E231 error through previous commit too --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 152537d7..6680a8e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,7 +97,6 @@ lint.select = [ ] line-length = 108 lint.ignore = [ - "E231", "E731", "E702", "S101", From 7c40d82c4e3c649b4f2258300fbe60de1360bab5 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 11:51:55 +0000 Subject: [PATCH 007/143] fix: resolve E702 --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6680a8e5..bcd991ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,6 @@ lint.select = [ line-length = 108 lint.ignore = [ "E731", - "E702", "S101", "PT004", # Rule `PT004` is deprecated and will be removed in a future release. "PT013", From 6c9cb2ee6c8706f031798238b1c0d8883ed7960a Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 12:38:12 +0000 Subject: [PATCH 008/143] fix: resolve S101: substitute `assert` detected by `if...raise` and add `dj.logger` in `dj_pipeline` files --- aeon/dj_pipeline/acquisition.py | 3 +- aeon/dj_pipeline/lab.py | 1 + aeon/dj_pipeline/qc.py | 1 + aeon/dj_pipeline/report.py | 2 + aeon/dj_pipeline/streams.py | 1 + aeon/dj_pipeline/tracking.py | 4 +- tests/dj_pipeline/conftest.py | 9 ++- tests/dj_pipeline/test_acquisition.py | 48 ++++++++++------ .../test_pipeline_instantiation.py | 57 +++++++++++++++---- tests/dj_pipeline/test_qc.py | 11 +++- tests/dj_pipeline/test_tracking.py | 27 +++++---- tests/io/test_api.py | 52 ++++++++++++----- 12 files changed, 160 insertions(+), 56 deletions(-) diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index c5307ecc..c4d9b606 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -147,7 +147,8 @@ def get_data_directory(cls, experiment_key, directory_type="raw", as_posix=False dir_path = pathlib.Path(dir_path) if dir_path.exists(): - assert dir_path.is_relative_to(paths.get_repository_path(repo_name)) + if not dir_path.is_relative_to(paths.get_repository_path(repo_name)): + raise ValueError(f"f{dir_path} is not relative to the repository path.") data_directory = dir_path else: data_directory = paths.get_repository_path(repo_name) / dir_path diff --git a/aeon/dj_pipeline/lab.py b/aeon/dj_pipeline/lab.py index 203f4a47..b0c6204b 100644 --- a/aeon/dj_pipeline/lab.py +++ b/aeon/dj_pipeline/lab.py @@ -5,6 +5,7 @@ from . import get_schema_name schema = dj.schema(get_schema_name("lab")) +logger = dj.logger # ------------------- GENERAL LAB INFORMATION -------------------- diff --git a/aeon/dj_pipeline/qc.py b/aeon/dj_pipeline/qc.py index cc52d48f..200285da 100644 --- a/aeon/dj_pipeline/qc.py +++ b/aeon/dj_pipeline/qc.py @@ -10,6 +10,7 @@ from aeon.dj_pipeline import acquisition, streams schema = dj.schema(get_schema_name("qc")) +logger = dj.logger # -------------- Quality Control --------------------- diff --git a/aeon/dj_pipeline/report.py b/aeon/dj_pipeline/report.py index 0a66dc75..ac929625 100644 --- a/aeon/dj_pipeline/report.py +++ b/aeon/dj_pipeline/report.py @@ -16,6 +16,8 @@ from . import acquisition, analysis, get_schema_name +logger = dj.logger + # schema = dj.schema(get_schema_name("report")) schema = dj.schema() os.environ["DJ_SUPPORT_FILEPATH_MANAGEMENT"] = "TRUE" diff --git a/aeon/dj_pipeline/streams.py b/aeon/dj_pipeline/streams.py index 8a639f5e..b692a5aa 100644 --- a/aeon/dj_pipeline/streams.py +++ b/aeon/dj_pipeline/streams.py @@ -12,6 +12,7 @@ from aeon.schema import schemas as aeon_schemas schema = dj.Schema(get_schema_name("streams")) +logger = dj.logger @schema diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 842d99f2..f53a6350 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -19,6 +19,7 @@ from aeon.schema import schemas as aeon_schemas schema = dj.schema(get_schema_name("tracking")) +logger = dj.logger pixel_scale = 0.00192 # 1 px = 1.92 mm arena_center_x, arena_center_y = 1.475, 1.075 # center @@ -250,7 +251,8 @@ def make(self, key): def compute_distance(position_df, target, xcol="x", ycol="y"): - assert len(target) == 2 + if len(target) != 2: + raise ValueError("Target must be a list of tuple of length 2.") return np.sqrt(np.square(position_df[[xcol, ycol]] - target).sum(axis=1)) diff --git a/tests/dj_pipeline/conftest.py b/tests/dj_pipeline/conftest.py index f4ec3a39..af0dfc56 100644 --- a/tests/dj_pipeline/conftest.py +++ b/tests/dj_pipeline/conftest.py @@ -18,6 +18,7 @@ _tear_down = True # always set to True since most fixtures are session-scoped _populate_settings = {"suppress_errors": True} +logger = dj.logger def data_dir(): @@ -49,10 +50,14 @@ def test_params(): def dj_config(): """Configures DataJoint connection and loads custom settings.""" dj_config_fp = pathlib.Path("dj_local_conf.json") - assert dj_config_fp.exists() + if not dj_config_fp.exists(): + raise FileNotFoundError( + f"DataJoint configuration file not found: {dj_config_fp}" + ) dj.config.load(dj_config_fp) dj.config["safemode"] = False - assert "custom" in dj.config + if "custom" not in dj.config: + raise KeyError("'custom' not found in DataJoint configuration.") dj.config["custom"][ "database.prefix" ] = f"u_{dj.config['database.user']}_testsuite_" diff --git a/tests/dj_pipeline/test_acquisition.py b/tests/dj_pipeline/test_acquisition.py index d54a491e..6432d06f 100644 --- a/tests/dj_pipeline/test_acquisition.py +++ b/tests/dj_pipeline/test_acquisition.py @@ -1,20 +1,28 @@ """ Tests for the acquisition pipeline. """ from pytest import mark +import datajoint as dj +logger = dj.logger @mark.ingestion def test_epoch_chunk_ingestion(test_params, pipeline, epoch_chunk_ingestion): acquisition = pipeline["acquisition"] - - assert ( - len(acquisition.Epoch & {"experiment_name": test_params["experiment_name"]}) - == test_params["epoch_count"] + epoch_count = len( + acquisition.Epoch & {"experiment_name": test_params["experiment_name"]} ) - assert ( - len(acquisition.Chunk & {"experiment_name": test_params["experiment_name"]}) - == test_params["chunk_count"] + chunk_count = len( + acquisition.Chunk & {"experiment_name": test_params["experiment_name"]} ) + if epoch_count != test_params["epoch_count"]: + raise AssertionError( + f"Expected {test_params['epoch_count']} epochs, but got {epoch_count}." + ) + + if chunk_count != test_params["chunk_count"]: + raise AssertionError( + f"Expected {test_params['chunk_count']} chunks, but got {chunk_count}." + ) @mark.ingestion @@ -23,24 +31,30 @@ def test_experimentlog_ingestion( ): acquisition = pipeline["acquisition"] - assert ( + experiment_log_message_count = ( len( acquisition.ExperimentLog.Message & {"experiment_name": test_params["experiment_name"]} ) - == test_params["experiment_log_message_count"] + if experiment_log_message_count != test_params["experiment_log_message_count"]: + raise AssertionError( + f"Expected {test_params['experiment_log_message_count']} experiment log messages, but got {experiment_log_message_count}." + ) ) - assert ( - len( + subject_enter_exit_count = len( acquisition.SubjectEnterExit.Time & {"experiment_name": test_params["experiment_name"]} ) - == test_params["subject_enter_exit_count"] - ) - assert ( - len( + if subject_enter_exit_count != test_params["subject_enter_exit_count"]: + raise AssertionError( + f"Expected {test_params['subject_enter_exit_count']} subject enter/exit events, but got {subject_enter_exit_count}." + ) + + subject_weight_time_count = len( acquisition.SubjectWeight.WeightTime & {"experiment_name": test_params["experiment_name"]} ) - == test_params["subject_weight_time_count"] - ) + if subject_weight_time_count != test_params["subject_weight_time_count"]: + raise AssertionError( + f"Expected {test_params['subject_weight_time_count']} subject weight events, but got {subject_weight_time_count}." + ) diff --git a/tests/dj_pipeline/test_pipeline_instantiation.py b/tests/dj_pipeline/test_pipeline_instantiation.py index d6bdc96f..88cc31b0 100644 --- a/tests/dj_pipeline/test_pipeline_instantiation.py +++ b/tests/dj_pipeline/test_pipeline_instantiation.py @@ -1,16 +1,37 @@ """ Tests for pipeline instantiation and experiment creation """ +import datajoint as dj + +logger = dj.logger + from pytest import mark @mark.instantiation def test_pipeline_instantiation(pipeline): - assert hasattr(pipeline["acquisition"], "FoodPatchEvent") - assert hasattr(pipeline["lab"], "Arena") - assert hasattr(pipeline["qc"], "CameraQC") - assert hasattr(pipeline["report"], "InArenaSummaryPlot") - assert hasattr(pipeline["subject"], "Subject") - assert hasattr(pipeline["tracking"], "CameraTracking") + if not hasattr(pipeline["acquisition"], "FoodPatchEvent"): + raise AssertionError( + "Pipeline acquisition does not have 'FoodPatchEvent' attribute." + ) + + if not hasattr(pipeline["lab"], "Arena"): + raise AssertionError("Pipeline lab does not have 'Arena' attribute.") + + if not hasattr(pipeline["qc"], "CameraQC"): + raise AssertionError("Pipeline qc does not have 'CameraQC' attribute.") + + if not hasattr(pipeline["report"], "InArenaSummaryPlot"): + raise AssertionError( + "Pipeline report does not have 'InArenaSummaryPlot' attribute." + ) + + if not hasattr(pipeline["subject"], "Subject"): + raise AssertionError("Pipeline subject does not have 'Subject' attribute.") + + if not hasattr(pipeline["tracking"], "CameraTracking"): + raise AssertionError( + "Pipeline tracking does not have 'CameraTracking' attribute." + ) @mark.instantiation @@ -18,14 +39,30 @@ def test_experiment_creation(test_params, pipeline, experiment_creation): acquisition = pipeline["acquisition"] experiment_name = test_params["experiment_name"] - assert acquisition.Experiment.fetch1("experiment_name") == experiment_name + fetched_experiment_name = acquisition.Experiment.fetch1("experiment_name") + if fetched_experiment_name != experiment_name: + raise AssertionError( + f"Expected experiment name '{experiment_name}', but got '{fetched_experiment_name}'." + ) + raw_dir = ( acquisition.Experiment.Directory & {"experiment_name": experiment_name, "directory_type": "raw"} ).fetch1("directory_path") - assert raw_dir == test_params["raw_dir"] + if raw_dir != test_params["raw_dir"]: + raise AssertionError( + f"Expected raw directory '{test_params['raw_dir']}', but got '{raw_dir}'." + ) + exp_subjects = ( acquisition.Experiment.Subject & {"experiment_name": experiment_name} ).fetch("subject") - assert len(exp_subjects) == test_params["subject_count"] - assert "BAA-1100701" in exp_subjects + if len(exp_subjects) != test_params["subject_count"]: + raise AssertionError( + f"Expected subject count {test_params['subject_count']}, but got {len(exp_subjects)}." + ) + + if "BAA-1100701" not in exp_subjects: + raise AssertionError( + "Expected subject 'BAA-1100701' not found in experiment subjects." + ) diff --git a/tests/dj_pipeline/test_qc.py b/tests/dj_pipeline/test_qc.py index c0ced19f..85c1772f 100644 --- a/tests/dj_pipeline/test_qc.py +++ b/tests/dj_pipeline/test_qc.py @@ -1,10 +1,19 @@ """ Tests for the QC pipeline. """ from pytest import mark +import datajoint as dj + +logger = dj.logger @mark.qc def test_camera_qc_ingestion(test_params, pipeline, camera_qc_ingestion): qc = pipeline["qc"] - assert len(qc.CameraQC()) == test_params["camera_qc_count"] + camera_qc_count = len(qc.CameraQC()) + expected_camera_qc_count = test_params["camera_qc_count"] + + if camera_qc_count != expected_camera_qc_count: + raise AssertionError( + f"Expected camera QC count {expected_camera_qc_count}, but got {camera_qc_count}." + ) diff --git a/tests/dj_pipeline/test_tracking.py b/tests/dj_pipeline/test_tracking.py index 42e1ede3..7a388cae 100644 --- a/tests/dj_pipeline/test_tracking.py +++ b/tests/dj_pipeline/test_tracking.py @@ -5,6 +5,10 @@ import numpy as np from pytest import mark +import datajoint as dj + +logger = dj.logger + index = 0 column_name = "position_x" # data column to run test on @@ -42,10 +46,11 @@ def save_test_data(pipeline, test_params): def test_camera_tracking_ingestion(test_params, pipeline, camera_tracking_ingestion): tracking = pipeline["tracking"] - assert ( - len(tracking.CameraTracking.Object()) - == test_params["camera_tracking_object_count"] - ) + camera_tracking_object_count = len(tracking.CameraTracking.Object()) + if camera_tracking_object_count != test_params["camera_tracking_object_count"]: + raise AssertionError( + f"Expected camera tracking object count {test_params['camera_tracking_object_count']},but got {camera_tracking_object_count}." + ) key = tracking.CameraTracking.Object().fetch("KEY")[index] file_name = ( @@ -63,13 +68,15 @@ def test_camera_tracking_ingestion(test_params, pipeline, camera_tracking_ingest ) test_file = pathlib.Path(test_params["test_dir"] + "/" + file_name) - assert test_file.exists() + if not test_file.exists(): + raise AssertionError(f"Test file '{test_file}' does not exist.") print(f"\nTesting {file_name}") data = np.load(test_file) - assert np.allclose( - data, - (tracking.CameraTracking.Object() & key).fetch(column_name)[0], - equal_nan=True, - ) + expected_data = (tracking.CameraTracking.Object() & key).fetch(column_name)[0] + + if not np.allclose(data, expected_data, equal_nan=True): + raise AssertionError( + f"Loaded data does not match the expected data.nExpected: {expected_data}, but got: {data}." + ) diff --git a/tests/io/test_api.py b/tests/io/test_api.py index 71018e72..5f70e5f6 100644 --- a/tests/io/test_api.py +++ b/tests/io/test_api.py @@ -21,7 +21,8 @@ def test_load_start_only(): start=pd.Timestamp("2022-06-06T13:00:49"), downsample=None, ) - assert len(data) > 0 + if len(data) <= 0: + raise AssertionError("Loaded data is empty. Expected non-empty data.") @mark.api @@ -32,7 +33,8 @@ def test_load_end_only(): end=pd.Timestamp("2022-06-06T13:00:49"), downsample=None, ) - assert len(data) > 0 + if len(data) <= 0: + raise AssertionError("Loaded data is empty. Expected non-empty data.") @mark.api @@ -40,20 +42,27 @@ def test_load_filter_nonchunked(): data = aeon.load( nonmonotonic_path, exp02.Metadata, start=pd.Timestamp("2022-06-06T09:00:00") ) - assert len(data) > 0 + if len(data) <= 0: + raise AssertionError("Loaded data is empty. Expected non-empty data.") @mark.api def test_load_monotonic(): data = aeon.load(monotonic_path, exp02.Patch2.Encoder, downsample=None) - assert len(data) > 0 - assert data.index.is_monotonic_increasing + if len(data) <= 0: + raise AssertionError("Loaded data is empty. Expected non-empty data.") + + if not data.index.is_monotonic_increasing: + raise AssertionError("Data index is not monotonic increasing.") @mark.api def test_load_nonmonotonic(): data = aeon.load(nonmonotonic_path, exp02.Patch2.Encoder, downsample=None) - assert not data.index.is_monotonic_increasing + if data.index.is_monotonic_increasing: + raise AssertionError( + "Data index is monotonic increasing, but it should not be." + ) @mark.api @@ -63,20 +72,35 @@ def test_load_encoder_with_downsampling(): raw_data = aeon.load(monotonic_path, exp02.Patch2.Encoder, downsample=None) # Check that the length of the downsampled data is less than the raw data - assert len(data) < len(raw_data) + if len(data) >= len(raw_data): + raise AssertionError( + "Downsampled data length should be less than raw data length." + ) # Check that the first timestamp of the downsampled data is within 20ms of the raw data - assert abs(data.index[0] - raw_data.index[0]).total_seconds() <= DOWNSAMPLE_PERIOD + if abs(data.index[0] - raw_data.index[0]).total_seconds() > DOWNSAMPLE_PERIOD: + raise AssertionError( + "The first timestamp of downsampled data is not within 20ms of raw data." + ) # Check that the last timestamp of the downsampled data is within 20ms of the raw data - assert abs(data.index[-1] - raw_data.index[-1]).total_seconds() <= DOWNSAMPLE_PERIOD - - # Check that the minimum difference between consecutive timestamps in the downsampled data - # is at least 20ms (50Hz) - assert data.index.to_series().diff().dt.total_seconds().min() >= DOWNSAMPLE_PERIOD + if abs(data.index[-1] - raw_data.index[-1]).total_seconds() > DOWNSAMPLE_PERIOD: + raise AssertionError( + f"The last timestamp of downsampled data is not within {DOWNSAMPLE_PERIOD*1000} ms of raw data." + ) + + # Check that the minimum difference between consecutive timestamps in the downsampled data is at least 20ms (50Hz) + min_diff = data.index.to_series().diff().dt.total_seconds().min() + if min_diff < DOWNSAMPLE_PERIOD: + raise AssertionError( + f"Minimum difference between consecutive timestamps is less than {DOWNSAMPLE_PERIOD} seconds." + ) # Check that the timestamps in the downsampled data are strictly increasing - assert data.index.is_monotonic_increasing + if not data.index.is_monotonic_increasing: + raise AssertionError( + "Timestamps in downsampled data are not strictly increasing." + ) if __name__ == "__main__": From 047d8ec9ecca83c6bbe68775a50f1ecfd875a073 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 12:47:58 +0000 Subject: [PATCH 009/143] fix: resolve additional S101: replaced `assert` with explicit `if...raise` in compliance with ruff guidelines --- aeon/dj_pipeline/analysis/block_analysis.py | 5 ++++- aeon/dj_pipeline/analysis/visit_analysis.py | 5 ++++- .../scripts/update_timestamps_longblob.py | 13 +++++++++++-- aeon/dj_pipeline/utils/load_metadata.py | 5 ++++- aeon/dj_pipeline/utils/video.py | 6 ++++-- 5 files changed, 27 insertions(+), 7 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 735ebfd3..ff710f48 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -1985,7 +1985,10 @@ def get_foraging_bouts( if bout_start_indxs[-1] >= len(wheel_ts): bout_start_indxs = bout_start_indxs[:-1] bout_end_indxs = bout_end_indxs[:-1] - assert len(bout_start_indxs) == len(bout_end_indxs) + if len(bout_start_indxs) != len(bout_end_indxs): + raise ValueError( + "Mismatch between the lengths of bout_start_indxs and bout_end_indxs." + ) bout_durations = ( wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs] ).astype( # in seconds diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 8a882f48..1eb192c7 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -201,7 +201,10 @@ def make(self, key): def get_position(cls, visit_key=None, subject=None, start=None, end=None): """Given a key to a single Visit, return a Pandas DataFrame for the position data of the subject for the specified Visit time period.""" if visit_key is not None: - assert len(Visit & visit_key) == 1 + if len(Visit & visit_key) != 1: + raise ValueError( + "The `visit_key` must correspond to exactly one Visit." + ) start, end = ( Visit.join(VisitEnd, left=True).proj( visit_end="IFNULL(visit_end, NOW())" diff --git a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py index a2d70c7e..ba8b7fa5 100644 --- a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py +++ b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py @@ -5,10 +5,16 @@ from datetime import datetime import datajoint as dj + +logger = dj.logger + import numpy as np from tqdm import tqdm -assert dj.__version__ >= "0.13.7" +if dj.__version__ < "0.13.7": + raise ImportError( + f"DataJoint version must be at least 0.13.7, but found {dj.__version__}." + ) schema = dj.schema("u_thinh_aeonfix") @@ -60,7 +66,10 @@ def main(): if not len(ts) or isinstance(ts[0], np.datetime64): TimestampFix.insert1(fix_key) continue - assert isinstance(ts[0], datetime) + if not isinstance(ts[0], datetime): + raise TypeError( + f"Expected ts[0] to be of type 'datetime', but got {type(ts[0])}." + ) with table.connection.transaction: table.update1( { diff --git a/aeon/dj_pipeline/utils/load_metadata.py b/aeon/dj_pipeline/utils/load_metadata.py index 8468a8bd..bd8a2f14 100644 --- a/aeon/dj_pipeline/utils/load_metadata.py +++ b/aeon/dj_pipeline/utils/load_metadata.py @@ -155,7 +155,10 @@ def extract_epoch_config( if isinstance(commit, float) and np.isnan(commit): commit = epoch_config["metadata"]["Revision"] - assert commit, f'Neither "Commit" nor "Revision" found in {metadata_yml_filepath}' + if not commit: + raise ValueError( + f'Neither "Commit" nor "Revision" found in {metadata_yml_filepath}' + ) devices: list[dict] = json.loads( json.dumps( diff --git a/aeon/dj_pipeline/utils/video.py b/aeon/dj_pipeline/utils/video.py index 16a2f27f..2e728171 100644 --- a/aeon/dj_pipeline/utils/video.py +++ b/aeon/dj_pipeline/utils/video.py @@ -24,8 +24,10 @@ def retrieve_video_frames( ): """Retrive video trames from the raw data directory.""" raw_data_dir = Path(raw_data_dir) - assert raw_data_dir.exists() - + if not raw_data_dir.exists(): + raise FileNotFoundError( + f"The specified raw data directory does not exist: {raw_data_dir}" + ) # Load video data videodata = io_api.load( root=raw_data_dir.as_posix(), From c14d7dea57e0a3242df2e5b70d8152f4244d7542 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 12:51:49 +0000 Subject: [PATCH 010/143] fix: eliminate extra parenthesis introduced in the last commit --- tests/dj_pipeline/test_acquisition.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/tests/dj_pipeline/test_acquisition.py b/tests/dj_pipeline/test_acquisition.py index 6432d06f..19939539 100644 --- a/tests/dj_pipeline/test_acquisition.py +++ b/tests/dj_pipeline/test_acquisition.py @@ -2,6 +2,7 @@ from pytest import mark import datajoint as dj + logger = dj.logger @@ -31,29 +32,27 @@ def test_experimentlog_ingestion( ): acquisition = pipeline["acquisition"] - experiment_log_message_count = ( - len( - acquisition.ExperimentLog.Message - & {"experiment_name": test_params["experiment_name"]} - ) + experiment_log_message_count = len( + acquisition.ExperimentLog.Message + & {"experiment_name": test_params["experiment_name"]} + ) if experiment_log_message_count != test_params["experiment_log_message_count"]: raise AssertionError( f"Expected {test_params['experiment_log_message_count']} experiment log messages, but got {experiment_log_message_count}." ) - ) subject_enter_exit_count = len( - acquisition.SubjectEnterExit.Time - & {"experiment_name": test_params["experiment_name"]} - ) + acquisition.SubjectEnterExit.Time + & {"experiment_name": test_params["experiment_name"]} + ) if subject_enter_exit_count != test_params["subject_enter_exit_count"]: raise AssertionError( f"Expected {test_params['subject_enter_exit_count']} subject enter/exit events, but got {subject_enter_exit_count}." ) - + subject_weight_time_count = len( - acquisition.SubjectWeight.WeightTime - & {"experiment_name": test_params["experiment_name"]} - ) + acquisition.SubjectWeight.WeightTime + & {"experiment_name": test_params["experiment_name"]} + ) if subject_weight_time_count != test_params["subject_weight_time_count"]: raise AssertionError( f"Expected {test_params['subject_weight_time_count']} subject weight events, but got {subject_weight_time_count}." From e5daa6d7b9a809f271721e54746134e36af40b9e Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 12:52:45 +0000 Subject: [PATCH 011/143] feat: addressed all S101 checks - delete from pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bcd991ed..64514705 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,7 +98,6 @@ lint.select = [ line-length = 108 lint.ignore = [ "E731", - "S101", "PT004", # Rule `PT004` is deprecated and will be removed in a future release. "PT013", "PLR0912", From b50ca4265c515607c40ea23ee69760de15a2611f Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 13:08:01 +0000 Subject: [PATCH 012/143] fix: resolve PT013 errors: Incorrect import of `pytest` --- tests/dj_pipeline/test_acquisition.py | 6 +++--- tests/dj_pipeline/test_pipeline_instantiation.py | 6 +++--- tests/dj_pipeline/test_qc.py | 4 ++-- tests/dj_pipeline/test_tracking.py | 6 +++--- tests/io/test_api.py | 14 +++++++------- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/dj_pipeline/test_acquisition.py b/tests/dj_pipeline/test_acquisition.py index 19939539..02a60e54 100644 --- a/tests/dj_pipeline/test_acquisition.py +++ b/tests/dj_pipeline/test_acquisition.py @@ -1,12 +1,12 @@ """ Tests for the acquisition pipeline. """ -from pytest import mark +import pytest import datajoint as dj logger = dj.logger -@mark.ingestion +@pytest.mark.ingestion def test_epoch_chunk_ingestion(test_params, pipeline, epoch_chunk_ingestion): acquisition = pipeline["acquisition"] epoch_count = len( @@ -26,7 +26,7 @@ def test_epoch_chunk_ingestion(test_params, pipeline, epoch_chunk_ingestion): ) -@mark.ingestion +@pytest.mark.ingestion def test_experimentlog_ingestion( test_params, pipeline, epoch_chunk_ingestion, experimentlog_ingestion ): diff --git a/tests/dj_pipeline/test_pipeline_instantiation.py b/tests/dj_pipeline/test_pipeline_instantiation.py index 88cc31b0..c7321b09 100644 --- a/tests/dj_pipeline/test_pipeline_instantiation.py +++ b/tests/dj_pipeline/test_pipeline_instantiation.py @@ -4,10 +4,10 @@ logger = dj.logger -from pytest import mark +import pytest -@mark.instantiation +@pytest.mark.instantiation def test_pipeline_instantiation(pipeline): if not hasattr(pipeline["acquisition"], "FoodPatchEvent"): raise AssertionError( @@ -34,7 +34,7 @@ def test_pipeline_instantiation(pipeline): ) -@mark.instantiation +@pytest.mark.instantiation def test_experiment_creation(test_params, pipeline, experiment_creation): acquisition = pipeline["acquisition"] diff --git a/tests/dj_pipeline/test_qc.py b/tests/dj_pipeline/test_qc.py index 85c1772f..9ea93334 100644 --- a/tests/dj_pipeline/test_qc.py +++ b/tests/dj_pipeline/test_qc.py @@ -1,12 +1,12 @@ """ Tests for the QC pipeline. """ -from pytest import mark +import pytest import datajoint as dj logger = dj.logger -@mark.qc +@pytest.mark.qc def test_camera_qc_ingestion(test_params, pipeline, camera_qc_ingestion): qc = pipeline["qc"] diff --git a/tests/dj_pipeline/test_tracking.py b/tests/dj_pipeline/test_tracking.py index 7a388cae..f8c7bfe7 100644 --- a/tests/dj_pipeline/test_tracking.py +++ b/tests/dj_pipeline/test_tracking.py @@ -4,7 +4,7 @@ import pathlib import numpy as np -from pytest import mark +import pytest import datajoint as dj logger = dj.logger @@ -41,8 +41,8 @@ def save_test_data(pipeline, test_params): return test_file -@mark.ingestion -@mark.tracking +@pytest.mark.ingestion +@pytest.mark.tracking def test_camera_tracking_ingestion(test_params, pipeline, camera_tracking_ingestion): tracking = pipeline["tracking"] diff --git a/tests/io/test_api.py b/tests/io/test_api.py index 5f70e5f6..cbe5223b 100644 --- a/tests/io/test_api.py +++ b/tests/io/test_api.py @@ -4,7 +4,7 @@ import pandas as pd import pytest -from pytest import mark +import pytest import aeon from aeon.schema.schemas import exp02 @@ -13,7 +13,7 @@ monotonic_path = Path(__file__).parent.parent / "data" / "monotonic" -@mark.api +@pytest.mark.api def test_load_start_only(): data = aeon.load( nonmonotonic_path, @@ -25,7 +25,7 @@ def test_load_start_only(): raise AssertionError("Loaded data is empty. Expected non-empty data.") -@mark.api +@pytest.mark.api def test_load_end_only(): data = aeon.load( nonmonotonic_path, @@ -37,7 +37,7 @@ def test_load_end_only(): raise AssertionError("Loaded data is empty. Expected non-empty data.") -@mark.api +@pytest.mark.api def test_load_filter_nonchunked(): data = aeon.load( nonmonotonic_path, exp02.Metadata, start=pd.Timestamp("2022-06-06T09:00:00") @@ -46,7 +46,7 @@ def test_load_filter_nonchunked(): raise AssertionError("Loaded data is empty. Expected non-empty data.") -@mark.api +@pytest.mark.api def test_load_monotonic(): data = aeon.load(monotonic_path, exp02.Patch2.Encoder, downsample=None) if len(data) <= 0: @@ -56,7 +56,7 @@ def test_load_monotonic(): raise AssertionError("Data index is not monotonic increasing.") -@mark.api +@pytest.mark.api def test_load_nonmonotonic(): data = aeon.load(nonmonotonic_path, exp02.Patch2.Encoder, downsample=None) if data.index.is_monotonic_increasing: @@ -65,7 +65,7 @@ def test_load_nonmonotonic(): ) -@mark.api +@pytest.mark.api def test_load_encoder_with_downsampling(): DOWNSAMPLE_PERIOD = 0.02 data = aeon.load(monotonic_path, exp02.Patch2.Encoder, downsample=True) From d5e3fe65e2ed576b6bdb087a708ffefe9ba3ec2b Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 13:08:58 +0000 Subject: [PATCH 013/143] fix: remove PT013 from pyproject.toml and add docstrings for `dj_config()` --- pyproject.toml | 1 - tests/dj_pipeline/conftest.py | 9 ++++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 64514705..e11166a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,6 @@ line-length = 108 lint.ignore = [ "E731", "PT004", # Rule `PT004` is deprecated and will be removed in a future release. - "PT013", "PLR0912", "PLR0913", "PLR0915", diff --git a/tests/dj_pipeline/conftest.py b/tests/dj_pipeline/conftest.py index af0dfc56..1faf316e 100644 --- a/tests/dj_pipeline/conftest.py +++ b/tests/dj_pipeline/conftest.py @@ -48,7 +48,14 @@ def test_params(): @pytest.fixture(autouse=True, scope="session") def dj_config(): - """Configures DataJoint connection and loads custom settings.""" + """ + Configures DataJoint connection and loads custom settings. + + This fixture sets up the DataJoint configuration using the + 'dj_local_conf.json' file. It raises FileNotFoundError if the file + does not exist, and KeyError if 'custom' is not found in the + DataJoint configuration. + """ dj_config_fp = pathlib.Path("dj_local_conf.json") if not dj_config_fp.exists(): raise FileNotFoundError( From a9cacf263c73219dc614229b5d7e27d094a77980 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 13:22:14 +0000 Subject: [PATCH 014/143] fix: resolve E501 errors --- pyproject.toml | 9 +-------- tests/dj_pipeline/test_acquisition.py | 14 +++++++++----- tests/dj_pipeline/test_tracking.py | 6 ++++-- tests/io/test_api.py | 3 ++- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e11166a2..23117142 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -99,7 +99,7 @@ line-length = 108 lint.ignore = [ "E731", "PT004", # Rule `PT004` is deprecated and will be removed in a future release. - "PLR0912", + "PLR0912", "PLR0913", "PLR0915", ] @@ -120,13 +120,6 @@ extend-exclude = [ "D106", # skip adding docstrings for nested streams ] "aeon/dj_pipeline/*" = [ - "B006", - "B021", - "D101", # skip adding docstrings for table class since it is added inside definition - "D102", # skip adding docstrings for make function - "D103", # skip adding docstrings for public functions - "D106", # skip adding docstrings for Part tables - "E501", "F401", # ignore unused import errors "B905", # ignore unused import errors "E999", diff --git a/tests/dj_pipeline/test_acquisition.py b/tests/dj_pipeline/test_acquisition.py index 02a60e54..ab6055e5 100644 --- a/tests/dj_pipeline/test_acquisition.py +++ b/tests/dj_pipeline/test_acquisition.py @@ -32,21 +32,24 @@ def test_experimentlog_ingestion( ): acquisition = pipeline["acquisition"] - experiment_log_message_count = len( + exp_log_message_count = len( acquisition.ExperimentLog.Message & {"experiment_name": test_params["experiment_name"]} ) - if experiment_log_message_count != test_params["experiment_log_message_count"]: + if exp_log_message_count != test_params["experiment_log_message_count"]: raise AssertionError( - f"Expected {test_params['experiment_log_message_count']} experiment log messages, but got {experiment_log_message_count}." + f"Expected {test_params['experiment_log_message_count']} log messages," + f"but got {exp_log_message_count}." ) + subject_enter_exit_count = len( acquisition.SubjectEnterExit.Time & {"experiment_name": test_params["experiment_name"]} ) if subject_enter_exit_count != test_params["subject_enter_exit_count"]: raise AssertionError( - f"Expected {test_params['subject_enter_exit_count']} subject enter/exit events, but got {subject_enter_exit_count}." + f"Expected {test_params['subject_enter_exit_count']} subject enter/exit events," + f"but got {subject_enter_exit_count}." ) subject_weight_time_count = len( @@ -55,5 +58,6 @@ def test_experimentlog_ingestion( ) if subject_weight_time_count != test_params["subject_weight_time_count"]: raise AssertionError( - f"Expected {test_params['subject_weight_time_count']} subject weight events, but got {subject_weight_time_count}." + f"Expected {test_params['subject_weight_time_count']} subject weight events," + f"but got {subject_weight_time_count}." ) diff --git a/tests/dj_pipeline/test_tracking.py b/tests/dj_pipeline/test_tracking.py index f8c7bfe7..d36888cf 100644 --- a/tests/dj_pipeline/test_tracking.py +++ b/tests/dj_pipeline/test_tracking.py @@ -12,7 +12,8 @@ index = 0 column_name = "position_x" # data column to run test on -file_name = "exp0.2-r0-20220524090000-21053810-20220524082942-0-0.npy" # test file to be saved with save_test_data +file_name = "exp0.2-r0-20220524090000-21053810-20220524082942-0-0.npy" +# test file to be saved with save_test_data def save_test_data(pipeline, test_params): @@ -49,7 +50,8 @@ def test_camera_tracking_ingestion(test_params, pipeline, camera_tracking_ingest camera_tracking_object_count = len(tracking.CameraTracking.Object()) if camera_tracking_object_count != test_params["camera_tracking_object_count"]: raise AssertionError( - f"Expected camera tracking object count {test_params['camera_tracking_object_count']},but got {camera_tracking_object_count}." + f"Expected camera tracking object count {test_params['camera_tracking_object_count']}," + f"but got {camera_tracking_object_count}." ) key = tracking.CameraTracking.Object().fetch("KEY")[index] diff --git a/tests/io/test_api.py b/tests/io/test_api.py index cbe5223b..63739a6d 100644 --- a/tests/io/test_api.py +++ b/tests/io/test_api.py @@ -89,7 +89,8 @@ def test_load_encoder_with_downsampling(): f"The last timestamp of downsampled data is not within {DOWNSAMPLE_PERIOD*1000} ms of raw data." ) - # Check that the minimum difference between consecutive timestamps in the downsampled data is at least 20ms (50Hz) + # Check that the minimum difference between consecutive timestamps in the downsampled data + # is at least 20ms (50Hz) min_diff = data.index.to_series().diff().dt.total_seconds().min() if min_diff < DOWNSAMPLE_PERIOD: raise AssertionError( From bcb4c8200575bc758c5cbf947cc1eded3131053c Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 13:23:24 +0000 Subject: [PATCH 015/143] fix: resolve E401 error --- aeon/schema/social_03.py | 2 -- pyproject.toml | 1 - 2 files changed, 3 deletions(-) diff --git a/aeon/schema/social_03.py b/aeon/schema/social_03.py index 8ff87066..99fe0d17 100644 --- a/aeon/schema/social_03.py +++ b/aeon/schema/social_03.py @@ -1,7 +1,5 @@ """ This module contains the schema for the social_03 dataset. """ -import json -import pandas as pd import aeon.io.reader as _reader from aeon.schema.streams import Stream diff --git a/pyproject.toml b/pyproject.toml index 23117142..e640c538 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,6 @@ extend-exclude = [ "D106", # skip adding docstrings for nested streams ] "aeon/dj_pipeline/*" = [ - "F401", # ignore unused import errors "B905", # ignore unused import errors "E999", "S324", From e805caf49940370c0fc761b1e5bc6d2395a932ad Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 13:27:40 +0000 Subject: [PATCH 016/143] fix: resolve B905 by setting default value of strict to False --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e640c538..f6bba02f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,6 @@ extend-exclude = [ "D106", # skip adding docstrings for nested streams ] "aeon/dj_pipeline/*" = [ - "B905", # ignore unused import errors "E999", "S324", "E722", From 35e0096fb656ef6018d79db4699fe88a90d30641 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 13:28:11 +0000 Subject: [PATCH 017/143] fix: resolve B905 by setting default value of strict to False --- aeon/analysis/block_plotting.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index ac06eaa2..a9d22af9 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -59,7 +59,7 @@ def conv2d(arr, kernel): def gen_subject_colors_dict(subject_names): """Generates a dictionary of subject colors based on a list of subjects.""" - return {s: c for s, c in zip(subject_names, subject_colors)} + return {s: c for s, c in zip(subject_names, subject_colors, strict=False)} def gen_patch_style_dict(patch_names): @@ -70,8 +70,12 @@ def gen_patch_style_dict(patch_names): - patch_linestyles_dict: patch name to linestyle """ return { - "colors": {p: c for p, c in zip(patch_names, patch_colors)}, - "markers": {p: m for p, m in zip(patch_names, patch_markers)}, - "symbols": {p: s for p, s in zip(patch_names, patch_markers_symbols)}, - "linestyles": {p: ls for p, ls in zip(patch_names, patch_markers_linestyles)}, + "colors": {p: c for p, c in zip(patch_names, patch_colors, strict=False)}, + "markers": {p: m for p, m in zip(patch_names, patch_markers, strict=False)}, + "symbols": { + p: s for p, s in zip(patch_names, patch_markers_symbols, strict=False) + }, + "linestyles": { + p: ls for p, ls in zip(patch_names, patch_markers_linestyles, strict=False) + }, } From dff8a87ab7c44a4c6097a43739e87d64fff2ae50 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 13:38:00 +0000 Subject: [PATCH 018/143] fix: resolve D205 error --- aeon/analysis/block_plotting.py | 7 ++++++- aeon/io/reader.py | 4 ++-- pyproject.toml | 10 ---------- 3 files changed, 8 insertions(+), 13 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index a9d22af9..bed0713b 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -63,12 +63,17 @@ def gen_subject_colors_dict(subject_names): def gen_patch_style_dict(patch_names): - """Based on a list of patches, generates a dictionary of: + """ + + Based on a list of patches, generates a dictionary of: + - patch_colors_dict: patch name to color - patch_markers_dict: patch name to marker - patch_symbols_dict: patch name to symbol - patch_linestyles_dict: patch name to linestyle + """ + return { "colors": {p: c for p, c in zip(patch_names, patch_colors, strict=False)}, "markers": {p: m for p, m in zip(patch_names, patch_markers, strict=False)}, diff --git a/aeon/io/reader.py b/aeon/io/reader.py index f631ab43..09a7a9e0 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -159,8 +159,8 @@ def read(self, file): class JsonList(Reader): - """Extracts data from json list (.jsonl) files, where the key "seconds" - stores the Aeon timestamp, in seconds. + """ + Extracts data from json list (.jsonl) files, where the key "seconds" stores the Aeon timestamp, in seconds. """ def __init__(self, pattern, columns=(), root_key="value", extension="jsonl"): diff --git a/pyproject.toml b/pyproject.toml index f6bba02f..a60bc692 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,16 +120,6 @@ extend-exclude = [ "D106", # skip adding docstrings for nested streams ] "aeon/dj_pipeline/*" = [ - "E999", - "S324", - "E722", - "S110", - "F821", - "B904", - "UP038", - "S607", - "S605", - "D205", "D202", "F403", "PLR2004", From 0b86908c7340bdfbdf5d12f8583bb2b52ca3afd9 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 13:39:14 +0000 Subject: [PATCH 019/143] fix: resolve D202 error --- aeon/analysis/block_plotting.py | 1 - pyproject.toml | 1 - 2 files changed, 2 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index bed0713b..ac002814 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -73,7 +73,6 @@ def gen_patch_style_dict(patch_names): - patch_linestyles_dict: patch name to linestyle """ - return { "colors": {p: c for p, c in zip(patch_names, patch_colors, strict=False)}, "markers": {p: m for p, m in zip(patch_names, patch_markers, strict=False)}, diff --git a/pyproject.toml b/pyproject.toml index a60bc692..0e2fc5d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -120,7 +120,6 @@ extend-exclude = [ "D106", # skip adding docstrings for nested streams ] "aeon/dj_pipeline/*" = [ - "D202", "F403", "PLR2004", "SIM108", From de642c848e77fa5387b3c0df152804ccbd30c9ad Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 13:41:30 +0000 Subject: [PATCH 020/143] fix: resolve I001 error --- aeon/analysis/plotting.py | 1 + pyproject.toml | 9 +-------- tests/dj_pipeline/test_acquisition.py | 2 +- tests/dj_pipeline/test_qc.py | 2 +- tests/dj_pipeline/test_tracking.py | 2 +- tests/io/test_api.py | 1 - 6 files changed, 5 insertions(+), 12 deletions(-) diff --git a/aeon/analysis/plotting.py b/aeon/analysis/plotting.py index ed82a519..fb9a1ff4 100644 --- a/aeon/analysis/plotting.py +++ b/aeon/analysis/plotting.py @@ -1,6 +1,7 @@ """Helper functions for plotting data.""" import math + import matplotlib.pyplot as plt import numpy as np import pandas as pd diff --git a/pyproject.toml b/pyproject.toml index 0e2fc5d1..6ede470b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,14 +119,7 @@ extend-exclude = [ "D101", # skip adding docstrings for schema classes "D106", # skip adding docstrings for nested streams ] -"aeon/dj_pipeline/*" = [ - "F403", - "PLR2004", - "SIM108", - "PLW0127", - "PLR2004", - "I001", -] +"aeon/dj_pipeline/*" = [] [tool.ruff.lint.pydocstyle] convention = "google" diff --git a/tests/dj_pipeline/test_acquisition.py b/tests/dj_pipeline/test_acquisition.py index ab6055e5..37dc7f47 100644 --- a/tests/dj_pipeline/test_acquisition.py +++ b/tests/dj_pipeline/test_acquisition.py @@ -1,7 +1,7 @@ """ Tests for the acquisition pipeline. """ -import pytest import datajoint as dj +import pytest logger = dj.logger diff --git a/tests/dj_pipeline/test_qc.py b/tests/dj_pipeline/test_qc.py index 9ea93334..c2750c99 100644 --- a/tests/dj_pipeline/test_qc.py +++ b/tests/dj_pipeline/test_qc.py @@ -1,7 +1,7 @@ """ Tests for the QC pipeline. """ -import pytest import datajoint as dj +import pytest logger = dj.logger diff --git a/tests/dj_pipeline/test_tracking.py b/tests/dj_pipeline/test_tracking.py index d36888cf..5692c8cf 100644 --- a/tests/dj_pipeline/test_tracking.py +++ b/tests/dj_pipeline/test_tracking.py @@ -3,9 +3,9 @@ import datetime import pathlib +import datajoint as dj import numpy as np import pytest -import datajoint as dj logger = dj.logger diff --git a/tests/io/test_api.py b/tests/io/test_api.py index 63739a6d..14e7d10f 100644 --- a/tests/io/test_api.py +++ b/tests/io/test_api.py @@ -4,7 +4,6 @@ import pandas as pd import pytest -import pytest import aeon from aeon.schema.schemas import exp02 From a7f72ee0d1167f8b75a4d7bbd9213c486fbbbf3c Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 14:30:35 +0000 Subject: [PATCH 021/143] fix: resolve B006 in `plotting.py` by fixing mutable default argument --- aeon/dj_pipeline/utils/plotting.py | 10 ++++++---- pyproject.toml | 5 ++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index 0e3b29cb..564dcca2 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -551,18 +551,20 @@ def plot_visit_time_distribution(visit_key, freq="D"): return fig -def _get_region_data( - visit_key, attrs=["in_nest", "in_arena", "in_corridor", "in_patch"] -): +def _get_region_data(visit_key, attrs=None): """Retrieve region data from VisitTimeDistribution tables. Args: visit_key (dict): Key from the Visit table - attrs (list, optional): List of column names (in VisitTimeDistribution tables) to retrieve. Defaults to all. + attrs (list, optional): List of column names (in VisitTimeDistribution tables) to retrieve. + Defaults is None, which will create a new list with the desired default values inside the function. Returns: region (pd.DataFrame): Timestamped region info """ + if attrs is None: + attrs = ["in_nest", "in_arena", "in_corridor", "in_patch"] + visit_start, visit_end = (VisitEnd & visit_key).fetch1("visit_start", "visit_end") region = pd.DataFrame() diff --git a/pyproject.toml b/pyproject.toml index 6ede470b..fd85d178 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -119,7 +119,10 @@ extend-exclude = [ "D101", # skip adding docstrings for schema classes "D106", # skip adding docstrings for nested streams ] -"aeon/dj_pipeline/*" = [] +"aeon/dj_pipeline/*" = [ + "D101", # skip adding docstrings for schema classes + "D106", # skip adding docstrings for nested streams +] [tool.ruff.lint.pydocstyle] convention = "google" From 0b631dc15a12634d4fd4e0e6c9a07d13e6b536d7 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 14:39:38 +0000 Subject: [PATCH 022/143] fix: resolve B021 error removed dynamic references and kept the docstring static while ensuring it remains informative --- aeon/dj_pipeline/utils/streams_maker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index 2d85c500..734ad88d 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -141,10 +141,10 @@ class DeviceDataStream(dj.Imported): @property def key_source(self): - f""" - Only the combination of Chunk and {device_type} with overlapping time - + Chunk(s) that started after {device_type} install time and ended before {device_type} remove time - + Chunk(s) that started after {device_type} install time for {device_type} that are not yet removed + """ + Only the combination of Chunk and device_type with overlapping time + + Chunk(s) that started after device_type install time and ended before device_type remove time + + Chunk(s) that started after device_type install time for device_type that are not yet removed """ return ( acquisition.Chunk From 7dd455a81ada636f245fe3d64939d61f2fabdd4a Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 16:07:37 +0000 Subject: [PATCH 023/143] fix: update deprecation of `datetime.utcnow()` --- aeon/dj_pipeline/analysis/block_analysis.py | 4 ++-- aeon/dj_pipeline/subject.py | 20 +++++++++++--------- 2 files changed, 13 insertions(+), 11 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index ff710f48..144ff37b 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -3,7 +3,7 @@ import itertools import json from collections import defaultdict -from datetime import datetime +from datetime import datetime, timezone import datajoint as dj import numpy as np @@ -264,7 +264,7 @@ def make(self, key): # multiple patch rates per block is unexpected, log a note and pick the first rate to move forward AnalysisNote.insert1( { - "note_timestamp": datetime.utcnow(), + "note_timestamp": datetime.now(timezone.utc), "note_type": "Multiple patch rates", "note": f"Found multiple patch rates for block {key} - patch: {patch_name} - rates: {depletion_state_df.rate.unique()}", } diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index 4b79ce72..fceb3042 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -3,7 +3,7 @@ import json import os import time -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone import datajoint as dj import requests @@ -193,7 +193,7 @@ def get_reference_weight(cls, subject_name): "procedure_date", order_by="procedure_date DESC", limit=1 )[0] else: - ref_date = datetime.now().date() + ref_date = datetime.now(timezone.utc).date() weight_query = SubjectWeight & subj_key & f"weight_time < '{ref_date}'" ref_weight = ( @@ -205,7 +205,7 @@ def get_reference_weight(cls, subject_name): entry = { "subject": subject_name, "reference_weight": ref_weight, - "last_updated_time": datetime.utcnow(), + "last_updated_time": datetime.now(timezone.utc), } cls.update1(entry) if cls & {"subject": subject_name} else cls.insert1(entry) @@ -247,7 +247,7 @@ class PyratIngestion(dj.Imported): schedule_interval = 12 # schedule interval in number of hours def _auto_schedule(self): - utc_now = datetime.utcnow() + utc_now = datetime.now(timezone.utc) next_task_schedule_time = utc_now + timedelta(hours=self.schedule_interval) if ( @@ -261,8 +261,8 @@ def _auto_schedule(self): ) def make(self, key): - execution_time = datetime.utcnow() """Automatically import or update entries in the Subject table.""" + execution_time = datetime.now(timezone.utc) new_eartags = [] for responsible_id in lab.User.fetch("responsible_id"): # 1 - retrieve all animals from this user @@ -298,7 +298,7 @@ def make(self, key): new_entry_count += 1 logger.info(f"Inserting {new_entry_count} new subject(s) from Pyrat") - completion_time = datetime.utcnow() + completion_time = datetime.now(timezone.utc) self.insert1( { **key, @@ -330,7 +330,7 @@ class PyratCommentWeightProcedure(dj.Imported): key_source = (PyratIngestion * SubjectDetail) & "available = 1" def make(self, key): - execution_time = datetime.utcnow() + execution_time = datetime.now(timezone.utc) logger.info("Extracting weights/comments/procedures") eartag_or_id = key["subject"] @@ -373,7 +373,7 @@ def make(self, key): # compute/update reference weight SubjectReferenceWeight.get_reference_weight(eartag_or_id) finally: - completion_time = datetime.utcnow() + completion_time = datetime.now(timezone.utc) self.insert1( { **key, @@ -393,7 +393,9 @@ class CreatePyratIngestionTask(dj.Computed): def make(self, key): """Create one new PyratIngestionTask for every newly added users.""" - PyratIngestionTask.insert1({"pyrat_task_scheduled_time": datetime.utcnow()}) + PyratIngestionTask.insert1( + {"pyrat_task_scheduled_time": datetime.now(timezone.utc)} + ) time.sleep(1) self.insert1(key) From 7b614c9f2d8fb005883818b6ea3c78c624b7de7a Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 16:23:41 +0000 Subject: [PATCH 024/143] fix: resolve D102 by adding docstring in public methods --- aeon/dj_pipeline/acquisition.py | 8 ++ aeon/dj_pipeline/analysis/block_analysis.py | 11 +- aeon/dj_pipeline/analysis/visit.py | 4 + aeon/dj_pipeline/analysis/visit_analysis.py | 4 + aeon/dj_pipeline/qc.py | 2 + aeon/dj_pipeline/report.py | 8 ++ aeon/dj_pipeline/streams.py | 111 +++++++++++--------- aeon/dj_pipeline/subject.py | 5 + aeon/dj_pipeline/tracking.py | 3 + 9 files changed, 104 insertions(+), 52 deletions(-) diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index c4d9b606..6211bfd8 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -138,6 +138,7 @@ class Note(dj.Part): @classmethod def get_data_directory(cls, experiment_key, directory_type="raw", as_posix=False): + """Get the data directory for the specified ``experiment_key`` and ``directory_type``.""" try: repo_name, dir_path = ( cls.Directory & experiment_key & {"directory_type": directory_type} @@ -158,6 +159,7 @@ def get_data_directory(cls, experiment_key, directory_type="raw", as_posix=False @classmethod def get_data_directories(cls, experiment_key, directory_types=None, as_posix=False): + """Get the data directories for the specified ``experiment_key`` and ``directory_types``.""" if directory_types is None: directory_types = (cls.Directory & experiment_key).fetch( "directory_type", order_by="load_order" @@ -318,6 +320,7 @@ class ActiveRegion(dj.Part): """ def make(self, key): + """Ingest metadata into EpochConfig.""" from aeon.dj_pipeline.utils import streams_maker from aeon.dj_pipeline.utils.load_metadata import ( extract_epoch_config, @@ -398,6 +401,7 @@ class File(dj.Part): @classmethod def ingest_chunks(cls, experiment_name): + """Ingest chunks for the specified ``experiment_name``.""" device_name = _ref_device_mapping.get(experiment_name, "CameraTop") all_chunks, raw_data_dirs = _get_all_chunks(experiment_name, device_name) @@ -553,6 +557,7 @@ class SubjectWeight(dj.Part): """ def make(self, key): + """Ingest environment data into Environment table.""" chunk_start, chunk_end = (Chunk & key).fetch1("chunk_start", "chunk_end") # Populate the part table @@ -616,6 +621,7 @@ class Name(dj.Part): """ def make(self, key): + """Ingest active configuration data into EnvironmentActiveConfiguration table.""" chunk_start, chunk_end = (Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = Experiment.get_data_directories(key) devices_schema = getattr( @@ -647,6 +653,7 @@ def make(self, key): def _get_all_chunks(experiment_name, device_name): + """Get all chunks for the specified ``experiment_name`` and ``device_name``.""" directory_types = ["quality-control", "raw"] raw_data_dirs = { dir_type: Experiment.get_data_directory( @@ -672,6 +679,7 @@ def _get_all_chunks(experiment_name, device_name): def _match_experiment_directory(experiment_name, path, directories): + """Match the path to the experiment directory.""" for k, v in directories.items(): raw_data_dir = v if pathlib.Path(raw_data_dir) in list(path.parents): diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 144ff37b..8a7a05fb 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -137,8 +137,10 @@ class BlockAnalysis(dj.Computed): @property def key_source(self): - # Ensure that the chunk ingestion has caught up with this block before processing - # (there exists a chunk that ends after the block end time) + """ + Ensure that the chunk ingestion has caught up with this block before processing + (there exists a chunk that ends after the block end time) + """ ks = Block.aggr(acquisition.Chunk, latest_chunk_end="MAX(chunk_end)") ks = ks * Block & "latest_chunk_end >= block_end" & "block_end IS NOT NULL" return ks @@ -430,6 +432,7 @@ class Preference(dj.Part): key_source = BlockAnalysis & BlockAnalysis.Patch & BlockAnalysis.Subject def make(self, key): + """Compute preference scores for each subject at each patch within a block.""" block_patches = (BlockAnalysis.Patch & key).fetch(as_dict=True) block_subjects = (BlockAnalysis.Subject & key).fetch(as_dict=True) subject_names = [s["subject_name"] for s in block_subjects] @@ -720,6 +723,7 @@ class BlockPatchPlots(dj.Computed): """ def make(self, key): + """Compute and plot various block-level statistics and visualizations.""" # Define subject colors and patch styling for plotting exp_subject_names = (acquisition.Experiment.Subject & key).fetch( "subject", order_by="subject" @@ -1461,6 +1465,8 @@ class BlockSubjectPositionPlots(dj.Computed): """ def make(self, key): + """Compute and plot various block-level statistics and visualizations.""" + # Get some block info block_start, block_end = (Block & key).fetch1("block_start", "block_end") chunk_restriction = acquisition.create_chunk_restriction( @@ -1737,6 +1743,7 @@ class Bout(dj.Part): """ def make(self, key): + """ Compute and store foraging bouts for each subject in the block. """ foraging_bout_df = get_foraging_bouts(key) foraging_bout_df.rename( columns={ diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index 5dcea011..b8440be9 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -69,11 +69,13 @@ class Visit(dj.Part): @property def key_source(self): + """Key source for OverlapVisit.""" return dj.U("experiment_name", "place", "overlap_start") & ( Visit & VisitEnd ).proj(overlap_start="visit_start") def make(self, key): + """Populate OverlapVisit table with overlapping visits.""" visit_starts, visit_ends = ( Visit * VisitEnd & key & {"visit_start": key["overlap_start"]} ).fetch("visit_start", "visit_end") @@ -194,6 +196,7 @@ def ingest_environment_visits(experiment_names: list | None = None): def get_maintenance_periods(experiment_name, start, end): + """Get maintenance periods for the specified experiment and time range.""" # get states from acquisition.Environment.EnvironmentState chunk_restriction = acquisition.create_chunk_restriction( experiment_name, start, end @@ -242,6 +245,7 @@ def get_maintenance_periods(experiment_name, start, end): def filter_out_maintenance_periods(data_df, maintenance_period, end_time, dropna=False): + """Filter out maintenance periods from the data_df.""" maint_period = maintenance_period.copy() while maint_period: (maintenance_start, maintenance_end) = maint_period[0] diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 1eb192c7..c36cccbe 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -103,6 +103,7 @@ def key_source(self): ) def make(self, key): + """Populate VisitSubjectPosition for each visit""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -275,6 +276,7 @@ class FoodPatch(dj.Part): ) def make(self, key): + """Populate VisitTimeDistribution for each visit""" visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) @@ -430,6 +432,7 @@ class FoodPatch(dj.Part): ) def make(self, key): + """Populate VisitSummary for each visit""" visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) @@ -574,6 +577,7 @@ class VisitForagingBout(dj.Computed): ) * acquisition.ExperimentFoodPatch def make(self, key): + """Populate VisitForagingBout for each visit.""" visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") # get in_patch timestamps diff --git a/aeon/dj_pipeline/qc.py b/aeon/dj_pipeline/qc.py index 200285da..9cc6bc63 100644 --- a/aeon/dj_pipeline/qc.py +++ b/aeon/dj_pipeline/qc.py @@ -58,6 +58,7 @@ class CameraQC(dj.Imported): @property def key_source(self): + """Return the keys for the CameraQC table""" return ( acquisition.Chunk * ( @@ -71,6 +72,7 @@ def key_source(self): ) # CameraTop def make(self, key): + """Perform quality control checks on the CameraTop stream""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) diff --git a/aeon/dj_pipeline/report.py b/aeon/dj_pipeline/report.py index ac929625..03f8f89c 100644 --- a/aeon/dj_pipeline/report.py +++ b/aeon/dj_pipeline/report.py @@ -45,6 +45,7 @@ class InArenaSummaryPlot(dj.Computed): } def make(self, key): + """Make method for InArenaSummaryPlot table""" in_arena_start, in_arena_end = ( analysis.InArena * analysis.InArenaEnd & key ).fetch1("in_arena_start", "in_arena_end") @@ -300,6 +301,7 @@ class SubjectRewardRateDifference(dj.Computed): key_source = acquisition.Experiment.Subject & analysis.InArenaRewardRate def make(self, key): + """Insert reward rate differences plot in SubjectRewardRateDifference table.""" from aeon.dj_pipeline.utils.plotting import plot_reward_rate_differences fig = plot_reward_rate_differences(key) @@ -343,6 +345,7 @@ class SubjectWheelTravelledDistance(dj.Computed): key_source = acquisition.Experiment.Subject & analysis.InArenaSummary def make(self, key): + """Insert wheel travelled distance plot in SubjectWheelTravelledDistance table.""" from aeon.dj_pipeline.utils.plotting import plot_wheel_travelled_distance in_arena_keys = (analysis.InArenaSummary & key).fetch("KEY") @@ -386,6 +389,7 @@ class ExperimentTimeDistribution(dj.Computed): """ def make(self, key): + """Insert average time distribution plot into ExperimentTimeDistribution table.""" from aeon.dj_pipeline.utils.plotting import plot_average_time_distribution in_arena_keys = (analysis.InArenaTimeDistribution & key).fetch("KEY") @@ -420,6 +424,7 @@ def delete_outdated_entries(cls): def delete_outdated_plot_entries(): + """Delete outdated entries in the tables that store plots.""" for tbl in ( SubjectRewardRateDifference, SubjectWheelTravelledDistance, @@ -452,6 +457,7 @@ class VisitDailySummaryPlot(dj.Computed): ) def make(self, key): + """Make method for VisitDailySummaryPlot table""" from aeon.dj_pipeline.utils.plotting import ( plot_foraging_bouts_count, plot_foraging_bouts_distribution, @@ -551,6 +557,7 @@ def make(self, key): def _make_path(in_arena_key): + """Make path for saving figures""" store_stage = pathlib.Path(dj.config["stores"]["djstore"]["stage"]) experiment_name, subject, in_arena_start = (analysis.InArena & in_arena_key).fetch1( "experiment_name", "subject", "in_arena_start" @@ -566,6 +573,7 @@ def _make_path(in_arena_key): def _save_figs(figs, fig_names, save_dir, prefix, extension=".png"): + """Save figures and return a dictionary with figure names and file paths""" fig_dict = {} for fig, figname in zip(figs, fig_names): fig_fp = save_dir / (prefix + "_" + figname + extension) diff --git a/aeon/dj_pipeline/streams.py b/aeon/dj_pipeline/streams.py index b692a5aa..225e7198 100644 --- a/aeon/dj_pipeline/streams.py +++ b/aeon/dj_pipeline/streams.py @@ -189,6 +189,7 @@ def key_source(self): ) def make(self, key): + """Load and insert RfidEvents data stream for a given chunk and RfidReader.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -242,11 +243,11 @@ class SpinnakerVideoSourceVideo(dj.Imported): @property def key_source(self): - f""" - Only the combination of Chunk and SpinnakerVideoSource with overlapping time - + Chunk(s) that started after SpinnakerVideoSource install time and ended before SpinnakerVideoSource remove time - + Chunk(s) that started after SpinnakerVideoSource install time for SpinnakerVideoSource that are not yet removed - """ + """ + Only the combination of Chunk and SpinnakerVideoSource with overlapping time + + Chunk(s) that started after SpinnakerVideoSource install time and ended before SpinnakerVideoSource remove time + + Chunk(s) that started after SpinnakerVideoSource install time for SpinnakerVideoSource that are not yet removed + """ return ( acquisition.Chunk * SpinnakerVideoSource.join(SpinnakerVideoSource.RemovalTime, left=True) @@ -255,6 +256,7 @@ def key_source(self): ) def make(self, key): + """Load and insert Video data stream for a given chunk and SpinnakerVideoSource.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -307,11 +309,11 @@ class UndergroundFeederBeamBreak(dj.Imported): @property def key_source(self): - f""" - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + """ + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -320,6 +322,7 @@ def key_source(self): ) def make(self, key): + """Load and insert BeamBreak data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -372,11 +375,11 @@ class UndergroundFeederDeliverPellet(dj.Imported): @property def key_source(self): - f""" - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + """ + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -385,6 +388,7 @@ def key_source(self): ) def make(self, key): + """Load and insert DeliverPellet data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -439,11 +443,11 @@ class UndergroundFeederDepletionState(dj.Imported): @property def key_source(self): - f""" - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + """ + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -452,6 +456,7 @@ def key_source(self): ) def make(self, key): + """Load and insert DepletionState data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -505,11 +510,11 @@ class UndergroundFeederEncoder(dj.Imported): @property def key_source(self): - f""" - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + """ + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -518,6 +523,7 @@ def key_source(self): ) def make(self, key): + """Load and insert Encoder data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -570,11 +576,11 @@ class UndergroundFeederManualDelivery(dj.Imported): @property def key_source(self): - f""" - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + """ + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -583,6 +589,7 @@ def key_source(self): ) def make(self, key): + """Load and insert ManualDelivery data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -635,11 +642,11 @@ class UndergroundFeederMissedPellet(dj.Imported): @property def key_source(self): - f""" - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + """ + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -648,6 +655,7 @@ def key_source(self): ) def make(self, key): + """Load and insert MissedPellet data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -700,11 +708,11 @@ class UndergroundFeederRetriedDelivery(dj.Imported): @property def key_source(self): - f""" - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + """ + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -713,6 +721,7 @@ def key_source(self): ) def make(self, key): + """Load and insert RetriedDelivery data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -766,11 +775,11 @@ class WeightScaleWeightFiltered(dj.Imported): @property def key_source(self): - f""" - Only the combination of Chunk and WeightScale with overlapping time - + Chunk(s) that started after WeightScale install time and ended before WeightScale remove time - + Chunk(s) that started after WeightScale install time for WeightScale that are not yet removed - """ + """ + Only the combination of Chunk and WeightScale with overlapping time + + Chunk(s) that started after WeightScale install time and ended before WeightScale remove time + + Chunk(s) that started after WeightScale install time for WeightScale that are not yet removed + """ return ( acquisition.Chunk * WeightScale.join(WeightScale.RemovalTime, left=True) & "chunk_start >= weight_scale_install_time" @@ -778,6 +787,7 @@ def key_source(self): ) def make(self, key): + """Load and insert WeightFiltered data stream for a given chunk and WeightScale.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -831,11 +841,11 @@ class WeightScaleWeightRaw(dj.Imported): @property def key_source(self): - f""" - Only the combination of Chunk and WeightScale with overlapping time - + Chunk(s) that started after WeightScale install time and ended before WeightScale remove time - + Chunk(s) that started after WeightScale install time for WeightScale that are not yet removed - """ + """ + Only the combination of Chunk and WeightScale with overlapping time + + Chunk(s) that started after WeightScale install time and ended before WeightScale remove time + + Chunk(s) that started after WeightScale install time for WeightScale that are not yet removed + """ return ( acquisition.Chunk * WeightScale.join(WeightScale.RemovalTime, left=True) & "chunk_start >= weight_scale_install_time" @@ -843,6 +853,7 @@ def key_source(self): ) def make(self, key): + """Load and insert WeightRaw data stream for a given chunk and WeightScale.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index fceb3042..70812dcc 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -58,6 +58,7 @@ class SubjectDetail(dj.Imported): """ def make(self, key): + """Automatically import and update entries in the Subject table.""" eartag_or_id = key["subject"] # cage id, sex, line/strain, genetic background, dob, lab id params = { @@ -183,6 +184,7 @@ class SubjectReferenceWeight(dj.Manual): @classmethod def get_reference_weight(cls, subject_name): + """Get the reference weight for the subject.""" subj_key = {"subject": subject_name} food_restrict_query = ( @@ -247,6 +249,7 @@ class PyratIngestion(dj.Imported): schedule_interval = 12 # schedule interval in number of hours def _auto_schedule(self): + """Automatically schedule the next task.""" utc_now = datetime.now(timezone.utc) next_task_schedule_time = utc_now + timedelta(hours=self.schedule_interval) @@ -330,6 +333,7 @@ class PyratCommentWeightProcedure(dj.Imported): key_source = (PyratIngestion * SubjectDetail) & "available = 1" def make(self, key): + """Automatically import or update entries in the PyratCommentWeightProcedure table.""" execution_time = datetime.now(timezone.utc) logger.info("Extracting weights/comments/procedures") @@ -465,6 +469,7 @@ def make(self, key): def get_pyrat_data(endpoint: str, params: dict = None, **kwargs): + """Get data from PyRat API.""" base_url = "https://swc.pyrat.cloud/api/v3/" pyrat_system_token = os.getenv("PYRAT_SYSTEM_TOKEN") pyrat_user_token = os.getenv("PYRAT_USER_TOKEN") diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index f53a6350..b021d61d 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -81,6 +81,7 @@ def insert_new_params( params: dict, tracking_paramset_id: int = None, ): + """Insert a new set of parameters for a given tracking method.""" if tracking_paramset_id is None: tracking_paramset_id = ( dj.U().aggr(cls, n="max(tracking_paramset_id)").fetch1("n") or 0 @@ -152,6 +153,7 @@ class Part(dj.Part): @property def key_source(self): + """Return the keys to be processed.""" return ( acquisition.Chunk * ( @@ -166,6 +168,7 @@ def key_source(self): ) # SLEAP & CameraTop def make(self, key): + """Ingest SLEAP tracking data for a given chunk.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) From e7c05f46dc0c97ce6fa0727d024769cb03169e2b Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 16:36:31 +0000 Subject: [PATCH 025/143] fix: resolve D103 error add docstrings for public functions --- aeon/dj_pipeline/create_experiments/create_experiment_01.py | 4 ++++ aeon/dj_pipeline/create_experiments/create_experiment_02.py | 2 ++ aeon/dj_pipeline/create_experiments/create_octagon_1.py | 2 ++ aeon/dj_pipeline/create_experiments/create_presocial.py | 3 +++ .../dj_pipeline/create_experiments/create_socialexperiment.py | 1 + .../create_experiments/create_socialexperiment_0.py | 3 +++ aeon/dj_pipeline/populate/worker.py | 1 + aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py | 2 ++ aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py | 2 ++ aeon/dj_pipeline/scripts/update_timestamps_longblob.py | 2 ++ aeon/dj_pipeline/tracking.py | 3 +++ aeon/dj_pipeline/utils/load_metadata.py | 1 + aeon/dj_pipeline/utils/plotting.py | 1 + aeon/dj_pipeline/utils/streams_maker.py | 2 ++ 14 files changed, 29 insertions(+) diff --git a/aeon/dj_pipeline/create_experiments/create_experiment_01.py b/aeon/dj_pipeline/create_experiments/create_experiment_01.py index aa1f8675..5f765b3b 100644 --- a/aeon/dj_pipeline/create_experiments/create_experiment_01.py +++ b/aeon/dj_pipeline/create_experiments/create_experiment_01.py @@ -11,6 +11,7 @@ def ingest_exp01_metadata(metadata_yml_filepath, experiment_name): + """Ingest metadata from a yml file into the database for experiment 0.1.""" with open(metadata_yml_filepath) as f: arena_setup = yaml.full_load(f) @@ -181,6 +182,7 @@ def ingest_exp01_metadata(metadata_yml_filepath, experiment_name): def create_new_experiment(): + """Create a new experiment and add subjects to it.""" # ---------------- Subject ----------------- subject.Subject.insert( [ @@ -259,6 +261,7 @@ def create_new_experiment(): def add_arena_setup(): + """Add arena setup.""" # Arena Setup - Experiment Devices this_file = pathlib.Path(__file__).expanduser().absolute().resolve() metadata_yml_filepath = this_file.parent / "setup_yml" / "Experiment0.1.yml" @@ -286,6 +289,7 @@ def add_arena_setup(): def main(): + """Main function to create a new experiment and set up the arena.""" create_new_experiment() add_arena_setup() diff --git a/aeon/dj_pipeline/create_experiments/create_experiment_02.py b/aeon/dj_pipeline/create_experiments/create_experiment_02.py index ef3611ca..e2965e91 100644 --- a/aeon/dj_pipeline/create_experiments/create_experiment_02.py +++ b/aeon/dj_pipeline/create_experiments/create_experiment_02.py @@ -8,6 +8,7 @@ def create_new_experiment(): + """Create new experiment for experiment0.2""" # ---------------- Subject ----------------- subject_list = [ {"subject": "BAA-1100699", "sex": "U", "subject_birth_date": "2021-01-01"}, @@ -81,6 +82,7 @@ def create_new_experiment(): def main(): + """Main function to create a new experiment.""" create_new_experiment() diff --git a/aeon/dj_pipeline/create_experiments/create_octagon_1.py b/aeon/dj_pipeline/create_experiments/create_octagon_1.py index ae3e831d..56a24613 100644 --- a/aeon/dj_pipeline/create_experiments/create_octagon_1.py +++ b/aeon/dj_pipeline/create_experiments/create_octagon_1.py @@ -8,6 +8,7 @@ def create_new_experiment(): + """Create new experiment for octagon1.0""" # ---------------- Subject ----------------- # This will get replaced by content from colony.csv subject_list = [ @@ -62,6 +63,7 @@ def create_new_experiment(): def main(): + """Main function to create a new experiment.""" create_new_experiment() diff --git a/aeon/dj_pipeline/create_experiments/create_presocial.py b/aeon/dj_pipeline/create_experiments/create_presocial.py index 855f4d00..0676dd7f 100644 --- a/aeon/dj_pipeline/create_experiments/create_presocial.py +++ b/aeon/dj_pipeline/create_experiments/create_presocial.py @@ -9,6 +9,7 @@ def create_new_experiment(): + """Create new experiments for presocial0.1""" lab.Location.insert1({"lab": "SWC", "location": location}, skip_duplicates=True) acquisition.ExperimentType.insert1( @@ -55,6 +56,8 @@ def create_new_experiment(): def main(): + """Main function to create a new experiment.""" + create_new_experiment() diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py index 501f6893..6ece34b6 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py @@ -18,6 +18,7 @@ def create_new_social_experiment(experiment_name): + """Create new social experiment""" exp_name, machine_name = experiment_name.split("-") raw_dir = ceph_data_dir / "raw" / machine_name.upper() / exp_name if not raw_dir.exists(): diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py index 79c4cba0..0b4b021f 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py @@ -12,6 +12,7 @@ def create_new_experiment(): + """Create new experiments for social0-r1""" # ---------------- Subject ----------------- subject_list = [ {"subject": "BAA-1100704", "sex": "U", "subject_birth_date": "2021-01-01"}, @@ -86,6 +87,7 @@ def create_new_experiment(): def add_arena_setup(): + """Add arena setup.""" # Arena Setup - Experiment Devices this_file = pathlib.Path(__file__).expanduser().absolute().resolve() metadata_yml_filepath = this_file.parent / "setup_yml" / "SocialExperiment0.yml" @@ -113,6 +115,7 @@ def add_arena_setup(): def main(): + """Main function to create a new experiment and set up the arena.""" create_new_experiment() add_arena_setup() diff --git a/aeon/dj_pipeline/populate/worker.py b/aeon/dj_pipeline/populate/worker.py index 309eea36..35e93da6 100644 --- a/aeon/dj_pipeline/populate/worker.py +++ b/aeon/dj_pipeline/populate/worker.py @@ -114,6 +114,7 @@ def ingest_epochs_chunks(): def get_workflow_operation_overview(): + """Get the workflow operation overview for the worker schema.""" from datajoint_utilities.dj_worker.utils import get_workflow_operation_overview return get_workflow_operation_overview( diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py index 32a4ba2a..a85042b4 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py @@ -34,6 +34,7 @@ def clone_pipeline(): + """Clone the pipeline for experiment 0.1""" diagram = None for orig_schema_name in schema_name_mapper: virtual_module = dj.create_virtual_module(orig_schema_name, orig_schema_name) @@ -47,6 +48,7 @@ def clone_pipeline(): def data_copy(restriction, table_block_list, batch_size=None): + """Migrate schema.""" for orig_schema_name, cloned_schema_name in schema_name_mapper.items(): orig_schema = dj.create_virtual_module(orig_schema_name, orig_schema_name) cloned_schema = dj.create_virtual_module(cloned_schema_name, cloned_schema_name) diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py index 3e6d842a..4521b75d 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py @@ -38,6 +38,7 @@ def clone_pipeline(): + """Clone the pipeline for experiment 0.2""" diagram = None for orig_schema_name in schema_name_mapper: virtual_module = dj.create_virtual_module(orig_schema_name, orig_schema_name) @@ -51,6 +52,7 @@ def clone_pipeline(): def data_copy(restriction, table_block_list, batch_size=None): + """Migrate schema.""" for orig_schema_name, cloned_schema_name in schema_name_mapper.items(): orig_schema = dj.create_virtual_module(orig_schema_name, orig_schema_name) cloned_schema = dj.create_virtual_module(cloned_schema_name, cloned_schema_name) diff --git a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py index ba8b7fa5..18a89bc7 100644 --- a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py +++ b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py @@ -38,6 +38,7 @@ class TimestampFix(dj.Manual): def main(): + """Update all timestamps longblob fields in the specified schemas.""" for schema_name in schema_names: vm = dj.create_virtual_module(schema_name, schema_name) table_names = [ @@ -81,6 +82,7 @@ def main(): def get_table(schema_object, table_object_name): + """Get the table object from the schema object.""" if "." in table_object_name: master_name, part_name = table_object_name.split(".") return getattr(getattr(schema_object, master_name), part_name) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index b021d61d..e8333ad2 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -254,6 +254,7 @@ def make(self, key): def compute_distance(position_df, target, xcol="x", ycol="y"): + """Compute the distance of the position data from a target point.""" if len(target) != 2: raise ValueError("Target must be a list of tuple of length 2.") return np.sqrt(np.square(position_df[[xcol, ycol]] - target).sum(axis=1)) @@ -262,6 +263,7 @@ def compute_distance(position_df, target, xcol="x", ycol="y"): def is_position_in_patch( position_df, patch_position, wheel_distance_travelled, patch_radius=0.2 ) -> pd.Series: + """The function returns a boolean array indicating whether the position is inside the patch.""" distance_from_patch = compute_distance(position_df, patch_position) in_patch = distance_from_patch < patch_radius exit_patch = in_patch.astype(np.int8).diff() < 0 @@ -296,6 +298,7 @@ def _get_position( attrs_to_scale: list, scale_factor=1.0, ): + """Get the position data for a given object between the specified time range.""" obj_restriction = {object_attr: object_name} start_restriction = f'"{start}" BETWEEN {start_attr} AND {end_attr}' diff --git a/aeon/dj_pipeline/utils/load_metadata.py b/aeon/dj_pipeline/utils/load_metadata.py index bd8a2f14..7c21ad8b 100644 --- a/aeon/dj_pipeline/utils/load_metadata.py +++ b/aeon/dj_pipeline/utils/load_metadata.py @@ -390,6 +390,7 @@ def get_device_info(devices_schema: DotMap) -> dict[dict]: """ def _get_class_path(obj): + """Returns the class path of the object.""" return f"{obj.__class__.__module__}.{obj.__class__.__name__}" schema_json = json.dumps(devices_schema, default=lambda x: x.__dict__, indent=4) diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index 564dcca2..28be5b10 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -133,6 +133,7 @@ def plot_wheel_travelled_distance(session_keys): def plot_average_time_distribution(session_keys): + """Plotting the average time spent in different regions.""" subject_list, arena_location_list, avg_time_spent_list = [], [], [] # Time spent in arena and corridor diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index 734ad88d..c653d918 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -154,6 +154,7 @@ def key_source(self): ) def make(self, key): + """Load and insert the data for the DeviceDataStream table.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -204,6 +205,7 @@ def make(self, key): def main(create_tables=True): + """Main function to create and update stream-related tables in the analysis schema.""" if not _STREAMS_MODULE_FILE.exists(): with open(_STREAMS_MODULE_FILE, "w") as f: imports_str = ( From 10bb40fd20915812c6b0ddad777c84e4ec143da3 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 17:17:05 +0000 Subject: [PATCH 026/143] fix: resolve E501 error + black formatting with max length 105 --- aeon/analysis/block_plotting.py | 16 +- aeon/analysis/movies.py | 4 +- aeon/analysis/plotting.py | 8 +- aeon/analysis/utils.py | 26 +- aeon/dj_pipeline/__init__.py | 8 +- aeon/dj_pipeline/acquisition.py | 84 +-- aeon/dj_pipeline/analysis/block_analysis.py | 526 +++++------------- aeon/dj_pipeline/analysis/visit.py | 45 +- aeon/dj_pipeline/analysis/visit_analysis.py | 134 ++--- .../create_experiment_01.py | 40 +- .../create_experiment_02.py | 5 +- .../create_experiments/create_octagon_1.py | 5 +- .../create_experiments/create_presocial.py | 4 +- .../create_socialexperiment.py | 8 +- .../create_socialexperiment_0.py | 21 +- aeon/dj_pipeline/populate/process.py | 4 +- aeon/dj_pipeline/populate/worker.py | 4 +- aeon/dj_pipeline/qc.py | 25 +- aeon/dj_pipeline/report.py | 63 +-- .../scripts/clone_and_freeze_exp02.py | 3 +- .../scripts/update_timestamps_longblob.py | 12 +- aeon/dj_pipeline/streams.py | 145 ++--- aeon/dj_pipeline/subject.py | 50 +- aeon/dj_pipeline/tracking.py | 51 +- aeon/dj_pipeline/utils/load_metadata.py | 114 +--- aeon/dj_pipeline/utils/paths.py | 3 +- aeon/dj_pipeline/utils/plotting.py | 133 ++--- aeon/dj_pipeline/utils/streams_maker.py | 51 +- aeon/dj_pipeline/utils/video.py | 4 +- aeon/io/api.py | 24 +- aeon/io/reader.py | 82 +-- aeon/io/video.py | 8 +- aeon/schema/foraging.py | 12 +- aeon/schema/octagon.py | 24 +- aeon/schema/social_02.py | 16 +- aeon/schema/social_03.py | 4 +- tests/dj_pipeline/conftest.py | 8 +- tests/dj_pipeline/test_acquisition.py | 29 +- .../test_pipeline_instantiation.py | 27 +- tests/dj_pipeline/test_tracking.py | 12 +- tests/io/test_api.py | 20 +- 41 files changed, 542 insertions(+), 1320 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index ac002814..61875d2e 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -29,14 +29,10 @@ def gen_hex_grad(hex_col, vals, min_l=0.3): """Generates an array of hex color values based on a gradient defined by unit-normalized values.""" # Convert hex to rgb to hls - h, l, s = rgb_to_hls( - *[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)] - ) # noqa: E741 + h, l, s = rgb_to_hls(*[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)]) # noqa: E741 grad = np.empty(shape=(len(vals),), dtype=" 1).reindex( - in_patch.index, method="pad" - ) + in_wheel = (wheel.diff().rolling("1s").sum() > 1).reindex(in_patch.index, method="pad") epochs = exit_patch.cumsum() return in_wheel.groupby(epochs).apply(lambda x: x.cumsum()) > 0 diff --git a/aeon/dj_pipeline/__init__.py b/aeon/dj_pipeline/__init__.py index 44e0d498..27bff7c8 100644 --- a/aeon/dj_pipeline/__init__.py +++ b/aeon/dj_pipeline/__init__.py @@ -15,9 +15,7 @@ db_prefix = dj.config["custom"].get("database.prefix", _default_database_prefix) -repository_config = dj.config["custom"].get( - "repository_config", _default_repository_config -) +repository_config = dj.config["custom"].get("repository_config", _default_repository_config) def get_schema_name(name) -> str: @@ -42,9 +40,7 @@ def fetch_stream(query, drop_pk=True): """ df = (query & "sample_count > 0").fetch(format="frame").reset_index() cols2explode = [ - c - for c in query.heading.secondary_attributes - if query.heading.attributes[c].type == "longblob" + c for c in query.heading.secondary_attributes if query.heading.attributes[c].type == "longblob" ] df = df.explode(column=cols2explode) cols2drop = ["sample_count"] + (query.primary_key if drop_pk else []) diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index 6211bfd8..cd427928 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -167,10 +167,7 @@ def get_data_directories(cls, experiment_key, directory_types=None, as_posix=Fal return [ d for dir_type in directory_types - if ( - d := cls.get_data_directory(experiment_key, dir_type, as_posix=as_posix) - ) - is not None + if (d := cls.get_data_directory(experiment_key, dir_type, as_posix=as_posix)) is not None ] @@ -198,9 +195,7 @@ def ingest_epochs(cls, experiment_name): for i, (_, chunk) in enumerate(all_chunks.iterrows()): chunk_rep_file = pathlib.Path(chunk.path) epoch_dir = pathlib.Path(chunk_rep_file.as_posix().split(device_name)[0]) - epoch_start = datetime.datetime.strptime( - epoch_dir.name, "%Y-%m-%dT%H-%M-%S" - ) + epoch_start = datetime.datetime.strptime(epoch_dir.name, "%Y-%m-%dT%H-%M-%S") # --- insert to Epoch --- epoch_key = {"experiment_name": experiment_name, "epoch_start": epoch_start} @@ -219,15 +214,11 @@ def ingest_epochs(cls, experiment_name): if i > 0: previous_chunk = all_chunks.iloc[i - 1] previous_chunk_path = pathlib.Path(previous_chunk.path) - previous_epoch_dir = pathlib.Path( - previous_chunk_path.as_posix().split(device_name)[0] - ) + previous_epoch_dir = pathlib.Path(previous_chunk_path.as_posix().split(device_name)[0]) previous_epoch_start = datetime.datetime.strptime( previous_epoch_dir.name, "%Y-%m-%dT%H-%M-%S" ) - previous_chunk_end = previous_chunk.name + datetime.timedelta( - hours=io_api.CHUNK_DURATION - ) + previous_chunk_end = previous_chunk.name + datetime.timedelta(hours=io_api.CHUNK_DURATION) previous_epoch_end = min(previous_chunk_end, epoch_start) previous_epoch_key = { "experiment_name": experiment_name, @@ -256,9 +247,7 @@ def ingest_epochs(cls, experiment_name): { **previous_epoch_key, "epoch_end": previous_epoch_end, - "epoch_duration": ( - previous_epoch_end - previous_epoch_start - ).total_seconds() + "epoch_duration": (previous_epoch_end - previous_epoch_start).total_seconds() / 3600, } ) @@ -331,23 +320,17 @@ def make(self, key): experiment_name = key["experiment_name"] devices_schema = getattr( aeon_schemas, - (Experiment.DevicesSchema & {"experiment_name": experiment_name}).fetch1( - "devices_schema_name" - ), + (Experiment.DevicesSchema & {"experiment_name": experiment_name}).fetch1("devices_schema_name"), ) dir_type, epoch_dir = (Epoch & key).fetch1("directory_type", "epoch_dir") data_dir = Experiment.get_data_directory(key, dir_type) metadata_yml_filepath = data_dir / epoch_dir / "Metadata.yml" - epoch_config = extract_epoch_config( - experiment_name, devices_schema, metadata_yml_filepath - ) + epoch_config = extract_epoch_config(experiment_name, devices_schema, metadata_yml_filepath) epoch_config = { **epoch_config, - "metadata_file_path": metadata_yml_filepath.relative_to( - data_dir - ).as_posix(), + "metadata_file_path": metadata_yml_filepath.relative_to(data_dir).as_posix(), } # Insert new entries for streams.DeviceType, streams.Device. @@ -358,20 +341,15 @@ def make(self, key): # Define and instantiate new devices/stream tables under `streams` schema streams_maker.main() # Insert devices' installation/removal/settings - epoch_device_types = ingest_epoch_metadata( - experiment_name, devices_schema, metadata_yml_filepath - ) + epoch_device_types = ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath) self.insert1(key) self.Meta.insert1(epoch_config) - self.DeviceType.insert( - key | {"device_type": n} for n in epoch_device_types or {} - ) + self.DeviceType.insert(key | {"device_type": n} for n in epoch_device_types or {}) with metadata_yml_filepath.open("r") as f: metadata = json.load(f) self.ActiveRegion.insert( - {**key, "region_name": k, "region_data": v} - for k, v in metadata["ActiveRegion"].items() + {**key, "region_name": k, "region_data": v} for k, v in metadata["ActiveRegion"].items() ) @@ -410,9 +388,7 @@ def ingest_chunks(cls, experiment_name): for _, chunk in all_chunks.iterrows(): chunk_rep_file = pathlib.Path(chunk.path) epoch_dir = pathlib.Path(chunk_rep_file.as_posix().split(device_name)[0]) - epoch_start = datetime.datetime.strptime( - epoch_dir.name, "%Y-%m-%dT%H-%M-%S" - ) + epoch_start = datetime.datetime.strptime(epoch_dir.name, "%Y-%m-%dT%H-%M-%S") epoch_key = {"experiment_name": experiment_name, "epoch_start": epoch_start} if not (Epoch & epoch_key): @@ -420,9 +396,7 @@ def ingest_chunks(cls, experiment_name): continue chunk_start = chunk.name - chunk_start = max( - chunk_start, epoch_start - ) # first chunk of the epoch starts at epoch_start + chunk_start = max(chunk_start, epoch_start) # first chunk of the epoch starts at epoch_start chunk_end = chunk_start + datetime.timedelta(hours=io_api.CHUNK_DURATION) if EpochEnd & epoch_key: @@ -442,12 +416,8 @@ def ingest_chunks(cls, experiment_name): ) chunk_starts.append(chunk_key["chunk_start"]) - chunk_list.append( - {**chunk_key, **directory, "chunk_end": chunk_end, **epoch_key} - ) - file_name_list.append( - chunk_rep_file.name - ) # handle duplicated files in different folders + chunk_list.append({**chunk_key, **directory, "chunk_end": chunk_end, **epoch_key}) + file_name_list.append(chunk_rep_file.name) # handle duplicated files in different folders # -- files -- file_datetime_str = chunk_rep_file.stem.replace(f"{device_name}_", "") @@ -564,9 +534,9 @@ def make(self, key): data_dirs = Experiment.get_data_directories(key) devices_schema = getattr( aeon_schemas, - ( - Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) device = devices_schema.Environment @@ -626,14 +596,12 @@ def make(self, key): data_dirs = Experiment.get_data_directories(key) devices_schema = getattr( aeon_schemas, - ( - Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) device = devices_schema.Environment - stream_reader = ( - device.EnvironmentActiveConfiguration - ) # expecting columns: time, name, value + stream_reader = device.EnvironmentActiveConfiguration # expecting columns: time, name, value stream_data = io_api.load( root=data_dirs, reader=stream_reader, @@ -666,9 +634,7 @@ def _get_all_chunks(experiment_name, device_name): raw_data_dirs = {k: v for k, v in raw_data_dirs.items() if v} if not raw_data_dirs: - raise ValueError( - f"No raw data directory found for experiment: {experiment_name}" - ) + raise ValueError(f"No raw data directory found for experiment: {experiment_name}") chunkdata = io_api.load( root=list(raw_data_dirs.values()), @@ -690,9 +656,7 @@ def _match_experiment_directory(experiment_name, path, directories): repo_path = paths.get_repository_path(directory.pop("repository_name")) break else: - raise FileNotFoundError( - f"Unable to identify the directory" f" where this chunk is from: {path}" - ) + raise FileNotFoundError(f"Unable to identify the directory" f" where this chunk is from: {path}") return raw_data_dir, directory, repo_path diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 8a7a05fb..ed752644 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -67,18 +67,14 @@ def make(self, key): # find the 0s in `pellet_ct` (these are times when the pellet count reset - i.e. new block) # that would mark the start of a new block - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") exp_key = {"experiment_name": key["experiment_name"]} chunk_restriction = acquisition.create_chunk_restriction( key["experiment_name"], chunk_start, chunk_end ) - block_state_query = ( - acquisition.Environment.BlockState & exp_key & chunk_restriction - ) + block_state_query = acquisition.Environment.BlockState & exp_key & chunk_restriction block_state_df = fetch_stream(block_state_query) if block_state_df.empty: self.insert1(key) @@ -101,12 +97,8 @@ def make(self, key): block_entries = [] if not blocks_df.empty: # calculate block end_times (use due_time) and durations - blocks_df["end_time"] = blocks_df["due_time"].apply( - lambda x: io_api.aeon(x) - ) - blocks_df["duration"] = ( - blocks_df["end_time"] - blocks_df.index - ).dt.total_seconds() / 3600 + blocks_df["end_time"] = blocks_df["due_time"].apply(lambda x: io_api.aeon(x)) + blocks_df["duration"] = (blocks_df["end_time"] - blocks_df.index).dt.total_seconds() / 3600 for _, row in blocks_df.iterrows(): block_entries.append( @@ -196,9 +188,7 @@ def make(self, key): tracking.SLEAPTracking, ) for streams_table in streams_tables: - if len(streams_table & chunk_keys) < len( - streams_table.key_source & chunk_keys - ): + if len(streams_table & chunk_keys) < len(streams_table.key_source & chunk_keys): raise ValueError( f"BlockAnalysis Not Ready - {streams_table.__name__} not yet fully ingested for block: {key}. Skipping (to retry later)..." ) @@ -207,14 +197,10 @@ def make(self, key): # For wheel data, downsample to 10Hz final_encoder_fs = 10 - maintenance_period = get_maintenance_periods( - key["experiment_name"], block_start, block_end - ) + maintenance_period = get_maintenance_periods(key["experiment_name"], block_start, block_end) patch_query = ( - streams.UndergroundFeeder.join( - streams.UndergroundFeeder.RemovalTime, left=True - ) + streams.UndergroundFeeder.join(streams.UndergroundFeeder.RemovalTime, left=True) & key & f'"{block_start}" >= underground_feeder_install_time' & f'"{block_end}" < IFNULL(underground_feeder_removal_time, "2200-01-01")' @@ -228,14 +214,12 @@ def make(self, key): streams.UndergroundFeederDepletionState & patch_key & chunk_restriction )[block_start:block_end] - pellet_ts_threshold_df = get_threshold_associated_pellets( - patch_key, block_start, block_end - ) + pellet_ts_threshold_df = get_threshold_associated_pellets(patch_key, block_start, block_end) # wheel encoder data - encoder_df = fetch_stream( - streams.UndergroundFeederEncoder & patch_key & chunk_restriction - )[block_start:block_end] + encoder_df = fetch_stream(streams.UndergroundFeederEncoder & patch_key & chunk_restriction)[ + block_start:block_end + ] # filter out maintenance period based on logs pellet_ts_threshold_df = filter_out_maintenance_periods( pellet_ts_threshold_df, @@ -254,13 +238,9 @@ def make(self, key): ) if depletion_state_df.empty: - raise ValueError( - f"No depletion state data found for block {key} - patch: {patch_name}" - ) + raise ValueError(f"No depletion state data found for block {key} - patch: {patch_name}") - encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled( - encoder_df.angle - ) + encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled(encoder_df.angle) if len(depletion_state_df.rate.unique()) > 1: # multiple patch rates per block is unexpected, log a note and pick the first rate to move forward @@ -291,9 +271,7 @@ def make(self, key): "wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values[ ::wheel_downsampling_factor ], - "wheel_timestamps": encoder_df.index.values[ - ::wheel_downsampling_factor - ], + "wheel_timestamps": encoder_df.index.values[::wheel_downsampling_factor], "patch_threshold": pellet_ts_threshold_df.threshold.values, "patch_threshold_timestamps": pellet_ts_threshold_df.index.values, "patch_rate": patch_rate, @@ -325,9 +303,7 @@ def make(self, key): # positions - query for CameraTop, identity_name matches subject_name, pos_query = ( streams.SpinnakerVideoSource - * tracking.SLEAPTracking.PoseIdentity.proj( - "identity_name", part_name="anchor_part" - ) + * tracking.SLEAPTracking.PoseIdentity.proj("identity_name", part_name="anchor_part") * tracking.SLEAPTracking.Part & key & { @@ -337,23 +313,18 @@ def make(self, key): & chunk_restriction ) pos_df = fetch_stream(pos_query)[block_start:block_end] - pos_df = filter_out_maintenance_periods( - pos_df, maintenance_period, block_end - ) + pos_df = filter_out_maintenance_periods(pos_df, maintenance_period, block_end) if pos_df.empty: continue position_diff = np.sqrt( - np.square(np.diff(pos_df.x.astype(float))) - + np.square(np.diff(pos_df.y.astype(float))) + np.square(np.diff(pos_df.x.astype(float))) + np.square(np.diff(pos_df.y.astype(float))) ) cumsum_distance_travelled = np.concatenate([[0], np.cumsum(position_diff)]) # weights - weight_query = ( - acquisition.Environment.SubjectWeight & key & chunk_restriction - ) + weight_query = acquisition.Environment.SubjectWeight & key & chunk_restriction weight_df = fetch_stream(weight_query)[block_start:block_end] weight_df.query(f"subject_id == '{subject_name}'", inplace=True) @@ -441,10 +412,7 @@ def make(self, key): subjects_positions_df = pd.concat( [ pd.DataFrame( - { - "subject_name": [s["subject_name"]] - * len(s["position_timestamps"]) - } + {"subject_name": [s["subject_name"]] * len(s["position_timestamps"])} | { k: s[k] for k in ( @@ -472,8 +440,7 @@ def make(self, key): "cum_pref_time", ] all_subj_patch_pref_dict = { - p: {s: {a: pd.Series() for a in pref_attrs} for s in subject_names} - for p in patch_names + p: {s: {a: pd.Series() for a in pref_attrs} for s in subject_names} for p in patch_names } for patch in block_patches: @@ -496,15 +463,11 @@ def make(self, key): ).fetch1("attribute_value") patch_center = (int(patch_center["X"]), int(patch_center["Y"])) subjects_xy = subjects_positions_df[["position_x", "position_y"]].values - dist_to_patch = np.sqrt( - np.sum((subjects_xy - patch_center) ** 2, axis=1).astype(float) - ) + dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_center) ** 2, axis=1).astype(float)) dist_to_patch_df = subjects_positions_df[["subject_name"]].copy() dist_to_patch_df["dist_to_patch"] = dist_to_patch - dist_to_patch_wheel_ts_id_df = pd.DataFrame( - index=cum_wheel_dist.index, columns=subject_names - ) + dist_to_patch_wheel_ts_id_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subject_names) dist_to_patch_pel_ts_id_df = pd.DataFrame( index=patch["pellet_timestamps"], columns=subject_names ) @@ -512,12 +475,10 @@ def make(self, key): # Find closest match between pose_df indices and wheel indices if not dist_to_patch_wheel_ts_id_df.empty: dist_to_patch_wheel_ts_subj = pd.merge_asof( - left=pd.DataFrame( - dist_to_patch_wheel_ts_id_df[subject_name].copy() - ).reset_index(names="time"), - right=dist_to_patch_df[ - dist_to_patch_df["subject_name"] == subject_name - ] + left=pd.DataFrame(dist_to_patch_wheel_ts_id_df[subject_name].copy()).reset_index( + names="time" + ), + right=dist_to_patch_df[dist_to_patch_df["subject_name"] == subject_name] .copy() .reset_index(names="time"), on="time", @@ -526,18 +487,16 @@ def make(self, key): direction="nearest", tolerance=pd.Timedelta("100ms"), ) - dist_to_patch_wheel_ts_id_df[subject_name] = ( - dist_to_patch_wheel_ts_subj["dist_to_patch"].values - ) + dist_to_patch_wheel_ts_id_df[subject_name] = dist_to_patch_wheel_ts_subj[ + "dist_to_patch" + ].values # Find closest match between pose_df indices and pel indices if not dist_to_patch_pel_ts_id_df.empty: dist_to_patch_pel_ts_subj = pd.merge_asof( - left=pd.DataFrame( - dist_to_patch_pel_ts_id_df[subject_name].copy() - ).reset_index(names="time"), - right=dist_to_patch_df[ - dist_to_patch_df["subject_name"] == subject_name - ] + left=pd.DataFrame(dist_to_patch_pel_ts_id_df[subject_name].copy()).reset_index( + names="time" + ), + right=dist_to_patch_df[dist_to_patch_df["subject_name"] == subject_name] .copy() .reset_index(names="time"), on="time", @@ -546,9 +505,9 @@ def make(self, key): direction="nearest", tolerance=pd.Timedelta("200ms"), ) - dist_to_patch_pel_ts_id_df[subject_name] = ( - dist_to_patch_pel_ts_subj["dist_to_patch"].values - ) + dist_to_patch_pel_ts_id_df[subject_name] = dist_to_patch_pel_ts_subj[ + "dist_to_patch" + ].values # Get closest subject to patch at each pellet timestep closest_subjects_pellet_ts = dist_to_patch_pel_ts_id_df.idxmin(axis=1) @@ -560,12 +519,8 @@ def make(self, key): wheel_dist = cum_wheel_dist.diff().fillna(cum_wheel_dist.iloc[0]) # Assign wheel dist to closest subject for each wheel timestep for subject_name in subject_names: - subj_idxs = cum_wheel_dist_subj_df[ - closest_subjects_wheel_ts == subject_name - ].index - cum_wheel_dist_subj_df.loc[subj_idxs, subject_name] = wheel_dist[ - subj_idxs - ] + subj_idxs = cum_wheel_dist_subj_df[closest_subjects_wheel_ts == subject_name].index + cum_wheel_dist_subj_df.loc[subj_idxs, subject_name] = wheel_dist[subj_idxs] cum_wheel_dist_subj_df = cum_wheel_dist_subj_df.cumsum(axis=0) # In patch time @@ -573,9 +528,9 @@ def make(self, key): dt = np.median(np.diff(cum_wheel_dist.index)).astype(int) / 1e9 # s # Fill in `all_subj_patch_pref` for subject_name in subject_names: - all_subj_patch_pref_dict[patch["patch_name"]][subject_name][ - "cum_dist" - ] = cum_wheel_dist_subj_df[subject_name].values + all_subj_patch_pref_dict[patch["patch_name"]][subject_name]["cum_dist"] = ( + cum_wheel_dist_subj_df[subject_name].values + ) subject_in_patch = in_patch[subject_name] subject_in_patch_cum_time = subject_in_patch.cumsum().values * dt all_subj_patch_pref_dict[patch["patch_name"]][subject_name][ @@ -596,9 +551,7 @@ def make(self, key): "pellet_count": len(subj_pellets), "pellet_timestamps": subj_pellets.index.values, "patch_threshold": subj_patch_thresh, - "wheel_cumsum_distance_travelled": cum_wheel_dist_subj_df[ - subject_name - ].values, + "wheel_cumsum_distance_travelled": cum_wheel_dist_subj_df[subject_name].values, } ) @@ -607,72 +560,46 @@ def make(self, key): for subject_name in subject_names: # Get sum of subj cum wheel dists and cum in patch time all_cum_dist = np.sum( - [ - all_subj_patch_pref_dict[p][subject_name]["cum_dist"][-1] - for p in patch_names - ] + [all_subj_patch_pref_dict[p][subject_name]["cum_dist"][-1] for p in patch_names] ) all_cum_time = np.sum( - [ - all_subj_patch_pref_dict[p][subject_name]["cum_time"][-1] - for p in patch_names - ] + [all_subj_patch_pref_dict[p][subject_name]["cum_time"][-1] for p in patch_names] ) for patch_name in patch_names: cum_pref_dist = ( - all_subj_patch_pref_dict[patch_name][subject_name]["cum_dist"] - / all_cum_dist + all_subj_patch_pref_dict[patch_name][subject_name]["cum_dist"] / all_cum_dist ) cum_pref_dist = np.where(cum_pref_dist < 1e-3, 0, cum_pref_dist) - all_subj_patch_pref_dict[patch_name][subject_name][ - "cum_pref_dist" - ] = cum_pref_dist + all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_dist"] = cum_pref_dist cum_pref_time = ( - all_subj_patch_pref_dict[patch_name][subject_name]["cum_time"] - / all_cum_time + all_subj_patch_pref_dict[patch_name][subject_name]["cum_time"] / all_cum_time ) - all_subj_patch_pref_dict[patch_name][subject_name][ - "cum_pref_time" - ] = cum_pref_time + all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_time"] = cum_pref_time # sum pref at each ts across patches for each subject total_dist_pref = np.sum( np.vstack( - [ - all_subj_patch_pref_dict[p][subject_name]["cum_pref_dist"] - for p in patch_names - ] + [all_subj_patch_pref_dict[p][subject_name]["cum_pref_dist"] for p in patch_names] ), axis=0, ) total_time_pref = np.sum( np.vstack( - [ - all_subj_patch_pref_dict[p][subject_name]["cum_pref_time"] - for p in patch_names - ] + [all_subj_patch_pref_dict[p][subject_name]["cum_pref_time"] for p in patch_names] ), axis=0, ) for patch_name in patch_names: - cum_pref_dist = all_subj_patch_pref_dict[patch_name][subject_name][ - "cum_pref_dist" - ] - all_subj_patch_pref_dict[patch_name][subject_name][ - "running_dist_pref" - ] = np.divide( + cum_pref_dist = all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_dist"] + all_subj_patch_pref_dict[patch_name][subject_name]["running_dist_pref"] = np.divide( cum_pref_dist, total_dist_pref, out=np.zeros_like(cum_pref_dist), where=total_dist_pref != 0, ) - cum_pref_time = all_subj_patch_pref_dict[patch_name][subject_name][ - "cum_pref_time" - ] - all_subj_patch_pref_dict[patch_name][subject_name][ - "running_time_pref" - ] = np.divide( + cum_pref_time = all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_time"] + all_subj_patch_pref_dict[patch_name][subject_name]["running_time_pref"] = np.divide( cum_pref_time, total_time_pref, out=np.zeros_like(cum_pref_time), @@ -684,24 +611,12 @@ def make(self, key): | { "patch_name": p, "subject_name": s, - "cumulative_preference_by_time": all_subj_patch_pref_dict[p][s][ - "cum_pref_time" - ], - "cumulative_preference_by_wheel": all_subj_patch_pref_dict[p][s][ - "cum_pref_dist" - ], - "running_preference_by_time": all_subj_patch_pref_dict[p][s][ - "running_time_pref" - ], - "running_preference_by_wheel": all_subj_patch_pref_dict[p][s][ - "running_dist_pref" - ], - "final_preference_by_time": all_subj_patch_pref_dict[p][s][ - "cum_pref_time" - ][-1], - "final_preference_by_wheel": all_subj_patch_pref_dict[p][s][ - "cum_pref_dist" - ][-1], + "cumulative_preference_by_time": all_subj_patch_pref_dict[p][s]["cum_pref_time"], + "cumulative_preference_by_wheel": all_subj_patch_pref_dict[p][s]["cum_pref_dist"], + "running_preference_by_time": all_subj_patch_pref_dict[p][s]["running_time_pref"], + "running_preference_by_wheel": all_subj_patch_pref_dict[p][s]["running_dist_pref"], + "final_preference_by_time": all_subj_patch_pref_dict[p][s]["cum_pref_time"][-1], + "final_preference_by_wheel": all_subj_patch_pref_dict[p][s]["cum_pref_dist"][-1], } for p, s in itertools.product(patch_names, subject_names) ) @@ -725,9 +640,7 @@ class BlockPatchPlots(dj.Computed): def make(self, key): """Compute and plot various block-level statistics and visualizations.""" # Define subject colors and patch styling for plotting - exp_subject_names = (acquisition.Experiment.Subject & key).fetch( - "subject", order_by="subject" - ) + exp_subject_names = (acquisition.Experiment.Subject & key).fetch("subject", order_by="subject") if not len(exp_subject_names): raise ValueError( "No subjects found in the `acquisition.Experiment.Subject`, missing a manual insert step?." @@ -746,10 +659,7 @@ def make(self, key): # Figure 1 - Patch stats: patch means and pellet threshold boxplots # --- subj_patch_info = ( - ( - BlockSubjectAnalysis.Patch.proj("pellet_timestamps", "patch_threshold") - & key - ) + (BlockSubjectAnalysis.Patch.proj("pellet_timestamps", "patch_threshold") & key) .fetch(format="frame") .reset_index() ) @@ -763,46 +673,28 @@ def make(self, key): ["patch_name", "subject_name", "pellet_timestamps", "patch_threshold"] ] min_subj_patch_info = ( - min_subj_patch_info.explode( - ["pellet_timestamps", "patch_threshold"], ignore_index=True - ) + min_subj_patch_info.explode(["pellet_timestamps", "patch_threshold"], ignore_index=True) .dropna() .reset_index(drop=True) ) # Rename and reindex columns min_subj_patch_info.columns = ["patch", "subject", "time", "threshold"] - min_subj_patch_info = min_subj_patch_info.reindex( - columns=["time", "patch", "threshold", "subject"] - ) + min_subj_patch_info = min_subj_patch_info.reindex(columns=["time", "patch", "threshold", "subject"]) # Add patch mean values and block-normalized delivery times to pellet info n_patches = len(patch_info) - patch_mean_info = pd.DataFrame( - index=np.arange(n_patches), columns=min_subj_patch_info.columns - ) + patch_mean_info = pd.DataFrame(index=np.arange(n_patches), columns=min_subj_patch_info.columns) patch_mean_info["subject"] = "mean" patch_mean_info["patch"] = [d["patch_name"] for d in patch_info] - patch_mean_info["threshold"] = [ - ((1 / d["patch_rate"]) + d["patch_offset"]) for d in patch_info - ] + patch_mean_info["threshold"] = [((1 / d["patch_rate"]) + d["patch_offset"]) for d in patch_info] patch_mean_info["time"] = subj_patch_info["block_start"][0] - min_subj_patch_info_plus = pd.concat( - (patch_mean_info, min_subj_patch_info) - ).reset_index(drop=True) + min_subj_patch_info_plus = pd.concat((patch_mean_info, min_subj_patch_info)).reset_index(drop=True) min_subj_patch_info_plus["norm_time"] = ( - ( - min_subj_patch_info_plus["time"] - - min_subj_patch_info_plus["time"].iloc[0] - ) - / ( - min_subj_patch_info_plus["time"].iloc[-1] - - min_subj_patch_info_plus["time"].iloc[0] - ) + (min_subj_patch_info_plus["time"] - min_subj_patch_info_plus["time"].iloc[0]) + / (min_subj_patch_info_plus["time"].iloc[-1] - min_subj_patch_info_plus["time"].iloc[0]) ).round(3) # Plot it - box_colors = ["#0A0A0A"] + list( - subject_colors_dict.values() - ) # subject colors + mean color + box_colors = ["#0A0A0A"] + list(subject_colors_dict.values()) # subject colors + mean color patch_stats_fig = px.box( min_subj_patch_info_plus.sort_values("patch"), x="patch", @@ -832,9 +724,7 @@ def make(self, key): .dropna() .reset_index(drop=True) ) - weights_block.drop( - columns=["experiment_name", "block_start"], inplace=True, errors="ignore" - ) + weights_block.drop(columns=["experiment_name", "block_start"], inplace=True, errors="ignore") weights_block.rename(columns={"weight_timestamps": "time"}, inplace=True) weights_block.set_index("time", inplace=True) weights_block.sort_index(inplace=True) @@ -858,17 +748,13 @@ def make(self, key): # Figure 3 - Cumulative pellet count: over time, per subject, markered by patch # --- # Create dataframe with cumulative pellet count per subject - cum_pel_ct = ( - min_subj_patch_info_plus.sort_values("time").copy().reset_index(drop=True) - ) + cum_pel_ct = min_subj_patch_info_plus.sort_values("time").copy().reset_index(drop=True) patch_means = cum_pel_ct.loc[0:3][["patch", "threshold"]].rename( columns={"threshold": "mean_thresh"} ) patch_means["mean_thresh"] = patch_means["mean_thresh"].astype(float).round(1) cum_pel_ct = cum_pel_ct.merge(patch_means, on="patch", how="left") - cum_pel_ct = cum_pel_ct[ - ~cum_pel_ct["subject"].str.contains("mean") - ].reset_index(drop=True) + cum_pel_ct = cum_pel_ct[~cum_pel_ct["subject"].str.contains("mean")].reset_index(drop=True) cum_pel_ct = ( cum_pel_ct.groupby("subject", group_keys=False) .apply(lambda group: group.assign(counter=np.arange(len(group)) + 1)) @@ -878,9 +764,7 @@ def make(self, key): make_float_cols = ["threshold", "mean_thresh", "norm_time"] cum_pel_ct[make_float_cols] = cum_pel_ct[make_float_cols].astype(float) cum_pel_ct["patch_label"] = ( - cum_pel_ct["patch"] - + " μ: " - + cum_pel_ct["mean_thresh"].astype(float).round(1).astype(str) + cum_pel_ct["patch"] + " μ: " + cum_pel_ct["mean_thresh"].astype(float).round(1).astype(str) ) cum_pel_ct["norm_thresh_val"] = ( (cum_pel_ct["threshold"] - cum_pel_ct["threshold"].min()) @@ -910,9 +794,7 @@ def make(self, key): mode="markers", marker={ "symbol": patch_markers_dict[patch_grp["patch"].iloc[0]], - "color": gen_hex_grad( - pel_mrkr_col, patch_grp["norm_thresh_val"] - ), + "color": gen_hex_grad(pel_mrkr_col, patch_grp["norm_thresh_val"]), "size": 8, }, name=patch_val, @@ -932,9 +814,7 @@ def make(self, key): cum_pel_per_subject_fig = go.Figure() for id_val, id_grp in cum_pel_ct.groupby("subject"): for patch_val, patch_grp in id_grp.groupby("patch"): - cur_p_mean = patch_means[patch_means["patch"] == patch_val][ - "mean_thresh" - ].values[0] + cur_p_mean = patch_means[patch_means["patch"] == patch_val]["mean_thresh"].values[0] cur_p = patch_val.replace("Patch", "P") cum_pel_per_subject_fig.add_trace( go.Scatter( @@ -949,9 +829,7 @@ def make(self, key): # line=dict(width=2, color=subject_colors_dict[id_val]), marker={ "symbol": patch_markers_dict[patch_val], - "color": gen_hex_grad( - pel_mrkr_col, patch_grp["norm_thresh_val"] - ), + "color": gen_hex_grad(pel_mrkr_col, patch_grp["norm_thresh_val"]), "size": 8, }, name=f"{id_val} - {cur_p} - μ: {cur_p_mean}", @@ -968,9 +846,7 @@ def make(self, key): # Figure 5 - Cumulative wheel distance: over time, per subject-patch # --- # Get wheel timestamps for each patch - wheel_ts = (BlockAnalysis.Patch & key).fetch( - "patch_name", "wheel_timestamps", as_dict=True - ) + wheel_ts = (BlockAnalysis.Patch & key).fetch("patch_name", "wheel_timestamps", as_dict=True) wheel_ts = {d["patch_name"]: d["wheel_timestamps"] for d in wheel_ts} # Get subject patch data subj_wheel_cumsum_dist = (BlockSubjectAnalysis.Patch & key).fetch( @@ -990,9 +866,7 @@ def make(self, key): for subj in subject_names: for patch_name in patch_names: cur_cum_wheel_dist = subj_wheel_cumsum_dist[(subj, patch_name)] - cur_p_mean = patch_means[patch_means["patch"] == patch_name][ - "mean_thresh" - ].values[0] + cur_p_mean = patch_means[patch_means["patch"] == patch_name]["mean_thresh"].values[0] cur_p = patch_name.replace("Patch", "P") cum_wheel_dist_fig.add_trace( go.Scatter( @@ -1009,10 +883,7 @@ def make(self, key): ) # Add markers for each pellet cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[ - (cum_pel_ct["subject"] == subj) - & (cum_pel_ct["patch"] == patch_name) - ], + cum_pel_ct[(cum_pel_ct["subject"] == subj) & (cum_pel_ct["patch"] == patch_name)], pd.DataFrame( { "time": wheel_ts[patch_name], @@ -1031,15 +902,11 @@ def make(self, key): mode="markers", marker={ "symbol": patch_markers_dict[patch_name], - "color": gen_hex_grad( - pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"] - ), + "color": gen_hex_grad(pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"]), "size": 8, }, name=f"{subj} - {cur_p} pellets", - customdata=np.stack( - (cur_cum_pel_ct["threshold"],), axis=-1 - ), + customdata=np.stack((cur_cum_pel_ct["threshold"],), axis=-1), hovertemplate="Threshold: %{customdata[0]:.2f} cm", ) ) @@ -1053,14 +920,10 @@ def make(self, key): # --- # Get and format a dataframe with preference data patch_pref = (BlockSubjectAnalysis.Preference & key).fetch(format="frame") - patch_pref.reset_index( - level=["experiment_name", "block_start"], drop=True, inplace=True - ) + patch_pref.reset_index(level=["experiment_name", "block_start"], drop=True, inplace=True) # Replace small vals with 0 small_pref_thresh = 1e-3 - patch_pref["cumulative_preference_by_wheel"] = patch_pref[ - "cumulative_preference_by_wheel" - ].apply( + patch_pref["cumulative_preference_by_wheel"] = patch_pref["cumulative_preference_by_wheel"].apply( lambda arr: np.where(np.array(arr) < small_pref_thresh, 0, np.array(arr)) ) @@ -1068,9 +931,7 @@ def calculate_running_preference(group, pref_col, out_col): # Sum pref at each ts total_pref = np.sum(np.vstack(group[pref_col].values), axis=0) # Calculate running pref - group[out_col] = group[pref_col].apply( - lambda x: np.nan_to_num(x / total_pref, 0.0) - ) + group[out_col] = group[pref_col].apply(lambda x: np.nan_to_num(x / total_pref, 0.0)) return group patch_pref = ( @@ -1099,12 +960,8 @@ def calculate_running_preference(group, pref_col, out_col): # Add trace for each subject-patch combo for subj in subject_names: for patch_name in patch_names: - cur_run_wheel_pref = patch_pref.loc[patch_name].loc[subj][ - "running_preference_by_wheel" - ] - cur_p_mean = patch_means[patch_means["patch"] == patch_name][ - "mean_thresh" - ].values[0] + cur_run_wheel_pref = patch_pref.loc[patch_name].loc[subj]["running_preference_by_wheel"] + cur_p_mean = patch_means[patch_means["patch"] == patch_name]["mean_thresh"].values[0] cur_p = patch_name.replace("Patch", "P") running_pref_by_wheel_plot.add_trace( go.Scatter( @@ -1121,10 +978,7 @@ def calculate_running_preference(group, pref_col, out_col): ) # Add markers for each pellet cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[ - (cum_pel_ct["subject"] == subj) - & (cum_pel_ct["patch"] == patch_name) - ], + cum_pel_ct[(cum_pel_ct["subject"] == subj) & (cum_pel_ct["patch"] == patch_name)], pd.DataFrame( { "time": wheel_ts[patch_name], @@ -1143,15 +997,11 @@ def calculate_running_preference(group, pref_col, out_col): mode="markers", marker={ "symbol": patch_markers_dict[patch_name], - "color": gen_hex_grad( - pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"] - ), + "color": gen_hex_grad(pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"]), "size": 8, }, name=f"{subj} - {cur_p} pellets", - customdata=np.stack( - (cur_cum_pel_ct["threshold"],), axis=-1 - ), + customdata=np.stack((cur_cum_pel_ct["threshold"],), axis=-1), hovertemplate="Threshold: %{customdata[0]:.2f} cm", ) ) @@ -1167,12 +1017,8 @@ def calculate_running_preference(group, pref_col, out_col): # Add trace for each subject-patch combo for subj in subject_names: for patch_name in patch_names: - cur_run_time_pref = patch_pref.loc[patch_name].loc[subj][ - "running_preference_by_time" - ] - cur_p_mean = patch_means[patch_means["patch"] == patch_name][ - "mean_thresh" - ].values[0] + cur_run_time_pref = patch_pref.loc[patch_name].loc[subj]["running_preference_by_time"] + cur_p_mean = patch_means[patch_means["patch"] == patch_name]["mean_thresh"].values[0] cur_p = patch_name.replace("Patch", "P") running_pref_by_patch_fig.add_trace( go.Scatter( @@ -1189,10 +1035,7 @@ def calculate_running_preference(group, pref_col, out_col): ) # Add markers for each pellet cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[ - (cum_pel_ct["subject"] == subj) - & (cum_pel_ct["patch"] == patch_name) - ], + cum_pel_ct[(cum_pel_ct["subject"] == subj) & (cum_pel_ct["patch"] == patch_name)], pd.DataFrame( { "time": wheel_ts[patch_name], @@ -1211,15 +1054,11 @@ def calculate_running_preference(group, pref_col, out_col): mode="markers", marker={ "symbol": patch_markers_dict[patch_name], - "color": gen_hex_grad( - pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"] - ), + "color": gen_hex_grad(pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"]), "size": 8, }, name=f"{subj} - {cur_p} pellets", - customdata=np.stack( - (cur_cum_pel_ct["threshold"],), axis=-1 - ), + customdata=np.stack((cur_cum_pel_ct["threshold"],), axis=-1), hovertemplate="Threshold: %{customdata[0]:.2f} cm", ) ) @@ -1233,9 +1072,7 @@ def calculate_running_preference(group, pref_col, out_col): # Figure 8 - Weighted patch preference: weighted by 'wheel_dist_spun : pel_ct' ratio # --- # Create multi-indexed dataframe with weighted distance for each subject-patch pair - pel_patches = [ - p for p in patch_names if "dummy" not in p.lower() - ] # exclude dummy patches + pel_patches = [p for p in patch_names if "dummy" not in p.lower()] # exclude dummy patches data = [] for patch in pel_patches: for subject in subject_names: @@ -1248,16 +1085,12 @@ def calculate_running_preference(group, pref_col, out_col): } ) subj_wheel_pel_weighted_dist = pd.DataFrame(data) - subj_wheel_pel_weighted_dist.set_index( - ["patch_name", "subject_name"], inplace=True - ) + subj_wheel_pel_weighted_dist.set_index(["patch_name", "subject_name"], inplace=True) subj_wheel_pel_weighted_dist["weighted_dist"] = np.nan # Calculate weighted distance subject_patch_data = (BlockSubjectAnalysis.Patch() & key).fetch(format="frame") - subject_patch_data.reset_index( - level=["experiment_name", "block_start"], drop=True, inplace=True - ) + subject_patch_data.reset_index(level=["experiment_name", "block_start"], drop=True, inplace=True) subj_wheel_pel_weighted_dist = defaultdict(lambda: defaultdict(dict)) for s in subject_names: for p in pel_patches: @@ -1265,14 +1098,11 @@ def calculate_running_preference(group, pref_col, out_col): cur_wheel_cum_dist_df = pd.DataFrame(columns=["time", "cum_wheel_dist"]) cur_wheel_cum_dist_df["time"] = wheel_ts[p] cur_wheel_cum_dist_df["cum_wheel_dist"] = ( - subject_patch_data.loc[p].loc[s]["wheel_cumsum_distance_travelled"] - + 1 + subject_patch_data.loc[p].loc[s]["wheel_cumsum_distance_travelled"] + 1 ) # Get cumulative pellet count cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[ - (cum_pel_ct["subject"] == s) & (cum_pel_ct["patch"] == p) - ], + cum_pel_ct[(cum_pel_ct["subject"] == s) & (cum_pel_ct["patch"] == p)], cur_wheel_cum_dist_df.sort_values("time"), on="time", direction="forward", @@ -1291,9 +1121,7 @@ def calculate_running_preference(group, pref_col, out_col): on="time", direction="forward", ) - max_weight = ( - cur_cum_pel_ct.iloc[-1]["counter"] + 1 - ) # for values after last pellet + max_weight = cur_cum_pel_ct.iloc[-1]["counter"] + 1 # for values after last pellet merged_df["counter"] = merged_df["counter"].fillna(max_weight) merged_df["weighted_cum_wheel_dist"] = ( merged_df.groupby("counter") @@ -1304,9 +1132,7 @@ def calculate_running_preference(group, pref_col, out_col): else: weighted_dist = cur_wheel_cum_dist_df["cum_wheel_dist"].values # Assign to dict - subj_wheel_pel_weighted_dist[p][s]["time"] = cur_wheel_cum_dist_df[ - "time" - ].values + subj_wheel_pel_weighted_dist[p][s]["time"] = cur_wheel_cum_dist_df["time"].values subj_wheel_pel_weighted_dist[p][s]["weighted_dist"] = weighted_dist # Convert back to dataframe data = [] @@ -1317,15 +1143,11 @@ def calculate_running_preference(group, pref_col, out_col): "patch_name": p, "subject_name": s, "time": subj_wheel_pel_weighted_dist[p][s]["time"], - "weighted_dist": subj_wheel_pel_weighted_dist[p][s][ - "weighted_dist" - ], + "weighted_dist": subj_wheel_pel_weighted_dist[p][s]["weighted_dist"], } ) subj_wheel_pel_weighted_dist = pd.DataFrame(data) - subj_wheel_pel_weighted_dist.set_index( - ["patch_name", "subject_name"], inplace=True - ) + subj_wheel_pel_weighted_dist.set_index(["patch_name", "subject_name"], inplace=True) # Calculate normalized weighted value def norm_inv_norm(group): @@ -1334,28 +1156,20 @@ def norm_inv_norm(group): inv_norm_dist = 1 / norm_dist inv_norm_dist = inv_norm_dist / (np.sum(inv_norm_dist, axis=0)) # Map each inv_norm_dist back to patch name. - return pd.Series( - inv_norm_dist.tolist(), index=group.index, name="norm_value" - ) + return pd.Series(inv_norm_dist.tolist(), index=group.index, name="norm_value") subj_wheel_pel_weighted_dist["norm_value"] = ( subj_wheel_pel_weighted_dist.groupby("subject_name") .apply(norm_inv_norm) .reset_index(level=0, drop=True) ) - subj_wheel_pel_weighted_dist["wheel_pref"] = patch_pref[ - "running_preference_by_wheel" - ] + subj_wheel_pel_weighted_dist["wheel_pref"] = patch_pref["running_preference_by_wheel"] # Plot it weighted_patch_pref_fig = make_subplots( rows=len(pel_patches), cols=len(subject_names), - subplot_titles=[ - f"{patch} - {subject}" - for patch in pel_patches - for subject in subject_names - ], + subplot_titles=[f"{patch} - {subject}" for patch in pel_patches for subject in subject_names], specs=[[{"secondary_y": True}] * len(subject_names)] * len(pel_patches), shared_xaxes=True, vertical_spacing=0.1, @@ -1538,9 +1352,7 @@ def make(self, key): for id_val, id_grp in centroid_df.groupby("identity_name"): # Add counts of x,y points to a grid that will be used for heatmap img_grid = np.zeros((max_x + 1, max_y + 1)) - points, counts = np.unique( - id_grp[["x", "y"]].values, return_counts=True, axis=0 - ) + points, counts = np.unique(id_grp[["x", "y"]].values, return_counts=True, axis=0) for point, count in zip(points, counts, strict=True): img_grid[point[0], point[1]] = count img_grid /= img_grid.max() # normalize @@ -1549,9 +1361,7 @@ def make(self, key): # so 45 cm/frame ~= 9 px/frame win_sz = 9 # in pixels (ensure odd for centering) kernel = np.ones((win_sz, win_sz)) / win_sz**2 # moving avg kernel - img_grid_p = np.pad( - img_grid, win_sz // 2, mode="edge" - ) # pad for full output from convolution + img_grid_p = np.pad(img_grid, win_sz // 2, mode="edge") # pad for full output from convolution img_grid_smooth = conv2d(img_grid_p, kernel) heatmaps.append((id_val, img_grid_smooth)) @@ -1580,17 +1390,11 @@ def make(self, key): # Figure 3 - Position ethogram # --- # Get Active Region (ROI) locations - epoch_query = acquisition.Epoch & ( - acquisition.Chunk & key & chunk_restriction - ).proj("epoch_start") + epoch_query = acquisition.Epoch & (acquisition.Chunk & key & chunk_restriction).proj("epoch_start") active_region_query = acquisition.EpochConfig.ActiveRegion & epoch_query - roi_locs = dict( - zip(*active_region_query.fetch("region_name", "region_data"), strict=True) - ) + roi_locs = dict(zip(*active_region_query.fetch("region_name", "region_data"), strict=True)) # get RFID reader locations - recent_rfid_query = ( - acquisition.Experiment.proj() * streams.Device.proj() & key - ).aggr( + recent_rfid_query = (acquisition.Experiment.proj() * streams.Device.proj() & key).aggr( streams.RfidReader & f"rfid_reader_install_time <= '{block_start}'", rfid_reader_install_time="max(rfid_reader_install_time)", ) @@ -1628,30 +1432,18 @@ def make(self, key): # For each ROI, compute if within ROI for roi in rois: - if ( - roi == "Corridor" - ): # special case for corridor, based on between inner and outer radius + if roi == "Corridor": # special case for corridor, based on between inner and outer radius dist = np.linalg.norm( (np.vstack((centroid_df["x"], centroid_df["y"])).T) - arena_center, axis=1, ) - pos_eth_df[roi] = (dist >= arena_inner_radius) & ( - dist <= arena_outer_radius - ) + pos_eth_df[roi] = (dist >= arena_inner_radius) & (dist <= arena_outer_radius) elif roi == "Nest": # special case for nest, based on 4 corners nest_corners = roi_locs["NestRegion"]["ArrayOfPoint"] - nest_br_x, nest_br_y = int(nest_corners[0]["X"]), int( - nest_corners[0]["Y"] - ) - nest_bl_x, nest_bl_y = int(nest_corners[1]["X"]), int( - nest_corners[1]["Y"] - ) - nest_tl_x, nest_tl_y = int(nest_corners[2]["X"]), int( - nest_corners[2]["Y"] - ) - nest_tr_x, nest_tr_y = int(nest_corners[3]["X"]), int( - nest_corners[3]["Y"] - ) + nest_br_x, nest_br_y = int(nest_corners[0]["X"]), int(nest_corners[0]["Y"]) + nest_bl_x, nest_bl_y = int(nest_corners[1]["X"]), int(nest_corners[1]["Y"]) + nest_tl_x, nest_tl_y = int(nest_corners[2]["X"]), int(nest_corners[2]["Y"]) + nest_tr_x, nest_tr_y = int(nest_corners[3]["X"]), int(nest_corners[3]["Y"]) pos_eth_df[roi] = ( (centroid_df["x"] <= nest_br_x) & (centroid_df["y"] >= nest_br_y) @@ -1665,13 +1457,10 @@ def make(self, key): else: roi_radius = gate_radius if roi == "Gate" else patch_radius # Get ROI coords - roi_x, roi_y = int(rfid_locs[roi + "Rfid"]["X"]), int( - rfid_locs[roi + "Rfid"]["Y"] - ) + roi_x, roi_y = int(rfid_locs[roi + "Rfid"]["X"]), int(rfid_locs[roi + "Rfid"]["Y"]) # Check if in ROI dist = np.linalg.norm( - (np.vstack((centroid_df["x"], centroid_df["y"])).T) - - (roi_x, roi_y), + (np.vstack((centroid_df["x"], centroid_df["y"])).T) - (roi_x, roi_y), axis=1, ) pos_eth_df[roi] = dist < roi_radius @@ -1743,7 +1532,7 @@ class Bout(dj.Part): """ def make(self, key): - """ Compute and store foraging bouts for each subject in the block. """ + """Compute and store foraging bouts for each subject in the block.""" foraging_bout_df = get_foraging_bouts(key) foraging_bout_df.rename( columns={ @@ -1800,9 +1589,7 @@ def get_threshold_associated_pellets(patch_key, start, end): - offset - rate """ - chunk_restriction = acquisition.create_chunk_restriction( - patch_key["experiment_name"], start, end - ) + chunk_restriction = acquisition.create_chunk_restriction(patch_key["experiment_name"], start, end) # Step 1 - fetch data # pellet delivery trigger @@ -1810,9 +1597,9 @@ def get_threshold_associated_pellets(patch_key, start, end): streams.UndergroundFeederDeliverPellet & patch_key & chunk_restriction )[start:end] # beambreak - beambreak_df = fetch_stream( - streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction - )[start:end] + beambreak_df = fetch_stream(streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction)[ + start:end + ] # patch threshold depletion_state_df = fetch_stream( streams.UndergroundFeederDepletionState & patch_key & chunk_restriction @@ -1864,18 +1651,14 @@ def get_threshold_associated_pellets(patch_key, start, end): .set_index("time") .dropna(subset=["beam_break_timestamp"]) ) - pellet_beam_break_df.drop_duplicates( - subset="beam_break_timestamp", keep="last", inplace=True - ) + pellet_beam_break_df.drop_duplicates(subset="beam_break_timestamp", keep="last", inplace=True) # Find pellet delivery triggers that approximately coincide with each threshold update # i.e. nearest pellet delivery within 100ms before or after threshold update pellet_ts_threshold_df = ( pd.merge_asof( depletion_state_df.reset_index(), - pellet_beam_break_df.reset_index().rename( - columns={"time": "pellet_timestamp"} - ), + pellet_beam_break_df.reset_index().rename(columns={"time": "pellet_timestamp"}), left_on="time", right_on="pellet_timestamp", tolerance=pd.Timedelta("100ms"), @@ -1888,12 +1671,8 @@ def get_threshold_associated_pellets(patch_key, start, end): # Clean up the df pellet_ts_threshold_df = pellet_ts_threshold_df.drop(columns=["event_x", "event_y"]) # Shift back the pellet_timestamp values by 1 to match with the previous threshold update - pellet_ts_threshold_df.pellet_timestamp = ( - pellet_ts_threshold_df.pellet_timestamp.shift(-1) - ) - pellet_ts_threshold_df.beam_break_timestamp = ( - pellet_ts_threshold_df.beam_break_timestamp.shift(-1) - ) + pellet_ts_threshold_df.pellet_timestamp = pellet_ts_threshold_df.pellet_timestamp.shift(-1) + pellet_ts_threshold_df.beam_break_timestamp = pellet_ts_threshold_df.beam_break_timestamp.shift(-1) pellet_ts_threshold_df = pellet_ts_threshold_df.dropna( subset=["pellet_timestamp", "beam_break_timestamp"] ) @@ -1920,12 +1699,8 @@ def get_foraging_bouts( Returns: DataFrame containing foraging bouts. Columns: duration, n_pellets, cum_wheel_dist, subject. """ - max_inactive_time = ( - pd.Timedelta(seconds=60) if max_inactive_time is None else max_inactive_time - ) - bout_data = pd.DataFrame( - columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"] - ) + max_inactive_time = pd.Timedelta(seconds=60) if max_inactive_time is None else max_inactive_time + bout_data = pd.DataFrame(columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"]) subject_patch_data = (BlockSubjectAnalysis.Patch() & key).fetch(format="frame") if subject_patch_data.empty: return bout_data @@ -1969,52 +1744,34 @@ def get_foraging_bouts( wheel_s_r = pd.Timedelta(wheel_ts[1] - wheel_ts[0], unit="ns") max_inactive_win_len = int(max_inactive_time / wheel_s_r) # Find times when foraging - max_windowed_wheel_vals = ( - patch_spun_df["cum_wheel_dist"].shift(-(max_inactive_win_len - 1)).ffill() - ) - foraging_mask = max_windowed_wheel_vals > ( - patch_spun_df["cum_wheel_dist"] + min_wheel_movement - ) + max_windowed_wheel_vals = patch_spun_df["cum_wheel_dist"].shift(-(max_inactive_win_len - 1)).ffill() + foraging_mask = max_windowed_wheel_vals > (patch_spun_df["cum_wheel_dist"] + min_wheel_movement) # Discretize into foraging bouts - bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + ( - max_inactive_win_len - 1 - ) + bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + (max_inactive_win_len - 1) n_samples_in_1s = int(1 / wheel_s_r.total_seconds()) bout_end_indxs = ( np.where(np.diff(foraging_mask, prepend=0) == -1)[0] + (max_inactive_win_len - 1) + n_samples_in_1s ) - bout_end_indxs[-1] = min( - bout_end_indxs[-1], len(wheel_ts) - 1 - ) # ensure last bout ends in block + bout_end_indxs[-1] = min(bout_end_indxs[-1], len(wheel_ts) - 1) # ensure last bout ends in block # Remove bout that starts at block end if bout_start_indxs[-1] >= len(wheel_ts): bout_start_indxs = bout_start_indxs[:-1] bout_end_indxs = bout_end_indxs[:-1] if len(bout_start_indxs) != len(bout_end_indxs): - raise ValueError( - "Mismatch between the lengths of bout_start_indxs and bout_end_indxs." - ) - bout_durations = ( - wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs] - ).astype( # in seconds + raise ValueError("Mismatch between the lengths of bout_start_indxs and bout_end_indxs.") + bout_durations = (wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs]).astype( # in seconds "timedelta64[ns]" - ).astype( - float - ) / 1e9 + ).astype(float) / 1e9 bout_starts_ends = np.array( [ (wheel_ts[start_idx], wheel_ts[end_idx]) - for start_idx, end_idx in zip( - bout_start_indxs, bout_end_indxs, strict=True - ) + for start_idx, end_idx in zip(bout_start_indxs, bout_end_indxs, strict=True) ] ) all_pel_ts = np.sort( - np.concatenate( - [arr for arr in cur_subject_data["pellet_timestamps"] if len(arr) > 0] - ) + np.concatenate([arr for arr in cur_subject_data["pellet_timestamps"] if len(arr) > 0]) ) bout_pellets = np.array( [ @@ -2028,8 +1785,7 @@ def get_foraging_bouts( bout_pellets = bout_pellets[bout_pellets >= min_pellets] bout_cum_wheel_dist = np.array( [ - patch_spun_df.loc[end, "cum_wheel_dist"] - - patch_spun_df.loc[start, "cum_wheel_dist"] + patch_spun_df.loc[end, "cum_wheel_dist"] - patch_spun_df.loc[start, "cum_wheel_dist"] for start, end in bout_starts_ends ] ) diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index b8440be9..7c6e6077 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -70,15 +70,15 @@ class Visit(dj.Part): @property def key_source(self): """Key source for OverlapVisit.""" - return dj.U("experiment_name", "place", "overlap_start") & ( - Visit & VisitEnd - ).proj(overlap_start="visit_start") + return dj.U("experiment_name", "place", "overlap_start") & (Visit & VisitEnd).proj( + overlap_start="visit_start" + ) def make(self, key): """Populate OverlapVisit table with overlapping visits.""" - visit_starts, visit_ends = ( - Visit * VisitEnd & key & {"visit_start": key["overlap_start"]} - ).fetch("visit_start", "visit_end") + visit_starts, visit_ends = (Visit * VisitEnd & key & {"visit_start": key["overlap_start"]}).fetch( + "visit_start", "visit_end" + ) visit_start = min(visit_starts) visit_end = max(visit_ends) @@ -92,9 +92,7 @@ def make(self, key): if len(overlap_query) <= 1: break overlap_visits.extend( - overlap_query.proj(overlap_start=f'"{key["overlap_start"]}"').fetch( - as_dict=True - ) + overlap_query.proj(overlap_start=f'"{key["overlap_start"]}"').fetch(as_dict=True) ) visit_starts, visit_ends = overlap_query.fetch("visit_start", "visit_end") if visit_start == max(visit_starts) and visit_end == max(visit_ends): @@ -108,10 +106,7 @@ def make(self, key): { **key, "overlap_end": visit_end, - "overlap_duration": ( - visit_end - key["overlap_start"] - ).total_seconds() - / 3600, + "overlap_duration": (visit_end - key["overlap_start"]).total_seconds() / 3600, "subject_count": len({v["subject"] for v in overlap_visits}), } ) @@ -198,22 +193,16 @@ def ingest_environment_visits(experiment_names: list | None = None): def get_maintenance_periods(experiment_name, start, end): """Get maintenance periods for the specified experiment and time range.""" # get states from acquisition.Environment.EnvironmentState - chunk_restriction = acquisition.create_chunk_restriction( - experiment_name, start, end - ) + chunk_restriction = acquisition.create_chunk_restriction(experiment_name, start, end) state_query = ( - acquisition.Environment.EnvironmentState - & {"experiment_name": experiment_name} - & chunk_restriction + acquisition.Environment.EnvironmentState & {"experiment_name": experiment_name} & chunk_restriction ) env_state_df = fetch_stream(state_query)[start:end] if env_state_df.empty: return deque([]) env_state_df.reset_index(inplace=True) - env_state_df = env_state_df[ - env_state_df["state"].shift() != env_state_df["state"] - ].reset_index( + env_state_df = env_state_df[env_state_df["state"].shift() != env_state_df["state"]].reset_index( drop=True ) # remove duplicates and keep the first one # An experiment starts with visit start (anything before the first maintenance is experiment) @@ -229,12 +218,8 @@ def get_maintenance_periods(experiment_name, start, end): env_state_df = pd.concat([env_state_df, log_df_end]) env_state_df.reset_index(drop=True, inplace=True) - maintenance_starts = env_state_df.loc[ - env_state_df["state"] == "Maintenance", "time" - ].values - maintenance_ends = env_state_df.loc[ - env_state_df["state"] != "Maintenance", "time" - ].values + maintenance_starts = env_state_df.loc[env_state_df["state"] == "Maintenance", "time"].values + maintenance_ends = env_state_df.loc[env_state_df["state"] != "Maintenance", "time"].values return deque( [ @@ -251,9 +236,7 @@ def filter_out_maintenance_periods(data_df, maintenance_period, end_time, dropna (maintenance_start, maintenance_end) = maint_period[0] if end_time < maintenance_start: # no more maintenance for this date break - maintenance_filter = (data_df.index >= maintenance_start) & ( - data_df.index <= maintenance_end - ) + maintenance_filter = (data_df.index >= maintenance_start) & (data_df.index <= maintenance_end) data_df[maintenance_filter] = np.nan if end_time >= maintenance_end: # remove this range maint_period.popleft() diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index c36cccbe..88922f05 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -91,8 +91,7 @@ def key_source(self): + chunk starts after visit_start and ends before visit_end (or NOW() - i.e. ongoing visits). """ return ( - Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") - * acquisition.Chunk + Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") * acquisition.Chunk & acquisition.SubjectEnterExit & [ "visit_start BETWEEN chunk_start AND chunk_end", @@ -104,9 +103,7 @@ def key_source(self): def make(self, key): """Populate VisitSubjectPosition for each visit""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") # -- Determine the time to start time_slicing in this chunk if chunk_start < key["visit_start"] < chunk_end: @@ -175,12 +172,8 @@ def make(self, key): end_time = np.array(end_time, dtype="datetime64[ns]") while time_slice_start < end_time: - time_slice_end = time_slice_start + min( - self._time_slice_duration, end_time - time_slice_start - ) - in_time_slice = np.logical_and( - timestamps >= time_slice_start, timestamps < time_slice_end - ) + time_slice_end = time_slice_start + min(self._time_slice_duration, end_time - time_slice_start) + in_time_slice = np.logical_and(timestamps >= time_slice_start, timestamps < time_slice_end) chunk_time_slices.append( { **key, @@ -203,14 +196,9 @@ def get_position(cls, visit_key=None, subject=None, start=None, end=None): """Given a key to a single Visit, return a Pandas DataFrame for the position data of the subject for the specified Visit time period.""" if visit_key is not None: if len(Visit & visit_key) != 1: - raise ValueError( - "The `visit_key` must correspond to exactly one Visit." - ) + raise ValueError("The `visit_key` must correspond to exactly one Visit.") start, end = ( - Visit.join(VisitEnd, left=True).proj( - visit_end="IFNULL(visit_end, NOW())" - ) - & visit_key + Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") & visit_key ).fetch1("visit_start", "visit_end") subject = visit_key["subject"] elif all((subject, start, end)): @@ -271,9 +259,7 @@ class FoodPatch(dj.Part): """ # Work on finished visits - key_source = Visit & ( - VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end" - ) + key_source = Visit & (VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end") def make(self, key): """Populate VisitTimeDistribution for each visit""" @@ -281,9 +267,7 @@ def make(self, key): visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) ) - maintenance_period = get_maintenance_periods( - key["experiment_name"], visit_start, visit_end - ) + maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) for visit_date in visit_dates: day_start = datetime.datetime.combine(visit_date.date(), time.min) @@ -303,16 +287,12 @@ def make(self, key): subject=key["subject"], start=day_start, end=day_end ) # filter out maintenance period based on logs - position = filter_out_maintenance_periods( - position, maintenance_period, day_end - ) + position = filter_out_maintenance_periods(position, maintenance_period, day_end) # filter for objects of the correct size valid_position = (position.area > 0) & (position.area < 1000) position[~valid_position] = np.nan - position.rename( - columns={"position_x": "x", "position_y": "y"}, inplace=True - ) + position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) # in corridor distance_from_center = tracking.compute_distance( position[["x", "y"]], @@ -356,9 +336,9 @@ def make(self, key): in_food_patch_times = [] for food_patch_key in food_patch_keys: # wheel data - food_patch_description = ( - acquisition.ExperimentFoodPatch & food_patch_key - ).fetch1("food_patch_description") + food_patch_description = (acquisition.ExperimentFoodPatch & food_patch_key).fetch1( + "food_patch_description" + ) wheel_data = acquisition.FoodPatchWheel.get_wheel_data( experiment_name=key["experiment_name"], start=pd.Timestamp(day_start), @@ -367,12 +347,10 @@ def make(self, key): using_aeon_io=True, ) # filter out maintenance period based on logs - wheel_data = filter_out_maintenance_periods( - wheel_data, maintenance_period, day_end + wheel_data = filter_out_maintenance_periods(wheel_data, maintenance_period, day_end) + patch_position = (acquisition.ExperimentFoodPatch.Position & food_patch_key).fetch1( + "food_patch_position_x", "food_patch_position_y" ) - patch_position = ( - acquisition.ExperimentFoodPatch.Position & food_patch_key - ).fetch1("food_patch_position_x", "food_patch_position_y") in_patch = tracking.is_position_in_patch( position, patch_position, @@ -427,9 +405,7 @@ class FoodPatch(dj.Part): """ # Work on finished visits - key_source = Visit & ( - VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end" - ) + key_source = Visit & (VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end") def make(self, key): """Populate VisitSummary for each visit""" @@ -437,9 +413,7 @@ def make(self, key): visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) ) - maintenance_period = get_maintenance_periods( - key["experiment_name"], visit_start, visit_end - ) + maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) for visit_date in visit_dates: day_start = datetime.datetime.combine(visit_date.date(), time.min) @@ -460,18 +434,12 @@ def make(self, key): subject=key["subject"], start=day_start, end=day_end ) # filter out maintenance period based on logs - position = filter_out_maintenance_periods( - position, maintenance_period, day_end - ) + position = filter_out_maintenance_periods(position, maintenance_period, day_end) # filter for objects of the correct size valid_position = (position.area > 0) & (position.area < 1000) position[~valid_position] = np.nan - position.rename( - columns={"position_x": "x", "position_y": "y"}, inplace=True - ) - position_diff = np.sqrt( - np.square(np.diff(position.x)) + np.square(np.diff(position.y)) - ) + position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) + position_diff = np.sqrt(np.square(np.diff(position.x)) + np.square(np.diff(position.y))) total_distance_travelled = np.nansum(position_diff) # in food patches - loop through all in-use patches during this visit @@ -507,9 +475,9 @@ def make(self, key): dropna=True, ).index.values # wheel data - food_patch_description = ( - acquisition.ExperimentFoodPatch & food_patch_key - ).fetch1("food_patch_description") + food_patch_description = (acquisition.ExperimentFoodPatch & food_patch_key).fetch1( + "food_patch_description" + ) wheel_data = acquisition.FoodPatchWheel.get_wheel_data( experiment_name=key["experiment_name"], start=pd.Timestamp(day_start), @@ -518,9 +486,7 @@ def make(self, key): using_aeon_io=True, ) # filter out maintenance period based on logs - wheel_data = filter_out_maintenance_periods( - wheel_data, maintenance_period, day_end - ) + wheel_data = filter_out_maintenance_periods(wheel_data, maintenance_period, day_end) food_patch_statistics.append( { @@ -528,15 +494,11 @@ def make(self, key): **food_patch_key, "visit_date": visit_date.date(), "pellet_count": len(pellet_events), - "wheel_distance_travelled": wheel_data.distance_travelled.values[ - -1 - ], + "wheel_distance_travelled": wheel_data.distance_travelled.values[-1], } ) - total_pellet_count = np.sum( - [p["pellet_count"] for p in food_patch_statistics] - ) + total_pellet_count = np.sum([p["pellet_count"] for p in food_patch_statistics]) total_wheel_distance_travelled = np.sum( [p["wheel_distance_travelled"] for p in food_patch_statistics] ) @@ -570,10 +532,7 @@ class VisitForagingBout(dj.Computed): # Work on 24/7 experiments key_source = ( - Visit - & VisitSummary - & (VisitEnd & "visit_duration > 24") - & "experiment_name= 'exp0.2-r0'" + Visit & VisitSummary & (VisitEnd & "visit_duration > 24") & "experiment_name= 'exp0.2-r0'" ) * acquisition.ExperimentFoodPatch def make(self, key): @@ -581,17 +540,13 @@ def make(self, key): visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") # get in_patch timestamps - food_patch_description = (acquisition.ExperimentFoodPatch & key).fetch1( - "food_patch_description" - ) + food_patch_description = (acquisition.ExperimentFoodPatch & key).fetch1("food_patch_description") in_patch_times = np.concatenate( - ( - VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & key - ).fetch("in_patch", order_by="visit_date") - ) - maintenance_period = get_maintenance_periods( - key["experiment_name"], visit_start, visit_end + (VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & key).fetch( + "in_patch", order_by="visit_date" + ) ) + maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) in_patch_times = filter_out_maintenance_periods( pd.DataFrame( [[food_patch_description]] * len(in_patch_times), @@ -619,12 +574,8 @@ def make(self, key): .set_index("event_time") ) # TODO: handle multiple retries of pellet delivery - maintenance_period = get_maintenance_periods( - key["experiment_name"], visit_start, visit_end - ) - patch = filter_out_maintenance_periods( - patch, maintenance_period, visit_end, True - ) + maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) + patch = filter_out_maintenance_periods(patch, maintenance_period, visit_end, True) if len(in_patch_times): change_ind = ( @@ -640,9 +591,7 @@ def make(self, key): ts_array = in_patch_times[change_ind[i - 1] : change_ind[i]] wheel_start, wheel_end = ts_array[0], ts_array[-1] - if ( - wheel_start >= wheel_end - ): # skip if timestamps were misaligned or a single timestamp + if wheel_start >= wheel_end: # skip if timestamps were misaligned or a single timestamp continue wheel_data = acquisition.FoodPatchWheel.get_wheel_data( @@ -652,19 +601,14 @@ def make(self, key): patch_name=food_patch_description, using_aeon_io=True, ) - maintenance_period = get_maintenance_periods( - key["experiment_name"], visit_start, visit_end - ) - wheel_data = filter_out_maintenance_periods( - wheel_data, maintenance_period, visit_end, True - ) + maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) + wheel_data = filter_out_maintenance_periods(wheel_data, maintenance_period, visit_end, True) self.insert1( { **key, "bout_start": ts_array[0], "bout_end": ts_array[-1], - "bout_duration": (ts_array[-1] - ts_array[0]) - / np.timedelta64(1, "s"), + "bout_duration": (ts_array[-1] - ts_array[0]) / np.timedelta64(1, "s"), "wheel_distance_travelled": wheel_data.distance_travelled[-1], "pellet_count": len(patch.loc[wheel_start:wheel_end]), } diff --git a/aeon/dj_pipeline/create_experiments/create_experiment_01.py b/aeon/dj_pipeline/create_experiments/create_experiment_01.py index 5f765b3b..18edb4c3 100644 --- a/aeon/dj_pipeline/create_experiments/create_experiment_01.py +++ b/aeon/dj_pipeline/create_experiments/create_experiment_01.py @@ -35,10 +35,7 @@ def ingest_exp01_metadata(metadata_yml_filepath, experiment_name): & camera_key ) if current_camera_query: # If the same camera is currently installed - if ( - current_camera_query.fetch1("camera_install_time") - == arena_setup["start-time"] - ): + if current_camera_query.fetch1("camera_install_time") == arena_setup["start-time"]: # If it is installed at the same time as that read from this yml file # then it is the same ExperimentCamera instance, no need to do anything continue @@ -58,9 +55,7 @@ def ingest_exp01_metadata(metadata_yml_filepath, experiment_name): "experiment_name": experiment_name, "camera_install_time": arena_setup["start-time"], "camera_description": camera["description"], - "camera_sampling_rate": device_frequency_mapper[ - camera["trigger-source"].lower() - ], + "camera_sampling_rate": device_frequency_mapper[camera["trigger-source"].lower()], } ) acquisition.ExperimentCamera.Position.insert1( @@ -76,23 +71,17 @@ def ingest_exp01_metadata(metadata_yml_filepath, experiment_name): # ---- Load food patches ---- for patch in arena_setup["patches"]: # ---- Check if this is a new food patch, add to lab.FoodPatch if needed - patch_key = { - "food_patch_serial_number": patch["serial-number"] or patch["port-name"] - } + patch_key = {"food_patch_serial_number": patch["serial-number"] or patch["port-name"]} if patch_key not in lab.FoodPatch(): lab.FoodPatch.insert1(patch_key) # ---- Check if this food patch is currently installed - if so, remove it current_patch_query = ( - acquisition.ExperimentFoodPatch - - acquisition.ExperimentFoodPatch.RemovalTime + acquisition.ExperimentFoodPatch - acquisition.ExperimentFoodPatch.RemovalTime & {"experiment_name": experiment_name} & patch_key ) if current_patch_query: # If the same food-patch is currently installed - if ( - current_patch_query.fetch1("food_patch_install_time") - == arena_setup["start-time"] - ): + if current_patch_query.fetch1("food_patch_install_time") == arena_setup["start-time"]: # If it is installed at the same time as that read from this yml file # then it is the same ExperimentFoodPatch instance, no need to do anything continue @@ -127,21 +116,16 @@ def ingest_exp01_metadata(metadata_yml_filepath, experiment_name): ) # ---- Load weight scales ---- for weight_scale in arena_setup["weight-scales"]: - weight_scale_key = { - "weight_scale_serial_number": weight_scale["serial-number"] - } + weight_scale_key = {"weight_scale_serial_number": weight_scale["serial-number"]} if weight_scale_key not in lab.WeightScale(): lab.WeightScale.insert1(weight_scale_key) # ---- Check if this weight scale is currently installed - if so, remove it current_weight_scale_query = ( - acquisition.ExperimentWeightScale - - acquisition.ExperimentWeightScale.RemovalTime + acquisition.ExperimentWeightScale - acquisition.ExperimentWeightScale.RemovalTime & {"experiment_name": experiment_name} & weight_scale_key ) - if ( - current_weight_scale_query - ): # If the same weight scale is currently installed + if current_weight_scale_query: # If the same weight scale is currently installed if ( current_weight_scale_query.fetch1("weight_scale_install_time") == arena_setup["start-time"] @@ -271,12 +255,8 @@ def add_arena_setup(): # manually update coordinates of foodpatch and nest patch_coordinates = {"Patch1": (1.13, 1.59, 0), "Patch2": (1.19, 0.50, 0)} - for patch_key in ( - acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name} - ).fetch("KEY"): - patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1( - "food_patch_description" - ) + for patch_key in (acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name}).fetch("KEY"): + patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1("food_patch_description") x, y, z = patch_coordinates[patch] acquisition.ExperimentFoodPatch.Position.update1( { diff --git a/aeon/dj_pipeline/create_experiments/create_experiment_02.py b/aeon/dj_pipeline/create_experiments/create_experiment_02.py index e2965e91..c5aead5b 100644 --- a/aeon/dj_pipeline/create_experiments/create_experiment_02.py +++ b/aeon/dj_pipeline/create_experiments/create_experiment_02.py @@ -33,10 +33,7 @@ def create_new_experiment(): skip_duplicates=True, ) acquisition.Experiment.Subject.insert( - [ - {"experiment_name": experiment_name, "subject": s["subject"]} - for s in subject_list - ], + [{"experiment_name": experiment_name, "subject": s["subject"]} for s in subject_list], skip_duplicates=True, ) diff --git a/aeon/dj_pipeline/create_experiments/create_octagon_1.py b/aeon/dj_pipeline/create_experiments/create_octagon_1.py index 56a24613..1d95e1d5 100644 --- a/aeon/dj_pipeline/create_experiments/create_octagon_1.py +++ b/aeon/dj_pipeline/create_experiments/create_octagon_1.py @@ -36,10 +36,7 @@ def create_new_experiment(): skip_duplicates=True, ) acquisition.Experiment.Subject.insert( - [ - {"experiment_name": experiment_name, "subject": s["subject"]} - for s in subject_list - ], + [{"experiment_name": experiment_name, "subject": s["subject"]} for s in subject_list], skip_duplicates=True, ) diff --git a/aeon/dj_pipeline/create_experiments/create_presocial.py b/aeon/dj_pipeline/create_experiments/create_presocial.py index 0676dd7f..c66d7725 100644 --- a/aeon/dj_pipeline/create_experiments/create_presocial.py +++ b/aeon/dj_pipeline/create_experiments/create_presocial.py @@ -12,9 +12,7 @@ def create_new_experiment(): """Create new experiments for presocial0.1""" lab.Location.insert1({"lab": "SWC", "location": location}, skip_duplicates=True) - acquisition.ExperimentType.insert1( - {"experiment_type": experiment_type}, skip_duplicates=True - ) + acquisition.ExperimentType.insert1({"experiment_type": experiment_type}, skip_duplicates=True) acquisition.Experiment.insert( [ diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py index 6ece34b6..68925d29 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py @@ -40,9 +40,7 @@ def create_new_social_experiment(experiment_name): "experiment_name": experiment_name, "repository_name": "ceph_aeon", "directory_type": dir_type, - "directory_path": ( - ceph_data_dir / dir_type / machine_name.upper() / exp_name - ) + "directory_path": (ceph_data_dir / dir_type / machine_name.upper() / exp_name) .relative_to(ceph_dir) .as_posix(), "load_order": load_order, @@ -55,9 +53,7 @@ def create_new_social_experiment(experiment_name): new_experiment_entry, skip_duplicates=True, ) - acquisition.Experiment.Directory.insert( - experiment_directories, skip_duplicates=True - ) + acquisition.Experiment.Directory.insert(experiment_directories, skip_duplicates=True) acquisition.Experiment.DevicesSchema.insert1( { "experiment_name": experiment_name, diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py index 0b4b021f..497cc9e9 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py @@ -38,10 +38,7 @@ def create_new_experiment(): skip_duplicates=True, ) acquisition.Experiment.Subject.insert( - [ - {"experiment_name": experiment_name, "subject": s["subject"]} - for s in subject_list - ], + [{"experiment_name": experiment_name, "subject": s["subject"]} for s in subject_list], skip_duplicates=True, ) @@ -97,12 +94,8 @@ def add_arena_setup(): # manually update coordinates of foodpatch and nest patch_coordinates = {"Patch1": (1.13, 1.59, 0), "Patch2": (1.19, 0.50, 0)} - for patch_key in ( - acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name} - ).fetch("KEY"): - patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1( - "food_patch_description" - ) + for patch_key in (acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name}).fetch("KEY"): + patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1("food_patch_description") x, y, z = patch_coordinates[patch] acquisition.ExperimentFoodPatch.Position.update1( { @@ -161,11 +154,15 @@ def fixID(subjid, valid_ids=None, valid_id_file=None): # The subjid is a combo subjid. if ";" in subjid: subjidA, subjidB = subjid.split(";") - return f"{fixID(subjidA.strip(), valid_ids=valid_ids)};{fixID(subjidB.strip(), valid_ids=valid_ids)}" + return ( + f"{fixID(subjidA.strip(), valid_ids=valid_ids)};{fixID(subjidB.strip(), valid_ids=valid_ids)}" + ) if "vs" in subjid: subjidA, tmp, subjidB = subjid.split(" ")[1:] - return f"{fixID(subjidA.strip(), valid_ids=valid_ids)};{fixID(subjidB.strip(), valid_ids=valid_ids)}" + return ( + f"{fixID(subjidA.strip(), valid_ids=valid_ids)};{fixID(subjidB.strip(), valid_ids=valid_ids)}" + ) try: ld = [jl.levenshtein_distance(subjid, x[-len(subjid) :]) for x in valid_ids] diff --git a/aeon/dj_pipeline/populate/process.py b/aeon/dj_pipeline/populate/process.py index 049a3233..5c2e4d15 100644 --- a/aeon/dj_pipeline/populate/process.py +++ b/aeon/dj_pipeline/populate/process.py @@ -76,9 +76,7 @@ def run(**kwargs): try: worker.run() except Exception: - logger.exception( - "action '{}' encountered an exception:".format(kwargs["worker_name"]) - ) + logger.exception("action '{}' encountered an exception:".format(kwargs["worker_name"])) logger.info("Ingestion process ended.") diff --git a/aeon/dj_pipeline/populate/worker.py b/aeon/dj_pipeline/populate/worker.py index 35e93da6..c93a9bce 100644 --- a/aeon/dj_pipeline/populate/worker.py +++ b/aeon/dj_pipeline/populate/worker.py @@ -117,6 +117,4 @@ def get_workflow_operation_overview(): """Get the workflow operation overview for the worker schema.""" from datajoint_utilities.dj_worker.utils import get_workflow_operation_overview - return get_workflow_operation_overview( - worker_schema_name=worker_schema_name, db_prefixes=[db_prefix] - ) + return get_workflow_operation_overview(worker_schema_name=worker_schema_name, db_prefixes=[db_prefix]) diff --git a/aeon/dj_pipeline/qc.py b/aeon/dj_pipeline/qc.py index 9cc6bc63..5fa101ef 100644 --- a/aeon/dj_pipeline/qc.py +++ b/aeon/dj_pipeline/qc.py @@ -62,9 +62,7 @@ def key_source(self): return ( acquisition.Chunk * ( - streams.SpinnakerVideoSource.join( - streams.SpinnakerVideoSource.RemovalTime, left=True - ) + streams.SpinnakerVideoSource.join(streams.SpinnakerVideoSource.RemovalTime, left=True) & "spinnaker_video_source_name='CameraTop'" ) & "chunk_start >= spinnaker_video_source_install_time" @@ -73,21 +71,16 @@ def key_source(self): def make(self, key): """Perform quality control checks on the CameraTop stream""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - device_name = (streams.SpinnakerVideoSource & key).fetch1( - "spinnaker_video_source_name" - ) + device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name") data_dirs = acquisition.Experiment.get_data_directories(key) devices_schema = getattr( acquisition.aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(devices_schema, device_name).Video @@ -114,11 +107,9 @@ def make(self, key): **key, "drop_count": deltas.frame_offset.iloc[-1], "max_harp_delta": deltas.time_delta.max().total_seconds(), - "max_camera_delta": deltas.hw_timestamp_delta.max() - / 1e9, # convert to seconds + "max_camera_delta": deltas.hw_timestamp_delta.max() / 1e9, # convert to seconds "timestamps": videodata.index.values, - "time_delta": deltas.time_delta.values - / np.timedelta64(1, "s"), # convert to seconds + "time_delta": deltas.time_delta.values / np.timedelta64(1, "s"), # convert to seconds "frame_delta": deltas.frame_delta.values, "hw_counter_delta": deltas.hw_counter_delta.values, "hw_timestamp_delta": deltas.hw_timestamp_delta.values, diff --git a/aeon/dj_pipeline/report.py b/aeon/dj_pipeline/report.py index 03f8f89c..d09bfae2 100644 --- a/aeon/dj_pipeline/report.py +++ b/aeon/dj_pipeline/report.py @@ -32,9 +32,7 @@ class InArenaSummaryPlot(dj.Computed): summary_plot_png: attach """ - key_source = ( - analysis.InArena & analysis.InArenaTimeDistribution & analysis.InArenaSummary - ) + key_source = analysis.InArena & analysis.InArenaTimeDistribution & analysis.InArenaSummary color_code = { "Patch1": "b", @@ -46,17 +44,15 @@ class InArenaSummaryPlot(dj.Computed): def make(self, key): """Make method for InArenaSummaryPlot table""" - in_arena_start, in_arena_end = ( - analysis.InArena * analysis.InArenaEnd & key - ).fetch1("in_arena_start", "in_arena_end") + in_arena_start, in_arena_end = (analysis.InArena * analysis.InArenaEnd & key).fetch1( + "in_arena_start", "in_arena_end" + ) # subject's position data in the time_slices position = analysis.InArenaSubjectPosition.get_position(key) position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) - position_minutes_elapsed = ( - position.index - in_arena_start - ).total_seconds() / 60 + position_minutes_elapsed = (position.index - in_arena_start).total_seconds() / 60 # figure fig = plt.figure(figsize=(20, 9)) @@ -71,16 +67,12 @@ def make(self, key): # position plot non_nan = np.logical_and(~np.isnan(position.x), ~np.isnan(position.y)) - analysis_plotting.heatmap( - position[non_nan], 50, ax=position_ax, bins=500, alpha=0.5 - ) + analysis_plotting.heatmap(position[non_nan], 50, ax=position_ax, bins=500, alpha=0.5) # event rate plots in_arena_food_patches = ( analysis.InArena - * acquisition.ExperimentFoodPatch.join( - acquisition.ExperimentFoodPatch.RemovalTime, left=True - ) + * acquisition.ExperimentFoodPatch.join(acquisition.ExperimentFoodPatch.RemovalTime, left=True) & key & "in_arena_start >= food_patch_install_time" & 'in_arena_start < IFNULL(food_patch_remove_time, "2200-01-01")' @@ -147,9 +139,7 @@ def make(self, key): color=self.color_code[food_patch_key["food_patch_description"]], alpha=0.3, ) - threshold_change_ind = np.where( - wheel_threshold[:-1] != wheel_threshold[1:] - )[0] + threshold_change_ind = np.where(wheel_threshold[:-1] != wheel_threshold[1:])[0] threshold_ax.vlines( wheel_time[threshold_change_ind + 1], ymin=wheel_threshold[threshold_change_ind], @@ -161,20 +151,17 @@ def make(self, key): ) # ethogram - in_arena, in_corridor, arena_time, corridor_time = ( - analysis.InArenaTimeDistribution & key - ).fetch1( + in_arena, in_corridor, arena_time, corridor_time = (analysis.InArenaTimeDistribution & key).fetch1( "in_arena", "in_corridor", "time_fraction_in_arena", "time_fraction_in_corridor", ) - nest_keys, in_nests, nests_times = ( - analysis.InArenaTimeDistribution.Nest & key - ).fetch("KEY", "in_nest", "time_fraction_in_nest") + nest_keys, in_nests, nests_times = (analysis.InArenaTimeDistribution.Nest & key).fetch( + "KEY", "in_nest", "time_fraction_in_nest" + ) patch_names, in_patches, patches_times = ( - analysis.InArenaTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch - & key + analysis.InArenaTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & key ).fetch("food_patch_description", "in_patch", "time_fraction_in_patch") ethogram_ax.plot( @@ -205,9 +192,7 @@ def make(self, key): alpha=0.6, label="nest", ) - for patch_idx, (patch_name, in_patch) in enumerate( - zip(patch_names, in_patches) - ): + for patch_idx, (patch_name, in_patch) in enumerate(zip(patch_names, in_patches)): ethogram_ax.plot( position_minutes_elapsed[in_patch], np.full_like(position_minutes_elapsed[in_patch], (patch_idx + 3)), @@ -248,9 +233,7 @@ def make(self, key): rate_ax.set_title("foraging rate (bin size = 10 min)") distance_ax.set_ylabel("distance travelled (m)") threshold_ax.set_ylabel("threshold") - threshold_ax.set_ylim( - [threshold_ax.get_ylim()[0] - 100, threshold_ax.get_ylim()[1] + 100] - ) + threshold_ax.set_ylim([threshold_ax.get_ylim()[0] - 100, threshold_ax.get_ylim()[1] + 100]) ethogram_ax.set_xlabel("time (min)") analysis_plotting.set_ymargin(distance_ax, 0.2, 0.1) for ax in (rate_ax, distance_ax, pellet_ax, time_dist_ax, threshold_ax): @@ -279,9 +262,7 @@ def make(self, key): # ---- Save fig and insert ---- save_dir = _make_path(key) - fig_dict = _save_figs( - (fig,), ("summary_plot_png",), save_dir=save_dir, prefix=save_dir.name - ) + fig_dict = _save_figs((fig,), ("summary_plot_png",), save_dir=save_dir, prefix=save_dir.name) self.insert1({**key, **fig_dict}) @@ -450,10 +431,7 @@ class VisitDailySummaryPlot(dj.Computed): """ key_source = ( - Visit - & analysis.VisitSummary - & (VisitEnd & "visit_duration > 24") - & "experiment_name= 'exp0.2-r0'" + Visit & analysis.VisitSummary & (VisitEnd & "visit_duration > 24") & "experiment_name= 'exp0.2-r0'" ) def make(self, key): @@ -562,12 +540,7 @@ def _make_path(in_arena_key): experiment_name, subject, in_arena_start = (analysis.InArena & in_arena_key).fetch1( "experiment_name", "subject", "in_arena_start" ) - output_dir = ( - store_stage - / experiment_name - / subject - / in_arena_start.strftime("%y%m%d_%H%M%S_%f") - ) + output_dir = store_stage / experiment_name / subject / in_arena_start.strftime("%y%m%d_%H%M%S_%f") output_dir.mkdir(parents=True, exist_ok=True) return output_dir diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py index 4521b75d..024d0900 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py @@ -108,8 +108,7 @@ def validate(): target_entry_count = len(target_tbl()) missing_entries[orig_schema_name][source_tbl.__name__] = { "entry_count_diff": source_entry_count - target_entry_count, - "db_size_diff": source_tbl().size_on_disk - - target_tbl().size_on_disk, + "db_size_diff": source_tbl().size_on_disk - target_tbl().size_on_disk, } return { diff --git a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py index 18a89bc7..9cd845ef 100644 --- a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py +++ b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py @@ -12,9 +12,7 @@ from tqdm import tqdm if dj.__version__ < "0.13.7": - raise ImportError( - f"DataJoint version must be at least 0.13.7, but found {dj.__version__}." - ) + raise ImportError(f"DataJoint version must be at least 0.13.7, but found {dj.__version__}.") schema = dj.schema("u_thinh_aeonfix") @@ -42,13 +40,7 @@ def main(): for schema_name in schema_names: vm = dj.create_virtual_module(schema_name, schema_name) table_names = [ - ".".join( - [ - dj.utils.to_camel_case(s) - for s in tbl_name.strip("`").split("__") - if s - ] - ) + ".".join([dj.utils.to_camel_case(s) for s in tbl_name.strip("`").split("__") if s]) for tbl_name in vm.schema.list_tables() ] for table_name in table_names: diff --git a/aeon/dj_pipeline/streams.py b/aeon/dj_pipeline/streams.py index 225e7198..97cbc10d 100644 --- a/aeon/dj_pipeline/streams.py +++ b/aeon/dj_pipeline/streams.py @@ -190,9 +190,7 @@ def key_source(self): def make(self, key): """Load and insert RfidEvents data stream for a given chunk and RfidReader.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -200,10 +198,9 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(getattr(devices_schema, device_name), "RfidEvents") @@ -249,17 +246,14 @@ def key_source(self): + Chunk(s) that started after SpinnakerVideoSource install time for SpinnakerVideoSource that are not yet removed """ return ( - acquisition.Chunk - * SpinnakerVideoSource.join(SpinnakerVideoSource.RemovalTime, left=True) + acquisition.Chunk * SpinnakerVideoSource.join(SpinnakerVideoSource.RemovalTime, left=True) & "chunk_start >= spinnaker_video_source_install_time" & 'chunk_start < IFNULL(spinnaker_video_source_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert Video data stream for a given chunk and SpinnakerVideoSource.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -267,10 +261,9 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(getattr(devices_schema, device_name), "Video") @@ -315,17 +308,14 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert BeamBreak data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -333,10 +323,9 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(getattr(devices_schema, device_name), "BeamBreak") @@ -381,17 +370,14 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert DeliverPellet data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -399,10 +385,9 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(getattr(devices_schema, device_name), "DeliverPellet") @@ -449,17 +434,14 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert DepletionState data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -467,10 +449,9 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(getattr(devices_schema, device_name), "DepletionState") @@ -516,17 +497,14 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert Encoder data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -534,10 +512,9 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(getattr(devices_schema, device_name), "Encoder") @@ -582,17 +559,14 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert ManualDelivery data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -600,10 +574,9 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(getattr(devices_schema, device_name), "ManualDelivery") @@ -648,17 +621,14 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert MissedPellet data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -666,10 +636,9 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(getattr(devices_schema, device_name), "MissedPellet") @@ -714,17 +683,14 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert RetriedDelivery data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -732,10 +698,9 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(getattr(devices_schema, device_name), "RetriedDelivery") @@ -788,9 +753,7 @@ def key_source(self): def make(self, key): """Load and insert WeightFiltered data stream for a given chunk and WeightScale.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -798,10 +761,9 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(getattr(devices_schema, device_name), "WeightFiltered") @@ -854,9 +816,7 @@ def key_source(self): def make(self, key): """Load and insert WeightRaw data stream for a given chunk and WeightScale.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) @@ -864,10 +824,9 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(getattr(devices_schema, device_name), "WeightRaw") diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index 70812dcc..f6896c2d 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -86,9 +86,7 @@ def make(self, key): ) return elif len(animal_resp) > 1: - raise ValueError( - f"Found {len(animal_resp)} with eartag {eartag_or_id}, expect one" - ) + raise ValueError(f"Found {len(animal_resp)} with eartag {eartag_or_id}, expect one") else: animal_resp = animal_resp[0] @@ -187,21 +185,17 @@ def get_reference_weight(cls, subject_name): """Get the reference weight for the subject.""" subj_key = {"subject": subject_name} - food_restrict_query = ( - SubjectProcedure & subj_key & "procedure_name = 'R02 - food restriction'" - ) + food_restrict_query = SubjectProcedure & subj_key & "procedure_name = 'R02 - food restriction'" if food_restrict_query: - ref_date = food_restrict_query.fetch( - "procedure_date", order_by="procedure_date DESC", limit=1 - )[0] + ref_date = food_restrict_query.fetch("procedure_date", order_by="procedure_date DESC", limit=1)[ + 0 + ] else: ref_date = datetime.now(timezone.utc).date() weight_query = SubjectWeight & subj_key & f"weight_time < '{ref_date}'" ref_weight = ( - weight_query.fetch("weight", order_by="weight_time DESC", limit=1)[0] - if weight_query - else -1 + weight_query.fetch("weight", order_by="weight_time DESC", limit=1)[0] if weight_query else -1 ) entry = { @@ -259,9 +253,7 @@ def _auto_schedule(self): ): return - PyratIngestionTask.insert1( - {"pyrat_task_scheduled_time": next_task_schedule_time} - ) + PyratIngestionTask.insert1({"pyrat_task_scheduled_time": next_task_schedule_time}) def make(self, key): """Automatically import or update entries in the Subject table.""" @@ -269,15 +261,11 @@ def make(self, key): new_eartags = [] for responsible_id in lab.User.fetch("responsible_id"): # 1 - retrieve all animals from this user - animal_resp = get_pyrat_data( - endpoint="animals", params={"responsible_id": responsible_id} - ) + animal_resp = get_pyrat_data(endpoint="animals", params={"responsible_id": responsible_id}) for animal_entry in animal_resp: # 2 - find animal with comment - Project Aeon eartag_or_id = animal_entry["eartag_or_id"] - comment_resp = get_pyrat_data( - endpoint=f"animals/{eartag_or_id}/comments" - ) + comment_resp = get_pyrat_data(endpoint=f"animals/{eartag_or_id}/comments") for comment in comment_resp: if comment["attributes"]: first_attr = comment["attributes"][0] @@ -306,9 +294,7 @@ def make(self, key): { **key, "execution_time": execution_time, - "execution_duration": ( - completion_time - execution_time - ).total_seconds(), + "execution_duration": (completion_time - execution_time).total_seconds(), "new_pyrat_entry_count": new_entry_count, } ) @@ -354,9 +340,7 @@ def make(self, key): for cmt in comment_resp: cmt["subject"] = eartag_or_id cmt["attributes"] = json.dumps(cmt["attributes"], default=str) - SubjectComment.insert( - comment_resp, skip_duplicates=True, allow_direct_insert=True - ) + SubjectComment.insert(comment_resp, skip_duplicates=True, allow_direct_insert=True) weight_resp = get_pyrat_data(endpoint=f"animals/{eartag_or_id}/weights") SubjectWeight.insert( @@ -365,9 +349,7 @@ def make(self, key): allow_direct_insert=True, ) - procedure_resp = get_pyrat_data( - endpoint=f"animals/{eartag_or_id}/procedures" - ) + procedure_resp = get_pyrat_data(endpoint=f"animals/{eartag_or_id}/procedures") SubjectProcedure.insert( [{**v, "subject": eartag_or_id} for v in procedure_resp], skip_duplicates=True, @@ -382,9 +364,7 @@ def make(self, key): { **key, "execution_time": execution_time, - "execution_duration": ( - completion_time - execution_time - ).total_seconds(), + "execution_duration": (completion_time - execution_time).total_seconds(), } ) @@ -397,9 +377,7 @@ class CreatePyratIngestionTask(dj.Computed): def make(self, key): """Create one new PyratIngestionTask for every newly added users.""" - PyratIngestionTask.insert1( - {"pyrat_task_scheduled_time": datetime.now(timezone.utc)} - ) + PyratIngestionTask.insert1({"pyrat_task_scheduled_time": datetime.now(timezone.utc)}) time.sleep(1) self.insert1(key) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index e8333ad2..4969cf76 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -83,18 +83,14 @@ def insert_new_params( ): """Insert a new set of parameters for a given tracking method.""" if tracking_paramset_id is None: - tracking_paramset_id = ( - dj.U().aggr(cls, n="max(tracking_paramset_id)").fetch1("n") or 0 - ) + 1 + tracking_paramset_id = (dj.U().aggr(cls, n="max(tracking_paramset_id)").fetch1("n") or 0) + 1 param_dict = { "tracking_method": tracking_method, "tracking_paramset_id": tracking_paramset_id, "paramset_description": paramset_description, "params": params, - "param_set_hash": dict_to_uuid( - {**params, "tracking_method": tracking_method} - ), + "param_set_hash": dict_to_uuid({**params, "tracking_method": tracking_method}), } param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} @@ -157,9 +153,7 @@ def key_source(self): return ( acquisition.Chunk * ( - streams.SpinnakerVideoSource.join( - streams.SpinnakerVideoSource.RemovalTime, left=True - ) + streams.SpinnakerVideoSource.join(streams.SpinnakerVideoSource.RemovalTime, left=True) & "spinnaker_video_source_name='CameraTop'" ) * (TrackingParamSet & "tracking_paramset_id = 1") @@ -169,22 +163,17 @@ def key_source(self): def make(self, key): """Ingest SLEAP tracking data for a given chunk.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) - device_name = (streams.SpinnakerVideoSource & key).fetch1( - "spinnaker_video_source_name" - ) + device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name") devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(devices_schema, device_name).Pose @@ -196,9 +185,7 @@ def make(self, key): ) if not len(pose_data): - raise ValueError( - f"No SLEAP data found for {key['experiment_name']} - {device_name}" - ) + raise ValueError(f"No SLEAP data found for {key['experiment_name']} - {device_name}") # get identity names class_names = np.unique(pose_data.identity) @@ -231,9 +218,7 @@ def make(self, key): if part == anchor_part: identity_likelihood = part_position.identity_likelihood.values if isinstance(identity_likelihood[0], dict): - identity_likelihood = np.array( - [v[identity] for v in identity_likelihood] - ) + identity_likelihood = np.array([v[identity] for v in identity_likelihood]) pose_identity_entries.append( { @@ -278,9 +263,7 @@ def is_position_in_nest(position_df, nest_key, xcol="x", ycol="y") -> pd.Series: """Given the session key and the position data - arrays of x and y return an array of boolean indicating whether or not a position is inside the nest. """ - nest_vertices = list( - zip(*(lab.ArenaNest.Vertex & nest_key).fetch("vertex_x", "vertex_y")) - ) + nest_vertices = list(zip(*(lab.ArenaNest.Vertex & nest_key).fetch("vertex_x", "vertex_y"))) nest_path = matplotlib.path.Path(nest_vertices) position_df["in_nest"] = nest_path.contains_points(position_df[[xcol, ycol]]) return position_df["in_nest"] @@ -307,9 +290,7 @@ def _get_position( start_query = table & obj_restriction & start_restriction end_query = table & obj_restriction & end_restriction if not (start_query and end_query): - raise ValueError( - f"No position data found for {object_name} between {start} and {end}" - ) + raise ValueError(f"No position data found for {object_name} between {start} and {end}") time_restriction = ( f'{start_attr} >= "{min(start_query.fetch(start_attr))}"' @@ -317,14 +298,10 @@ def _get_position( ) # subject's position data in the time slice - fetched_data = (table & obj_restriction & time_restriction).fetch( - *fetch_attrs, order_by=start_attr - ) + fetched_data = (table & obj_restriction & time_restriction).fetch(*fetch_attrs, order_by=start_attr) if not len(fetched_data[0]): - raise ValueError( - f"No position data found for {object_name} between {start} and {end}" - ) + raise ValueError(f"No position data found for {object_name} between {start} and {end}") timestamp_attr = next(attr for attr in fetch_attrs if "timestamps" in attr) diff --git a/aeon/dj_pipeline/utils/load_metadata.py b/aeon/dj_pipeline/utils/load_metadata.py index 7c21ad8b..4c3b7315 100644 --- a/aeon/dj_pipeline/utils/load_metadata.py +++ b/aeon/dj_pipeline/utils/load_metadata.py @@ -42,9 +42,7 @@ def insert_stream_types(): else: # If the existed stream type does not have the same name: # human error, trying to add the same content with different name - raise dj.DataJointError( - f"The specified stream type already exists - name: {pname}" - ) + raise dj.DataJointError(f"The specified stream type already exists - name: {pname}") else: streams.StreamType.insert1(entry) @@ -57,9 +55,7 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): streams = dj.VirtualModule("streams", streams_maker.schema_name) device_info: dict[dict] = get_device_info(devices_schema) - device_type_mapper, device_sn = get_device_mapper( - devices_schema, metadata_yml_filepath - ) + device_type_mapper, device_sn = get_device_mapper(devices_schema, metadata_yml_filepath) # Add device type to device_info. Only add if device types that are defined in Metadata.yml device_info = { @@ -96,8 +92,7 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): {"device_type": device_type, "stream_type": stream_type} for device_type, stream_list in device_stream_map.items() for stream_type in stream_list - if not streams.DeviceType.Stream - & {"device_type": device_type, "stream_type": stream_type} + if not streams.DeviceType.Stream & {"device_type": device_type, "stream_type": stream_type} ] new_devices = [ @@ -106,8 +101,7 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): "device_type": device_config["device_type"], } for device_name, device_config in device_info.items() - if device_sn[device_name] - and not streams.Device & {"device_serial_number": device_sn[device_name]} + if device_sn[device_name] and not streams.Device & {"device_serial_number": device_sn[device_name]} ] # Insert new entries. @@ -125,9 +119,7 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): streams.Device.insert(new_devices) -def extract_epoch_config( - experiment_name: str, devices_schema: DotMap, metadata_yml_filepath: str -) -> dict: +def extract_epoch_config(experiment_name: str, devices_schema: DotMap, metadata_yml_filepath: str) -> dict: """Parse experiment metadata YAML file and extract epoch configuration. Args: @@ -139,9 +131,7 @@ def extract_epoch_config( dict: epoch_config [dict] """ metadata_yml_filepath = pathlib.Path(metadata_yml_filepath) - epoch_start = datetime.datetime.strptime( - metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S" - ) + epoch_start = datetime.datetime.strptime(metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S") epoch_config: dict = ( io_api.load( metadata_yml_filepath.parent.as_posix(), @@ -156,21 +146,15 @@ def extract_epoch_config( commit = epoch_config["metadata"]["Revision"] if not commit: - raise ValueError( - f'Neither "Commit" nor "Revision" found in {metadata_yml_filepath}' - ) + raise ValueError(f'Neither "Commit" nor "Revision" found in {metadata_yml_filepath}') devices: list[dict] = json.loads( - json.dumps( - epoch_config["metadata"]["Devices"], default=lambda x: x.__dict__, indent=4 - ) + json.dumps(epoch_config["metadata"]["Devices"], default=lambda x: x.__dict__, indent=4) ) # Maintain backward compatibility - In exp02, it is a list of dict. From presocial onward, it's a dict of dict. if isinstance(devices, list): - devices: dict = { - d.pop("Name"): d for d in devices - } # {deivce_name: device_config} + devices: dict = {d.pop("Name"): d for d in devices} # {deivce_name: device_config} return { "experiment_name": experiment_name, @@ -194,17 +178,15 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath experiment_key = {"experiment_name": experiment_name} metadata_yml_filepath = pathlib.Path(metadata_yml_filepath) - epoch_config = extract_epoch_config( - experiment_name, devices_schema, metadata_yml_filepath - ) + epoch_config = extract_epoch_config(experiment_name, devices_schema, metadata_yml_filepath) previous_epoch = (acquisition.Experiment & experiment_key).aggr( acquisition.Epoch & f'epoch_start < "{epoch_config["epoch_start"]}"', epoch_start="MAX(epoch_start)", ) - if len(acquisition.EpochConfig.Meta & previous_epoch) and epoch_config[ - "commit" - ] == (acquisition.EpochConfig.Meta & previous_epoch).fetch1("commit"): + if len(acquisition.EpochConfig.Meta & previous_epoch) and epoch_config["commit"] == ( + acquisition.EpochConfig.Meta & previous_epoch + ).fetch1("commit"): # if identical commit -> no changes return @@ -236,9 +218,7 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath table_entry = { "experiment_name": experiment_name, **device_key, - f"{dj.utils.from_camel_case(table.__name__)}_install_time": epoch_config[ - "epoch_start" - ], + f"{dj.utils.from_camel_case(table.__name__)}_install_time": epoch_config["epoch_start"], f"{dj.utils.from_camel_case(table.__name__)}_name": device_name, } @@ -255,21 +235,15 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath { **table_entry, "attribute_name": "SamplingFrequency", - "attribute_value": video_controller[ - device_config["TriggerFrequency"] - ], + "attribute_value": video_controller[device_config["TriggerFrequency"]], } ) """Check if this device is currently installed. If the same device serial number is currently installed check for any changes in configuration. If not, skip this""" - current_device_query = ( - table - table.RemovalTime & experiment_key & device_key - ) + current_device_query = table - table.RemovalTime & experiment_key & device_key if current_device_query: - current_device_config: list[dict] = ( - table.Attribute & current_device_query - ).fetch( + current_device_config: list[dict] = (table.Attribute & current_device_query).fetch( "experiment_name", "device_serial_number", "attribute_name", @@ -277,11 +251,7 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath as_dict=True, ) new_device_config: list[dict] = [ - { - k: v - for k, v in entry.items() - if dj.utils.from_camel_case(table.__name__) not in k - } + {k: v for k, v in entry.items() if dj.utils.from_camel_case(table.__name__) not in k} for entry in table_attribute_entry ] @@ -291,10 +261,7 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath for config in current_device_config } ) == dict_to_uuid( - { - config["attribute_name"]: config["attribute_value"] - for config in new_device_config - } + {config["attribute_name"]: config["attribute_value"] for config in new_device_config} ): # Skip if none of the configuration has changed. continue @@ -412,14 +379,10 @@ def _get_class_path(obj): "aeon.schema.social", ]: device_info[device_name]["stream_type"].append(stream_type) - device_info[device_name]["stream_reader"].append( - _get_class_path(stream_obj) - ) + device_info[device_name]["stream_reader"].append(_get_class_path(stream_obj)) required_args = [ - k - for k in inspect.signature(stream_obj.__init__).parameters - if k != "self" + k for k in inspect.signature(stream_obj.__init__).parameters if k != "self" ] pattern = schema_dict[device_name][stream_type].get("pattern") schema_dict[device_name][stream_type]["pattern"] = pattern.replace( @@ -427,35 +390,23 @@ def _get_class_path(obj): ) kwargs = { - k: v - for k, v in schema_dict[device_name][stream_type].items() - if k in required_args + k: v for k, v in schema_dict[device_name][stream_type].items() if k in required_args } device_info[device_name]["stream_reader_kwargs"].append(kwargs) # Add hash device_info[device_name]["stream_hash"].append( - dict_to_uuid( - {**kwargs, "stream_reader": _get_class_path(stream_obj)} - ) + dict_to_uuid({**kwargs, "stream_reader": _get_class_path(stream_obj)}) ) else: stream_type = device.__class__.__name__ device_info[device_name]["stream_type"].append(stream_type) device_info[device_name]["stream_reader"].append(_get_class_path(device)) - required_args = { - k: None - for k in inspect.signature(device.__init__).parameters - if k != "self" - } + required_args = {k: None for k in inspect.signature(device.__init__).parameters if k != "self"} pattern = schema_dict[device_name].get("pattern") - schema_dict[device_name]["pattern"] = pattern.replace( - device_name, "{pattern}" - ) + schema_dict[device_name]["pattern"] = pattern.replace(device_name, "{pattern}") - kwargs = { - k: v for k, v in schema_dict[device_name].items() if k in required_args - } + kwargs = {k: v for k, v in schema_dict[device_name].items() if k in required_args} device_info[device_name]["stream_reader_kwargs"].append(kwargs) # Add hash device_info[device_name]["stream_hash"].append( @@ -545,9 +496,7 @@ def ingest_epoch_metadata_octagon(experiment_name, metadata_yml_filepath): ("Wall8", "Wall"), ] - epoch_start = datetime.datetime.strptime( - metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S" - ) + epoch_start = datetime.datetime.strptime(metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S") for device_idx, (device_name, device_type) in enumerate(oct01_devices): device_sn = f"oct01_{device_idx}" @@ -556,13 +505,8 @@ def ingest_epoch_metadata_octagon(experiment_name, metadata_yml_filepath): skip_duplicates=True, ) experiment_table = getattr(streams, f"Experiment{device_type}") - if not ( - experiment_table - & {"experiment_name": experiment_name, "device_serial_number": device_sn} - ): - experiment_table.insert1( - (experiment_name, device_sn, epoch_start, device_name) - ) + if not (experiment_table & {"experiment_name": experiment_name, "device_serial_number": device_sn}): + experiment_table.insert1((experiment_name, device_sn, epoch_start, device_name)) # endregion diff --git a/aeon/dj_pipeline/utils/paths.py b/aeon/dj_pipeline/utils/paths.py index 1bdc7b7f..1df21e64 100644 --- a/aeon/dj_pipeline/utils/paths.py +++ b/aeon/dj_pipeline/utils/paths.py @@ -67,6 +67,5 @@ def find_root_directory( except StopIteration: raise FileNotFoundError( - f"No valid root directory found (from {root_directories})" - f" for {full_path}" + f"No valid root directory found (from {root_directories})" f" for {full_path}" ) diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index 28be5b10..a8826455 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -36,17 +36,13 @@ def plot_reward_rate_differences(subject_keys): """ subj_names, sess_starts, rate_timestamps, rate_diffs = ( analysis.InArenaRewardRate & subject_keys - ).fetch( - "subject", "in_arena_start", "pellet_rate_timestamps", "patch2_patch1_rate_diff" - ) + ).fetch("subject", "in_arena_start", "pellet_rate_timestamps", "patch2_patch1_rate_diff") nSessions = len(sess_starts) longest_rateDiff = np.max([len(t) for t in rate_timestamps]) max_session_idx = np.argmax([len(t) for t in rate_timestamps]) - max_session_elapsed_times = ( - rate_timestamps[max_session_idx] - rate_timestamps[max_session_idx][0] - ) + max_session_elapsed_times = rate_timestamps[max_session_idx] - rate_timestamps[max_session_idx][0] x_labels = [t.total_seconds() / 60 for t in max_session_elapsed_times] y_labels = [ @@ -91,15 +87,12 @@ def plot_wheel_travelled_distance(session_keys): ``` """ distance_travelled_query = ( - analysis.InArenaSummary.FoodPatch - * acquisition.ExperimentFoodPatch.proj("food_patch_description") + analysis.InArenaSummary.FoodPatch * acquisition.ExperimentFoodPatch.proj("food_patch_description") & session_keys ) distance_travelled_df = ( - distance_travelled_query.proj( - "food_patch_description", "wheel_distance_travelled" - ) + distance_travelled_query.proj("food_patch_description", "wheel_distance_travelled") .fetch(format="frame") .reset_index() ) @@ -161,8 +154,7 @@ def plot_average_time_distribution(session_keys): & session_keys ) .aggr( - analysis.InArenaTimeDistribution.FoodPatch - * acquisition.ExperimentFoodPatch, + analysis.InArenaTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch, avg_in_patch="AVG(time_fraction_in_patch)", ) .fetch("subject", "food_patch_description", "avg_in_patch") @@ -240,15 +232,11 @@ def plot_visit_daily_summary( .reset_index() ) else: - visit_per_day_df = ( - (VisitSummary & visit_key).fetch(format="frame").reset_index() - ) + visit_per_day_df = (VisitSummary & visit_key).fetch(format="frame").reset_index() if not attr.startswith("total"): attr = "total_" + attr - visit_per_day_df["day"] = ( - visit_per_day_df["visit_date"] - visit_per_day_df["visit_date"].min() - ) + visit_per_day_df["day"] = visit_per_day_df["visit_date"] - visit_per_day_df["visit_date"].min() visit_per_day_df["day"] = visit_per_day_df["day"].dt.days fig = px.bar( @@ -339,14 +327,10 @@ def plot_foraging_bouts_count( else [foraging_bouts["bout_start"].dt.floor("D")] ) - foraging_bouts_count = ( - foraging_bouts.groupby(group_by_attrs).size().reset_index(name="count") - ) + foraging_bouts_count = foraging_bouts.groupby(group_by_attrs).size().reset_index(name="count") visit_start = (VisitEnd & visit_key).fetch1("visit_start") - foraging_bouts_count["day"] = ( - foraging_bouts_count["bout_start"].dt.date - visit_start.date() - ).dt.days + foraging_bouts_count["day"] = (foraging_bouts_count["bout_start"].dt.date - visit_start.date()).dt.days fig = px.bar( foraging_bouts_count, @@ -360,10 +344,7 @@ def plot_foraging_bouts_count( width=700, height=400, template="simple_white", - title=visit_key["subject"] - + "
Foraging bouts: count (freq='" - + freq - + "')", + title=visit_key["subject"] + "
Foraging bouts: count (freq='" + freq + "')", ) fig.update_layout( @@ -435,9 +416,7 @@ def plot_foraging_bouts_distribution( fig = go.Figure() if per_food_patch: - patch_names = (acquisition.ExperimentFoodPatch & visit_key).fetch( - "food_patch_description" - ) + patch_names = (acquisition.ExperimentFoodPatch & visit_key).fetch("food_patch_description") for patch in patch_names: bouts = foraging_bouts[foraging_bouts["food_patch_description"] == patch] fig.add_trace( @@ -464,9 +443,7 @@ def plot_foraging_bouts_distribution( ) fig.update_layout( - title_text=visit_key["subject"] - + "
Foraging bouts: " - + attr.replace("_", " "), + title_text=visit_key["subject"] + "
Foraging bouts: " + attr.replace("_", " "), xaxis_title="date", yaxis_title=attr.replace("_", " "), violingap=0, @@ -504,17 +481,11 @@ def plot_visit_time_distribution(visit_key, freq="D"): region = _get_region_data(visit_key) # Compute time spent per region - time_spent = ( - region.groupby([region.index.floor(freq), "region"]) - .size() - .reset_index(name="count") + time_spent = region.groupby([region.index.floor(freq), "region"]).size().reset_index(name="count") + time_spent["time_fraction"] = time_spent["count"] / time_spent.groupby("timestamps")["count"].transform( + "sum" ) - time_spent["time_fraction"] = time_spent["count"] / time_spent.groupby( - "timestamps" - )["count"].transform("sum") - time_spent["day"] = ( - time_spent["timestamps"] - time_spent["timestamps"].min() - ).dt.days + time_spent["day"] = (time_spent["timestamps"] - time_spent["timestamps"].min()).dt.days fig = px.bar( time_spent, @@ -526,10 +497,7 @@ def plot_visit_time_distribution(visit_key, freq="D"): "time_fraction": "time fraction", "timestamps": "date" if freq == "D" else "time", }, - title=visit_key["subject"] - + "
Fraction of time spent in each region (freq='" - + freq - + "')", + title=visit_key["subject"] + "
Fraction of time spent in each region (freq='" + freq + "')", width=700, height=400, template="simple_white", @@ -573,9 +541,7 @@ def _get_region_data(visit_key, attrs=None): for attr in attrs: if attr == "in_nest": # Nest in_nest = np.concatenate( - (VisitTimeDistribution.Nest & visit_key).fetch( - attr, order_by="visit_date" - ) + (VisitTimeDistribution.Nest & visit_key).fetch(attr, order_by="visit_date") ) region = pd.concat( [ @@ -590,16 +556,14 @@ def _get_region_data(visit_key, attrs=None): elif attr == "in_patch": # Food patch # Find all patches patches = np.unique( - ( - VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch - & visit_key - ).fetch("food_patch_description") + (VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & visit_key).fetch( + "food_patch_description" + ) ) for patch in patches: in_patch = np.concatenate( ( - VisitTimeDistribution.FoodPatch - * acquisition.ExperimentFoodPatch + VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & visit_key & f"food_patch_description = '{patch}'" ).fetch("in_patch", order_by="visit_date") @@ -631,19 +595,13 @@ def _get_region_data(visit_key, attrs=None): region = region.sort_index().rename_axis("timestamps") # Exclude data during maintenance - maintenance_period = get_maintenance_periods( - visit_key["experiment_name"], visit_start, visit_end - ) - region = filter_out_maintenance_periods( - region, maintenance_period, visit_end, dropna=True - ) + maintenance_period = get_maintenance_periods(visit_key["experiment_name"], visit_start, visit_end) + region = filter_out_maintenance_periods(region, maintenance_period, visit_end, dropna=True) return region -def plot_weight_patch_data( - visit_key, freq="H", smooth_weight=True, min_weight=0, max_weight=35 -): +def plot_weight_patch_data(visit_key, freq="H", smooth_weight=True, min_weight=0, max_weight=35): """Plot subject weight and patch data (pellet trigger count) per visit. Args: @@ -660,9 +618,7 @@ def plot_weight_patch_data( >>> fig = plot_weight_patch_data(visit_key, freq="H", smooth_weight=True) >>> fig = plot_weight_patch_data(visit_key, freq="D") """ - subject_weight = _get_filtered_subject_weight( - visit_key, smooth_weight, min_weight, max_weight - ) + subject_weight = _get_filtered_subject_weight(visit_key, smooth_weight, min_weight, max_weight) # Count pellet trigger per patch per day/hour/... patch = _get_patch_data(visit_key) @@ -690,12 +646,8 @@ def plot_weight_patch_data( for p in patch_names: fig.add_trace( go.Bar( - x=patch_summary[patch_summary["food_patch_description"] == p][ - "event_time" - ], - y=patch_summary[patch_summary["food_patch_description"] == p][ - "event_type" - ], + x=patch_summary[patch_summary["food_patch_description"] == p]["event_time"], + y=patch_summary[patch_summary["food_patch_description"] == p]["event_type"], name=p, ), secondary_y=False, @@ -720,10 +672,7 @@ def plot_weight_patch_data( fig.update_layout( barmode="stack", hovermode="x", - title_text=visit_key["subject"] - + "
Weight and pellet count (freq='" - + freq - + "')", + title_text=visit_key["subject"] + "
Weight and pellet count (freq='" + freq + "')", xaxis_title="date" if freq == "D" else "time", yaxis={"title": "pellet count"}, yaxis2={"title": "weight"}, @@ -744,9 +693,7 @@ def plot_weight_patch_data( return fig -def _get_filtered_subject_weight( - visit_key, smooth_weight=True, min_weight=0, max_weight=35 -): +def _get_filtered_subject_weight(visit_key, smooth_weight=True, min_weight=0, max_weight=35): """Retrieve subject weight from WeightMeasurementFiltered table. Args: @@ -785,9 +732,7 @@ def _get_filtered_subject_weight( subject_weight = subject_weight.loc[visit_start:visit_end] # Exclude data during maintenance - maintenance_period = get_maintenance_periods( - visit_key["experiment_name"], visit_start, visit_end - ) + maintenance_period = get_maintenance_periods(visit_key["experiment_name"], visit_start, visit_end) subject_weight = filter_out_maintenance_periods( subject_weight, maintenance_period, visit_end, dropna=True ) @@ -804,9 +749,7 @@ def _get_filtered_subject_weight( subject_weight = subject_weight.resample("1T").mean().dropna() if smooth_weight: - subject_weight["weight_subject"] = savgol_filter( - subject_weight["weight_subject"], 10, 3 - ) + subject_weight["weight_subject"] = savgol_filter(subject_weight["weight_subject"], 10, 3) return subject_weight @@ -827,9 +770,7 @@ def _get_patch_data(visit_key): ( dj.U("event_time", "event_type", "food_patch_description") & ( - acquisition.FoodPatchEvent - * acquisition.EventType - * acquisition.ExperimentFoodPatch + acquisition.FoodPatchEvent * acquisition.EventType * acquisition.ExperimentFoodPatch & f'event_time BETWEEN "{visit_start}" AND "{visit_end}"' & 'event_type = "TriggerPellet"' ) @@ -842,11 +783,7 @@ def _get_patch_data(visit_key): # TODO: handle repeat attempts (pellet delivery trigger and beam break) # Exclude data during maintenance - maintenance_period = get_maintenance_periods( - visit_key["experiment_name"], visit_start, visit_end - ) - patch = filter_out_maintenance_periods( - patch, maintenance_period, visit_end, dropna=True - ) + maintenance_period = get_maintenance_periods(visit_key["experiment_name"], visit_start, visit_end) + patch = filter_out_maintenance_periods(patch, maintenance_period, visit_end, dropna=True) return patch diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index c653d918..e3d3259f 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -23,12 +23,17 @@ class StreamType(dj.Lookup): - """Catalog of all steam types for the different device types used across Project Aeon. One StreamType corresponds to one reader class in `aeon.io.reader`. The combination of `stream_reader` and `stream_reader_kwargs` should fully specify the data loading routine for a particular device, using the `aeon.io.utils`.""" + """Catalog of all steam types for the different device types used across + Project Aeon. One StreamType corresponds to one reader class in `aeon.io.reader`. + The combination of `stream_reader` and `stream_reader_kwargs` should fully + specify the data loading routine for a particular device, using the `aeon.io.utils`. + """ definition = """ # Catalog of all stream types used across Project Aeon stream_type : varchar(20) --- - stream_reader : varchar(256) # name of the reader class found in `aeon_mecha` package (e.g. aeon.io.reader.Video) + stream_reader : varchar(256) # name of the reader class found in `aeon_mecha` + # package (e.g. aeon.io.reader.Video) stream_reader_kwargs : longblob # keyword arguments to instantiate the reader class stream_description='': varchar(256) stream_hash : uuid # hash of dict(stream_reader_kwargs, stream_reader=stream_reader) @@ -70,16 +75,20 @@ def get_device_template(device_type: str): class ExperimentDevice(dj.Manual): definition = f""" - # {device_title} placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-{aeon.__version__}) + # {device_title} placement and operation for a particular time period, + # at a certain location, for a given experiment (auto-generated with + # aeon_mecha-{aeon.__version__}) -> acquisition.Experiment -> Device - {device_type}_install_time : datetime(6) # time of the {device_type} placed and started operation at this position + {device_type}_install_time : datetime(6) # time of the {device_type} placed + # and started operation at this position --- {device_type}_name : varchar(36) """ class Attribute(dj.Part): - definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device + definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) + # associated with this experimental device -> master attribute_name : varchar(32) --- @@ -122,7 +131,9 @@ def get_device_stream_template(device_type: str, stream_type: str, streams_modul stream = reader(**stream_detail["stream_reader_kwargs"]) - table_definition = f""" # Raw per-chunk {stream_type} data stream from {device_type} (auto-generated with aeon_mecha-{aeon.__version__}) + table_definition = f""" + # Raw per-chunk {stream_type} data stream from {device_type} + # (auto-generated with aeon_mecha-{aeon.__version__}) -> {device_type} -> acquisition.Chunk --- @@ -146,13 +157,17 @@ def key_source(self): + Chunk(s) that started after device_type install time and ended before device_type remove time + Chunk(s) that started after device_type install time for device_type that are not yet removed """ - return ( + + key_source_query = ( acquisition.Chunk * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) & f"chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time" - & f'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, "2200-01-01")' + & f'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time,\ + "2200-01-01")' ) + return key_source_query + def make(self, key): """Load and insert the data for the DeviceDataStream table.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( @@ -226,6 +241,10 @@ def main(create_tables=True): device_table_def = inspect.getsource(table_class).lstrip() full_def = "@schema \n" + device_table_def + "\n\n" f.write(full_def) + else: + raise FileExistsError( + f"File {_STREAMS_MODULE_FILE} already exists. Please remove it and try again." + ) streams = importlib.import_module("aeon.dj_pipeline.streams") @@ -279,9 +298,19 @@ def main(create_tables=True): replacements = { "DeviceDataStream": f"{device_type}{stream_type}", "ExperimentDevice": device_type, - 'f"chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time"': f"'chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time'", - """f'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, "2200-01-01")'""": f"""'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, "2200-01-01")'""", - 'f"{dj.utils.from_camel_case(device_type)}_name"': f"'{dj.utils.from_camel_case(device_type)}_name'", + 'f"chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time"': ( + f"'chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time'" + ), + """f'chunk_start < + IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, + "2200-01-01")'""": ( + f"""'chunk_start < + IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, + "2200-01-01")'""" + ), + 'f"{dj.utils.from_camel_case(device_type)}_name"': ( + f"'{dj.utils.from_camel_case(device_type)}_name'" + ), "{device_type}": device_type, "{stream_type}": stream_type, "{aeon.__version__}": aeon.__version__, diff --git a/aeon/dj_pipeline/utils/video.py b/aeon/dj_pipeline/utils/video.py index 2e728171..b9bf40e8 100644 --- a/aeon/dj_pipeline/utils/video.py +++ b/aeon/dj_pipeline/utils/video.py @@ -25,9 +25,7 @@ def retrieve_video_frames( """Retrive video trames from the raw data directory.""" raw_data_dir = Path(raw_data_dir) if not raw_data_dir.exists(): - raise FileNotFoundError( - f"The specified raw data directory does not exist: {raw_data_dir}" - ) + raise FileNotFoundError(f"The specified raw data directory does not exist: {raw_data_dir}") # Load video data videodata = io_api.load( root=raw_data_dir.as_posix(), diff --git a/aeon/io/api.py b/aeon/io/api.py index 5fe532c3..2c1814e2 100644 --- a/aeon/io/api.py +++ b/aeon/io/api.py @@ -27,9 +27,7 @@ def chunk(time): return pd.to_datetime(time.dt.date) + pd.to_timedelta(hour, "h") else: hour = CHUNK_DURATION * (time.hour // CHUNK_DURATION) - return pd.to_datetime( - datetime.datetime.combine(time.date(), datetime.time(hour=hour)) - ) + return pd.to_datetime(datetime.datetime.combine(time.date(), datetime.time(hour=hour))) def chunk_range(start, end): @@ -39,9 +37,7 @@ def chunk_range(start, end): :param datetime end: The right bound of the time range. :return: A DatetimeIndex representing the acquisition chunk range. """ - return pd.date_range( - chunk(start), chunk(end), freq=pd.DateOffset(hours=CHUNK_DURATION) - ) + return pd.date_range(chunk(start), chunk(end), freq=pd.DateOffset(hours=CHUNK_DURATION)) def chunk_key(file): @@ -53,9 +49,7 @@ def chunk_key(file): except ValueError: epoch = file.parts[-2] date_str, time_str = epoch.split("T") - return epoch, datetime.datetime.fromisoformat( - date_str + "T" + time_str.replace("-", ":") - ) + return epoch, datetime.datetime.fromisoformat(date_str + "T" + time_str.replace("-", ":")) def _set_index(data): @@ -68,9 +62,7 @@ def _empty(columns): return pd.DataFrame(columns=columns, index=pd.DatetimeIndex([], name="time")) -def load( - root, reader, start=None, end=None, time=None, tolerance=None, epoch=None, **kwargs -): +def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=None, **kwargs): """Extracts chunk data from the root path of an Aeon dataset. Reads all chunk data using the specified data stream reader. A subset of the data can be loaded @@ -97,9 +89,7 @@ def load( fileset = { chunk_key(fname): fname for path in root - for fname in Path(path).glob( - f"{epoch_pattern}/**/{reader.pattern}.{reader.extension}" - ) + for fname in Path(path).glob(f"{epoch_pattern}/**/{reader.pattern}.{reader.extension}") } files = sorted(fileset.items()) @@ -144,9 +134,7 @@ def load( if start is not None or end is not None: chunk_start = chunk(start) if start is not None else pd.Timestamp.min chunk_end = chunk(end) if end is not None else pd.Timestamp.max - files = list( - filter(lambda item: chunk_start <= chunk(item[0][1]) <= chunk_end, files) - ) + files = list(filter(lambda item: chunk_start <= chunk(item[0][1]) <= chunk_end, files)) if len(files) == 0: return _empty(reader.columns) diff --git a/aeon/io/reader.py b/aeon/io/reader.py index 09a7a9e0..144f6d3c 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -70,12 +70,8 @@ def read(self, file): payloadtype = _payloadtypes[data[4] & ~0x10] elementsize = payloadtype.itemsize payloadshape = (length, payloadsize // elementsize) - seconds = np.ndarray( - length, dtype=np.uint32, buffer=data, offset=5, strides=stride - ) - ticks = np.ndarray( - length, dtype=np.uint16, buffer=data, offset=9, strides=stride - ) + seconds = np.ndarray(length, dtype=np.uint32, buffer=data, offset=5, strides=stride) + ticks = np.ndarray(length, dtype=np.uint16, buffer=data, offset=9, strides=stride) seconds = ticks * _SECONDS_PER_TICK + seconds payload = np.ndarray( payloadshape, @@ -86,9 +82,7 @@ def read(self, file): ) if self.columns is not None and payloadshape[1] < len(self.columns): - data = pd.DataFrame( - payload, index=seconds, columns=self.columns[: payloadshape[1]] - ) + data = pd.DataFrame(payload, index=seconds, columns=self.columns[: payloadshape[1]]) data[self.columns[payloadshape[1] :]] = math.nan return data else: @@ -117,17 +111,13 @@ class Metadata(Reader): def __init__(self, pattern="Metadata"): """Initialize the object with the specified pattern.""" - super().__init__( - pattern, columns=["workflow", "commit", "metadata"], extension="yml" - ) + super().__init__(pattern, columns=["workflow", "commit", "metadata"], extension="yml") def read(self, file): """Returns metadata for the specified epoch.""" epoch_str = file.parts[-2] date_str, time_str = epoch_str.split("T") - time = datetime.datetime.fromisoformat( - date_str + "T" + time_str.replace("-", ":") - ) + time = datetime.datetime.fromisoformat(date_str + "T" + time_str.replace("-", ":")) with open(file) as fp: metadata = json.load(fp) workflow = metadata.pop("Workflow") @@ -160,7 +150,8 @@ def read(self, file): class JsonList(Reader): """ - Extracts data from json list (.jsonl) files, where the key "seconds" stores the Aeon timestamp, in seconds. + Extracts data from json list (.jsonl) files, + where the key "seconds" stores the Aeon timestamp, in seconds. """ def __init__(self, pattern, columns=(), root_key="value", extension="jsonl"): @@ -269,9 +260,7 @@ class Position(Harp): def __init__(self, pattern): """Initialize the object with a specified pattern and columns.""" - super().__init__( - pattern, columns=["x", "y", "angle", "major", "minor", "area", "id"] - ) + super().__init__(pattern, columns=["x", "y", "angle", "major", "minor", "area", "id"]) class BitmaskEvent(Harp): @@ -330,9 +319,7 @@ class Video(Csv): def __init__(self, pattern): """Initialize the object with a specified pattern.""" - super().__init__( - pattern, columns=["hw_counter", "hw_timestamp", "_frame", "_path", "_epoch"] - ) + super().__init__(pattern, columns=["hw_counter", "hw_timestamp", "_frame", "_path", "_epoch"]) self._rawcolumns = ["time"] + self.columns[0:2] def read(self, file): @@ -346,7 +333,8 @@ def read(self, file): class Pose(Harp): - """Reader for Harp-binarized tracking data given a model that outputs id, parts, and likelihoods. + """Reader for Harp-binarized tracking data given a model + that outputs id, parts, and likelihoods. Columns: class (int): Int ID of a subject in the environment. @@ -357,9 +345,7 @@ class (int): Int ID of a subject in the environment. y (float): Y-coordinate of the bodypart. """ - def __init__( - self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/processed" - ): + def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/processed"): """Pose reader constructor.""" # `pattern` for this reader should typically be '_*' super().__init__(pattern, columns=None) @@ -398,16 +384,10 @@ def read(self, file: Path) -> pd.DataFrame: # Drop any repeat parts. unique_parts, unique_idxs = np.unique(parts, return_index=True) repeat_idxs = np.setdiff1d(np.arange(len(parts)), unique_idxs) - if ( - repeat_idxs - ): # drop x, y, and likelihood cols for repeat parts (skip first 5 cols) + if repeat_idxs: # drop x, y, and likelihood cols for repeat parts (skip first 5 cols) init_rep_part_col_idx = (repeat_idxs - 1) * 3 + 5 - rep_part_col_idxs = np.concatenate( - [np.arange(i, i + 3) for i in init_rep_part_col_idx] - ) - keep_part_col_idxs = np.setdiff1d( - np.arange(len(data.columns)), rep_part_col_idxs - ) + rep_part_col_idxs = np.concatenate([np.arange(i, i + 3) for i in init_rep_part_col_idx]) + keep_part_col_idxs = np.setdiff1d(np.arange(len(data.columns)), rep_part_col_idxs) data = data.iloc[:, keep_part_col_idxs] parts = unique_parts @@ -415,25 +395,18 @@ def read(self, file: Path) -> pd.DataFrame: data = self.class_int2str(data, config_file) n_parts = len(parts) part_data_list = [pd.DataFrame()] * n_parts - new_columns = pd.Series( - ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"] - ) + new_columns = pd.Series(["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]) new_data = pd.DataFrame(columns=new_columns) for i, part in enumerate(parts): part_columns = ( - columns[0 : (len(identities) + 1)] - if bonsai_sleap_v == BONSAI_SLEAP_V3 - else columns[0:2] + columns[0 : (len(identities) + 1)] if bonsai_sleap_v == BONSAI_SLEAP_V3 else columns[0:2] ) part_columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"]) part_data = pd.DataFrame(data[part_columns]) if bonsai_sleap_v == BONSAI_SLEAP_V3: # combine all identity_likelihood cols into a single col as dict part_data["identity_likelihood"] = part_data.apply( - lambda row: { - identity: row[f"{identity}_likelihood"] - for identity in identities - }, + lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, axis=1, ) part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True) @@ -498,14 +471,10 @@ def class_int2str(data: pd.DataFrame, config_file: Path) -> pd.DataFrame: return data @classmethod - def get_config_file( - cls, config_file_dir: Path, config_file_names: None | list[str] = None - ) -> Path: + def get_config_file(cls, config_file_dir: Path, config_file_names: None | list[str] = None) -> Path: """Returns the config file from a model's config directory.""" if config_file_names is None: - config_file_names = [ - "confmap_config.json" - ] # SLEAP (add for other trackers to this list) + config_file_names = ["confmap_config.json"] # SLEAP (add for other trackers to this list) config_file = None for f in config_file_names: if (config_file_dir / f).exists(): @@ -524,21 +493,14 @@ def from_dict(data, pattern=None): return globals()[reader_type](pattern=pattern, **kwargs) return DotMap( - { - k: from_dict(v, f"{pattern}_{k}" if pattern is not None else k) - for k, v in data.items() - } + {k: from_dict(v, f"{pattern}_{k}" if pattern is not None else k) for k, v in data.items()} ) def to_dict(dotmap): """Converts a DotMap object to a dictionary.""" if isinstance(dotmap, Reader): - kwargs = { - k: v - for k, v in vars(dotmap).items() - if k not in ["pattern"] and not k.startswith("_") - } + kwargs = {k: v for k, v in vars(dotmap).items() if k not in ["pattern"] and not k.startswith("_")} kwargs["type"] = type(dotmap).__name__ return kwargs return {k: to_dict(v) for k, v in dotmap.items()} diff --git a/aeon/io/video.py b/aeon/io/video.py index dbdc173b..658379e7 100644 --- a/aeon/io/video.py +++ b/aeon/io/video.py @@ -29,9 +29,7 @@ def frames(data): index = frameidx success, frame = capture.read() if not success: - raise ValueError( - f'Unable to read frame {frameidx} from video path "{path}".' - ) + raise ValueError(f'Unable to read frame {frameidx} from video path "{path}".') yield frame index = index + 1 finally: @@ -54,9 +52,7 @@ def export(frames, file, fps, fourcc=None): if writer is None: if fourcc is None: fourcc = cv2.VideoWriter_fourcc("m", "p", "4", "v") # type: ignore - writer = cv2.VideoWriter( - file, fourcc, fps, (frame.shape[1], frame.shape[0]) - ) + writer = cv2.VideoWriter(file, fourcc, fps, (frame.shape[1], frame.shape[0])) writer.write(frame) finally: if writer is not None: diff --git a/aeon/schema/foraging.py b/aeon/schema/foraging.py index 05ce480a..900af684 100644 --- a/aeon/schema/foraging.py +++ b/aeon/schema/foraging.py @@ -25,9 +25,7 @@ def __init__(self, pattern): def read(self, file): data = super().read(file) - categorical = pd.Categorical( - data.region, categories=range(len(Area._member_names_)) - ) + categorical = pd.Categorical(data.region, categories=range(len(Area._member_names_))) data["region"] = categorical.rename_categories(Area._member_names_) return data @@ -89,9 +87,7 @@ class BeamBreak(Stream): def __init__(self, pattern): """Initializes the BeamBreak stream.""" - super().__init__( - _reader.BitmaskEvent(f"{pattern}_32_*", 0x22, "PelletDetected") - ) + super().__init__(_reader.BitmaskEvent(f"{pattern}_32_*", 0x22, "PelletDetected")) class DeliverPellet(Stream): @@ -147,6 +143,4 @@ class SessionData(Stream): def __init__(self, pattern): """Initializes the SessionData stream.""" - super().__init__( - _reader.Csv(f"{pattern}_2*", columns=["id", "weight", "event"]) - ) + super().__init__(_reader.Csv(f"{pattern}_2*", columns=["id", "weight", "event"])) diff --git a/aeon/schema/octagon.py b/aeon/schema/octagon.py index f031d53b..351bef31 100644 --- a/aeon/schema/octagon.py +++ b/aeon/schema/octagon.py @@ -38,9 +38,7 @@ def __init__(self, pattern): class EndTrial(Stream): def __init__(self, pattern): """Initialises the EndTrial stream.""" - super().__init__( - _reader.Csv(f"{pattern}_endtrial_*", columns=["typetag", "value"]) - ) + super().__init__(_reader.Csv(f"{pattern}_endtrial_*", columns=["typetag", "value"])) class Slice(Stream): def __init__(self, pattern): @@ -119,9 +117,7 @@ def __init__(self, pattern): class StartNewSession(Stream): def __init__(self, pattern): """Initializes the StartNewSession class.""" - super().__init__( - _reader.Csv(f"{pattern}_startnewsession_*", columns=["typetag", "path"]) - ) + super().__init__(_reader.Csv(f"{pattern}_startnewsession_*", columns=["typetag", "path"])) class TaskLogic(StreamGroup): @@ -137,9 +133,7 @@ def __init__(self, pattern): class Response(Stream): def __init__(self, pattern): """Initializes the Response stream.""" - super().__init__( - _reader.Harp(f"{pattern}_2_*", columns=["wall_id", "poke_id"]) - ) + super().__init__(_reader.Harp(f"{pattern}_2_*", columns=["wall_id", "poke_id"])) class PreTrialState(Stream): def __init__(self, pattern): @@ -175,23 +169,17 @@ def __init__(self, path): class BeamBreak0(Stream): def __init__(self, pattern): """Initialises the BeamBreak0 stream.""" - super().__init__( - _reader.DigitalBitmask(f"{pattern}_32_*", 0x1, columns=["state"]) - ) + super().__init__(_reader.DigitalBitmask(f"{pattern}_32_*", 0x1, columns=["state"])) class BeamBreak1(Stream): def __init__(self, pattern): """Initialises the BeamBreak1 stream.""" - super().__init__( - _reader.DigitalBitmask(f"{pattern}_32_*", 0x2, columns=["state"]) - ) + super().__init__(_reader.DigitalBitmask(f"{pattern}_32_*", 0x2, columns=["state"])) class BeamBreak2(Stream): def __init__(self, pattern): """Initialises the BeamBreak2 stream.""" - super().__init__( - _reader.DigitalBitmask(f"{pattern}_32_*", 0x4, columns=["state"]) - ) + super().__init__(_reader.DigitalBitmask(f"{pattern}_32_*", 0x4, columns=["state"])) class SetLed0(Stream): def __init__(self, pattern): diff --git a/aeon/schema/social_02.py b/aeon/schema/social_02.py index 8a5183dd..0564599f 100644 --- a/aeon/schema/social_02.py +++ b/aeon/schema/social_02.py @@ -25,9 +25,7 @@ def __init__(self, path): class LightEvents(Stream): def __init__(self, path): """Initializes the LightEvents stream.""" - super().__init__( - _reader.Csv(f"{path}_LightEvents_*", columns=["channel", "value"]) - ) + super().__init__(_reader.Csv(f"{path}_LightEvents_*", columns=["channel", "value"])) MessageLog = core.MessageLog @@ -40,16 +38,12 @@ def __init__(self, path): class SubjectState(Stream): def __init__(self, path): """Initializes the SubjectState stream.""" - super().__init__( - _reader.Csv(f"{path}_SubjectState_*", columns=["id", "weight", "type"]) - ) + super().__init__(_reader.Csv(f"{path}_SubjectState_*", columns=["id", "weight", "type"])) class SubjectVisits(Stream): def __init__(self, path): """Initializes the SubjectVisits stream.""" - super().__init__( - _reader.Csv(f"{path}_SubjectVisits_*", columns=["id", "type", "region"]) - ) + super().__init__(_reader.Csv(f"{path}_SubjectVisits_*", columns=["id", "type", "region"])) class SubjectWeight(Stream): def __init__(self, path): @@ -88,9 +82,7 @@ def __init__(self, path): class DepletionState(Stream): def __init__(self, path): """Initializes the DepletionState stream.""" - super().__init__( - _reader.Csv(f"{path}_State_*", columns=["threshold", "offset", "rate"]) - ) + super().__init__(_reader.Csv(f"{path}_State_*", columns=["threshold", "offset", "rate"])) Encoder = core.Encoder diff --git a/aeon/schema/social_03.py b/aeon/schema/social_03.py index 99fe0d17..6206f0f9 100644 --- a/aeon/schema/social_03.py +++ b/aeon/schema/social_03.py @@ -14,6 +14,4 @@ class EnvironmentActiveConfiguration(Stream): def __init__(self, path): """Initializes the EnvironmentActiveConfiguration stream.""" - super().__init__( - _reader.JsonList(f"{path}_ActiveConfiguration_*", columns=["name"]) - ) + super().__init__(_reader.JsonList(f"{path}_ActiveConfiguration_*", columns=["name"])) diff --git a/tests/dj_pipeline/conftest.py b/tests/dj_pipeline/conftest.py index 1faf316e..1cff0de1 100644 --- a/tests/dj_pipeline/conftest.py +++ b/tests/dj_pipeline/conftest.py @@ -58,16 +58,12 @@ def dj_config(): """ dj_config_fp = pathlib.Path("dj_local_conf.json") if not dj_config_fp.exists(): - raise FileNotFoundError( - f"DataJoint configuration file not found: {dj_config_fp}" - ) + raise FileNotFoundError(f"DataJoint configuration file not found: {dj_config_fp}") dj.config.load(dj_config_fp) dj.config["safemode"] = False if "custom" not in dj.config: raise KeyError("'custom' not found in DataJoint configuration.") - dj.config["custom"][ - "database.prefix" - ] = f"u_{dj.config['database.user']}_testsuite_" + dj.config["custom"]["database.prefix"] = f"u_{dj.config['database.user']}_testsuite_" def load_pipeline(): diff --git a/tests/dj_pipeline/test_acquisition.py b/tests/dj_pipeline/test_acquisition.py index 37dc7f47..ce956b3e 100644 --- a/tests/dj_pipeline/test_acquisition.py +++ b/tests/dj_pipeline/test_acquisition.py @@ -9,32 +9,21 @@ @pytest.mark.ingestion def test_epoch_chunk_ingestion(test_params, pipeline, epoch_chunk_ingestion): acquisition = pipeline["acquisition"] - epoch_count = len( - acquisition.Epoch & {"experiment_name": test_params["experiment_name"]} - ) - chunk_count = len( - acquisition.Chunk & {"experiment_name": test_params["experiment_name"]} - ) + epoch_count = len(acquisition.Epoch & {"experiment_name": test_params["experiment_name"]}) + chunk_count = len(acquisition.Chunk & {"experiment_name": test_params["experiment_name"]}) if epoch_count != test_params["epoch_count"]: - raise AssertionError( - f"Expected {test_params['epoch_count']} epochs, but got {epoch_count}." - ) + raise AssertionError(f"Expected {test_params['epoch_count']} epochs, but got {epoch_count}.") if chunk_count != test_params["chunk_count"]: - raise AssertionError( - f"Expected {test_params['chunk_count']} chunks, but got {chunk_count}." - ) + raise AssertionError(f"Expected {test_params['chunk_count']} chunks, but got {chunk_count}.") @pytest.mark.ingestion -def test_experimentlog_ingestion( - test_params, pipeline, epoch_chunk_ingestion, experimentlog_ingestion -): +def test_experimentlog_ingestion(test_params, pipeline, epoch_chunk_ingestion, experimentlog_ingestion): acquisition = pipeline["acquisition"] exp_log_message_count = len( - acquisition.ExperimentLog.Message - & {"experiment_name": test_params["experiment_name"]} + acquisition.ExperimentLog.Message & {"experiment_name": test_params["experiment_name"]} ) if exp_log_message_count != test_params["experiment_log_message_count"]: raise AssertionError( @@ -43,8 +32,7 @@ def test_experimentlog_ingestion( ) subject_enter_exit_count = len( - acquisition.SubjectEnterExit.Time - & {"experiment_name": test_params["experiment_name"]} + acquisition.SubjectEnterExit.Time & {"experiment_name": test_params["experiment_name"]} ) if subject_enter_exit_count != test_params["subject_enter_exit_count"]: raise AssertionError( @@ -53,8 +41,7 @@ def test_experimentlog_ingestion( ) subject_weight_time_count = len( - acquisition.SubjectWeight.WeightTime - & {"experiment_name": test_params["experiment_name"]} + acquisition.SubjectWeight.WeightTime & {"experiment_name": test_params["experiment_name"]} ) if subject_weight_time_count != test_params["subject_weight_time_count"]: raise AssertionError( diff --git a/tests/dj_pipeline/test_pipeline_instantiation.py b/tests/dj_pipeline/test_pipeline_instantiation.py index c7321b09..f53bde20 100644 --- a/tests/dj_pipeline/test_pipeline_instantiation.py +++ b/tests/dj_pipeline/test_pipeline_instantiation.py @@ -10,9 +10,7 @@ @pytest.mark.instantiation def test_pipeline_instantiation(pipeline): if not hasattr(pipeline["acquisition"], "FoodPatchEvent"): - raise AssertionError( - "Pipeline acquisition does not have 'FoodPatchEvent' attribute." - ) + raise AssertionError("Pipeline acquisition does not have 'FoodPatchEvent' attribute.") if not hasattr(pipeline["lab"], "Arena"): raise AssertionError("Pipeline lab does not have 'Arena' attribute.") @@ -21,17 +19,13 @@ def test_pipeline_instantiation(pipeline): raise AssertionError("Pipeline qc does not have 'CameraQC' attribute.") if not hasattr(pipeline["report"], "InArenaSummaryPlot"): - raise AssertionError( - "Pipeline report does not have 'InArenaSummaryPlot' attribute." - ) + raise AssertionError("Pipeline report does not have 'InArenaSummaryPlot' attribute.") if not hasattr(pipeline["subject"], "Subject"): raise AssertionError("Pipeline subject does not have 'Subject' attribute.") if not hasattr(pipeline["tracking"], "CameraTracking"): - raise AssertionError( - "Pipeline tracking does not have 'CameraTracking' attribute." - ) + raise AssertionError("Pipeline tracking does not have 'CameraTracking' attribute.") @pytest.mark.instantiation @@ -46,23 +40,16 @@ def test_experiment_creation(test_params, pipeline, experiment_creation): ) raw_dir = ( - acquisition.Experiment.Directory - & {"experiment_name": experiment_name, "directory_type": "raw"} + acquisition.Experiment.Directory & {"experiment_name": experiment_name, "directory_type": "raw"} ).fetch1("directory_path") if raw_dir != test_params["raw_dir"]: - raise AssertionError( - f"Expected raw directory '{test_params['raw_dir']}', but got '{raw_dir}'." - ) + raise AssertionError(f"Expected raw directory '{test_params['raw_dir']}', but got '{raw_dir}'.") - exp_subjects = ( - acquisition.Experiment.Subject & {"experiment_name": experiment_name} - ).fetch("subject") + exp_subjects = (acquisition.Experiment.Subject & {"experiment_name": experiment_name}).fetch("subject") if len(exp_subjects) != test_params["subject_count"]: raise AssertionError( f"Expected subject count {test_params['subject_count']}, but got {len(exp_subjects)}." ) if "BAA-1100701" not in exp_subjects: - raise AssertionError( - "Expected subject 'BAA-1100701' not found in experiment subjects." - ) + raise AssertionError("Expected subject 'BAA-1100701' not found in experiment subjects.") diff --git a/tests/dj_pipeline/test_tracking.py b/tests/dj_pipeline/test_tracking.py index 5692c8cf..733fe5f6 100644 --- a/tests/dj_pipeline/test_tracking.py +++ b/tests/dj_pipeline/test_tracking.py @@ -24,11 +24,7 @@ def save_test_data(pipeline, test_params): file_name = ( "-".join( [ - ( - v.strftime("%Y%m%d%H%M%S") - if isinstance(v, datetime.datetime) - else str(v) - ) + (v.strftime("%Y%m%d%H%M%S") if isinstance(v, datetime.datetime) else str(v)) for v in key.values() ] ) @@ -58,11 +54,7 @@ def test_camera_tracking_ingestion(test_params, pipeline, camera_tracking_ingest file_name = ( "-".join( [ - ( - v.strftime("%Y%m%d%H%M%S") - if isinstance(v, datetime.datetime) - else str(v) - ) + (v.strftime("%Y%m%d%H%M%S") if isinstance(v, datetime.datetime) else str(v)) for v in key.values() ] ) diff --git a/tests/io/test_api.py b/tests/io/test_api.py index 14e7d10f..015b9a58 100644 --- a/tests/io/test_api.py +++ b/tests/io/test_api.py @@ -38,9 +38,7 @@ def test_load_end_only(): @pytest.mark.api def test_load_filter_nonchunked(): - data = aeon.load( - nonmonotonic_path, exp02.Metadata, start=pd.Timestamp("2022-06-06T09:00:00") - ) + data = aeon.load(nonmonotonic_path, exp02.Metadata, start=pd.Timestamp("2022-06-06T09:00:00")) if len(data) <= 0: raise AssertionError("Loaded data is empty. Expected non-empty data.") @@ -59,9 +57,7 @@ def test_load_monotonic(): def test_load_nonmonotonic(): data = aeon.load(nonmonotonic_path, exp02.Patch2.Encoder, downsample=None) if data.index.is_monotonic_increasing: - raise AssertionError( - "Data index is monotonic increasing, but it should not be." - ) + raise AssertionError("Data index is monotonic increasing, but it should not be.") @pytest.mark.api @@ -72,15 +68,11 @@ def test_load_encoder_with_downsampling(): # Check that the length of the downsampled data is less than the raw data if len(data) >= len(raw_data): - raise AssertionError( - "Downsampled data length should be less than raw data length." - ) + raise AssertionError("Downsampled data length should be less than raw data length.") # Check that the first timestamp of the downsampled data is within 20ms of the raw data if abs(data.index[0] - raw_data.index[0]).total_seconds() > DOWNSAMPLE_PERIOD: - raise AssertionError( - "The first timestamp of downsampled data is not within 20ms of raw data." - ) + raise AssertionError("The first timestamp of downsampled data is not within 20ms of raw data.") # Check that the last timestamp of the downsampled data is within 20ms of the raw data if abs(data.index[-1] - raw_data.index[-1]).total_seconds() > DOWNSAMPLE_PERIOD: @@ -98,9 +90,7 @@ def test_load_encoder_with_downsampling(): # Check that the timestamps in the downsampled data are strictly increasing if not data.index.is_monotonic_increasing: - raise AssertionError( - "Timestamps in downsampled data are not strictly increasing." - ) + raise AssertionError("Timestamps in downsampled data are not strictly increasing.") if __name__ == "__main__": From 4b9b0e786c855c6b98950e5e7699c1273af0a901 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 23:38:55 +0000 Subject: [PATCH 027/143] fix: resolve E501 issues --- aeon/analysis/block_plotting.py | 4 +- aeon/dj_pipeline/acquisition.py | 86 ++- aeon/dj_pipeline/analysis/block_analysis.py | 553 +++++++++++++----- aeon/dj_pipeline/analysis/visit.py | 53 +- aeon/dj_pipeline/analysis/visit_analysis.py | 156 +++-- .../create_experiment_01.py | 4 +- .../create_socialexperiment_0.py | 19 +- aeon/dj_pipeline/populate/process.py | 15 +- aeon/dj_pipeline/populate/worker.py | 4 +- aeon/dj_pipeline/qc.py | 27 +- aeon/dj_pipeline/report.py | 84 ++- aeon/dj_pipeline/subject.py | 53 +- aeon/dj_pipeline/tracking.py | 59 +- aeon/dj_pipeline/utils/load_metadata.py | 133 +++-- aeon/dj_pipeline/utils/paths.py | 6 +- aeon/dj_pipeline/utils/plotting.py | 175 ++++-- aeon/dj_pipeline/utils/streams_maker.py | 19 +- aeon/io/reader.py | 4 +- .../test_pipeline_instantiation.py | 4 +- 19 files changed, 1047 insertions(+), 411 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index 61875d2e..5b04977d 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -29,7 +29,9 @@ def gen_hex_grad(hex_col, vals, min_l=0.3): """Generates an array of hex color values based on a gradient defined by unit-normalized values.""" # Convert hex to rgb to hls - h, l, s = rgb_to_hls(*[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)]) # noqa: E741 + h, l, s = rgb_to_hls( + *[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)] + ) # noqa: E741 grad = np.empty(shape=(len(vals),), dtype=" 0: previous_chunk = all_chunks.iloc[i - 1] previous_chunk_path = pathlib.Path(previous_chunk.path) - previous_epoch_dir = pathlib.Path(previous_chunk_path.as_posix().split(device_name)[0]) + previous_epoch_dir = pathlib.Path( + previous_chunk_path.as_posix().split(device_name)[0] + ) previous_epoch_start = datetime.datetime.strptime( previous_epoch_dir.name, "%Y-%m-%dT%H-%M-%S" ) - previous_chunk_end = previous_chunk.name + datetime.timedelta(hours=io_api.CHUNK_DURATION) + previous_chunk_end = previous_chunk.name + datetime.timedelta( + hours=io_api.CHUNK_DURATION + ) previous_epoch_end = min(previous_chunk_end, epoch_start) previous_epoch_key = { "experiment_name": experiment_name, @@ -247,7 +256,9 @@ def ingest_epochs(cls, experiment_name): { **previous_epoch_key, "epoch_end": previous_epoch_end, - "epoch_duration": (previous_epoch_end - previous_epoch_start).total_seconds() + "epoch_duration": ( + previous_epoch_end - previous_epoch_start + ).total_seconds() / 3600, } ) @@ -288,7 +299,7 @@ class Meta(dj.Part): -> master --- bonsai_workflow: varchar(36) - commit: varchar(64) # e.g. git commit hash of aeon_experiment used to generated this particular epoch + commit: varchar(64) # e.g., git commit hash of aeon_experiment used to generate this epoch source='': varchar(16) # e.g. aeon_experiment or aeon_acquisition (or others) metadata: longblob metadata_file_path: varchar(255) # path of the file, relative to the experiment repository @@ -320,17 +331,23 @@ def make(self, key): experiment_name = key["experiment_name"] devices_schema = getattr( aeon_schemas, - (Experiment.DevicesSchema & {"experiment_name": experiment_name}).fetch1("devices_schema_name"), + (Experiment.DevicesSchema & {"experiment_name": experiment_name}).fetch1( + "devices_schema_name" + ), ) dir_type, epoch_dir = (Epoch & key).fetch1("directory_type", "epoch_dir") data_dir = Experiment.get_data_directory(key, dir_type) metadata_yml_filepath = data_dir / epoch_dir / "Metadata.yml" - epoch_config = extract_epoch_config(experiment_name, devices_schema, metadata_yml_filepath) + epoch_config = extract_epoch_config( + experiment_name, devices_schema, metadata_yml_filepath + ) epoch_config = { **epoch_config, - "metadata_file_path": metadata_yml_filepath.relative_to(data_dir).as_posix(), + "metadata_file_path": metadata_yml_filepath.relative_to( + data_dir + ).as_posix(), } # Insert new entries for streams.DeviceType, streams.Device. @@ -341,15 +358,20 @@ def make(self, key): # Define and instantiate new devices/stream tables under `streams` schema streams_maker.main() # Insert devices' installation/removal/settings - epoch_device_types = ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath) + epoch_device_types = ingest_epoch_metadata( + experiment_name, devices_schema, metadata_yml_filepath + ) self.insert1(key) self.Meta.insert1(epoch_config) - self.DeviceType.insert(key | {"device_type": n} for n in epoch_device_types or {}) + self.DeviceType.insert( + key | {"device_type": n} for n in epoch_device_types or {} + ) with metadata_yml_filepath.open("r") as f: metadata = json.load(f) self.ActiveRegion.insert( - {**key, "region_name": k, "region_data": v} for k, v in metadata["ActiveRegion"].items() + {**key, "region_name": k, "region_data": v} + for k, v in metadata["ActiveRegion"].items() ) @@ -388,7 +410,9 @@ def ingest_chunks(cls, experiment_name): for _, chunk in all_chunks.iterrows(): chunk_rep_file = pathlib.Path(chunk.path) epoch_dir = pathlib.Path(chunk_rep_file.as_posix().split(device_name)[0]) - epoch_start = datetime.datetime.strptime(epoch_dir.name, "%Y-%m-%dT%H-%M-%S") + epoch_start = datetime.datetime.strptime( + epoch_dir.name, "%Y-%m-%dT%H-%M-%S" + ) epoch_key = {"experiment_name": experiment_name, "epoch_start": epoch_start} if not (Epoch & epoch_key): @@ -396,7 +420,9 @@ def ingest_chunks(cls, experiment_name): continue chunk_start = chunk.name - chunk_start = max(chunk_start, epoch_start) # first chunk of the epoch starts at epoch_start + chunk_start = max( + chunk_start, epoch_start + ) # first chunk of the epoch starts at epoch_start chunk_end = chunk_start + datetime.timedelta(hours=io_api.CHUNK_DURATION) if EpochEnd & epoch_key: @@ -416,8 +442,12 @@ def ingest_chunks(cls, experiment_name): ) chunk_starts.append(chunk_key["chunk_start"]) - chunk_list.append({**chunk_key, **directory, "chunk_end": chunk_end, **epoch_key}) - file_name_list.append(chunk_rep_file.name) # handle duplicated files in different folders + chunk_list.append( + {**chunk_key, **directory, "chunk_end": chunk_end, **epoch_key} + ) + file_name_list.append( + chunk_rep_file.name + ) # handle duplicated files in different folders # -- files -- file_datetime_str = chunk_rep_file.stem.replace(f"{device_name}_", "") @@ -534,9 +564,9 @@ def make(self, key): data_dirs = Experiment.get_data_directories(key) devices_schema = getattr( aeon_schemas, - (Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) device = devices_schema.Environment @@ -596,12 +626,14 @@ def make(self, key): data_dirs = Experiment.get_data_directories(key) devices_schema = getattr( aeon_schemas, - (Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) device = devices_schema.Environment - stream_reader = device.EnvironmentActiveConfiguration # expecting columns: time, name, value + stream_reader = ( + device.EnvironmentActiveConfiguration + ) # expecting columns: time, name, value stream_data = io_api.load( root=data_dirs, reader=stream_reader, @@ -634,7 +666,9 @@ def _get_all_chunks(experiment_name, device_name): raw_data_dirs = {k: v for k, v in raw_data_dirs.items() if v} if not raw_data_dirs: - raise ValueError(f"No raw data directory found for experiment: {experiment_name}") + raise ValueError( + f"No raw data directory found for experiment: {experiment_name}" + ) chunkdata = io_api.load( root=list(raw_data_dirs.values()), @@ -656,7 +690,9 @@ def _match_experiment_directory(experiment_name, path, directories): repo_path = paths.get_repository_path(directory.pop("repository_name")) break else: - raise FileNotFoundError(f"Unable to identify the directory" f" where this chunk is from: {path}") + raise FileNotFoundError( + f"Unable to identify the directory" f" where this chunk is from: {path}" + ) return raw_data_dir, directory, repo_path diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index ed752644..ddc220e3 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -67,14 +67,18 @@ def make(self, key): # find the 0s in `pellet_ct` (these are times when the pellet count reset - i.e. new block) # that would mark the start of a new block - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) exp_key = {"experiment_name": key["experiment_name"]} chunk_restriction = acquisition.create_chunk_restriction( key["experiment_name"], chunk_start, chunk_end ) - block_state_query = acquisition.Environment.BlockState & exp_key & chunk_restriction + block_state_query = ( + acquisition.Environment.BlockState & exp_key & chunk_restriction + ) block_state_df = fetch_stream(block_state_query) if block_state_df.empty: self.insert1(key) @@ -97,8 +101,12 @@ def make(self, key): block_entries = [] if not blocks_df.empty: # calculate block end_times (use due_time) and durations - blocks_df["end_time"] = blocks_df["due_time"].apply(lambda x: io_api.aeon(x)) - blocks_df["duration"] = (blocks_df["end_time"] - blocks_df.index).dt.total_seconds() / 3600 + blocks_df["end_time"] = blocks_df["due_time"].apply( + lambda x: io_api.aeon(x) + ) + blocks_df["duration"] = ( + blocks_df["end_time"] - blocks_df.index + ).dt.total_seconds() / 3600 for _, row in blocks_df.iterrows(): block_entries.append( @@ -167,7 +175,10 @@ class Subject(dj.Part): """ def make(self, key): - """Restrict, fetch and aggregate data from different streams to produce intermediate data products at a per-block level (for different patches and different subjects). + """Restrict, fetch and aggregate data from different streams to + produce intermediate data products at a per-block level + (for different patches and different subjects). + 1. Query data for all chunks within the block. 2. Fetch streams, filter by maintenance period. 3. Fetch subject position data (SLEAP). @@ -188,19 +199,27 @@ def make(self, key): tracking.SLEAPTracking, ) for streams_table in streams_tables: - if len(streams_table & chunk_keys) < len(streams_table.key_source & chunk_keys): + if len(streams_table & chunk_keys) < len( + streams_table.key_source & chunk_keys + ): raise ValueError( - f"BlockAnalysis Not Ready - {streams_table.__name__} not yet fully ingested for block: {key}. Skipping (to retry later)..." + f"BlockAnalysis Not Ready - {streams_table.__name__}" + f"not yet fully ingested for block: {key}." + f"Skipping (to retry later)..." ) # Patch data - TriggerPellet, DepletionState, Encoder (distancetravelled) # For wheel data, downsample to 10Hz final_encoder_fs = 10 - maintenance_period = get_maintenance_periods(key["experiment_name"], block_start, block_end) + maintenance_period = get_maintenance_periods( + key["experiment_name"], block_start, block_end + ) patch_query = ( - streams.UndergroundFeeder.join(streams.UndergroundFeeder.RemovalTime, left=True) + streams.UndergroundFeeder.join( + streams.UndergroundFeeder.RemovalTime, left=True + ) & key & f'"{block_start}" >= underground_feeder_install_time' & f'"{block_end}" < IFNULL(underground_feeder_removal_time, "2200-01-01")' @@ -214,12 +233,14 @@ def make(self, key): streams.UndergroundFeederDepletionState & patch_key & chunk_restriction )[block_start:block_end] - pellet_ts_threshold_df = get_threshold_associated_pellets(patch_key, block_start, block_end) + pellet_ts_threshold_df = get_threshold_associated_pellets( + patch_key, block_start, block_end + ) # wheel encoder data - encoder_df = fetch_stream(streams.UndergroundFeederEncoder & patch_key & chunk_restriction)[ - block_start:block_end - ] + encoder_df = fetch_stream( + streams.UndergroundFeederEncoder & patch_key & chunk_restriction + )[block_start:block_end] # filter out maintenance period based on logs pellet_ts_threshold_df = filter_out_maintenance_periods( pellet_ts_threshold_df, @@ -238,17 +259,26 @@ def make(self, key): ) if depletion_state_df.empty: - raise ValueError(f"No depletion state data found for block {key} - patch: {patch_name}") + raise ValueError( + f"No depletion state data found for block {key} - patch: {patch_name}" + ) - encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled(encoder_df.angle) + encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled( + encoder_df.angle + ) if len(depletion_state_df.rate.unique()) > 1: - # multiple patch rates per block is unexpected, log a note and pick the first rate to move forward + # multiple patch rates per block is unexpected + # log a note and pick the first rate to move forward AnalysisNote.insert1( { "note_timestamp": datetime.now(timezone.utc), "note_type": "Multiple patch rates", - "note": f"Found multiple patch rates for block {key} - patch: {patch_name} - rates: {depletion_state_df.rate.unique()}", + "note": ( + f"Found multiple patch rates for block {key} " + f"- patch: {patch_name} " + f"- rates: {depletion_state_df.rate.unique()}" + ), } ) @@ -271,7 +301,9 @@ def make(self, key): "wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values[ ::wheel_downsampling_factor ], - "wheel_timestamps": encoder_df.index.values[::wheel_downsampling_factor], + "wheel_timestamps": encoder_df.index.values[ + ::wheel_downsampling_factor + ], "patch_threshold": pellet_ts_threshold_df.threshold.values, "patch_threshold_timestamps": pellet_ts_threshold_df.index.values, "patch_rate": patch_rate, @@ -303,7 +335,9 @@ def make(self, key): # positions - query for CameraTop, identity_name matches subject_name, pos_query = ( streams.SpinnakerVideoSource - * tracking.SLEAPTracking.PoseIdentity.proj("identity_name", part_name="anchor_part") + * tracking.SLEAPTracking.PoseIdentity.proj( + "identity_name", part_name="anchor_part" + ) * tracking.SLEAPTracking.Part & key & { @@ -313,18 +347,23 @@ def make(self, key): & chunk_restriction ) pos_df = fetch_stream(pos_query)[block_start:block_end] - pos_df = filter_out_maintenance_periods(pos_df, maintenance_period, block_end) + pos_df = filter_out_maintenance_periods( + pos_df, maintenance_period, block_end + ) if pos_df.empty: continue position_diff = np.sqrt( - np.square(np.diff(pos_df.x.astype(float))) + np.square(np.diff(pos_df.y.astype(float))) + np.square(np.diff(pos_df.x.astype(float))) + + np.square(np.diff(pos_df.y.astype(float))) ) cumsum_distance_travelled = np.concatenate([[0], np.cumsum(position_diff)]) # weights - weight_query = acquisition.Environment.SubjectWeight & key & chunk_restriction + weight_query = ( + acquisition.Environment.SubjectWeight & key & chunk_restriction + ) weight_df = fetch_stream(weight_query)[block_start:block_end] weight_df.query(f"subject_id == '{subject_name}'", inplace=True) @@ -377,7 +416,7 @@ class Patch(dj.Part): -> BlockAnalysis.Patch -> BlockAnalysis.Subject --- - in_patch_timestamps: longblob # timestamps in which a particular subject is spending time at a particular patch + in_patch_timestamps: longblob # timestamps when a subject spends time at a specific patch in_patch_time: float # total seconds spent in this patch for this block pellet_count: int pellet_timestamps: longblob @@ -412,7 +451,10 @@ def make(self, key): subjects_positions_df = pd.concat( [ pd.DataFrame( - {"subject_name": [s["subject_name"]] * len(s["position_timestamps"])} + { + "subject_name": [s["subject_name"]] + * len(s["position_timestamps"]) + } | { k: s[k] for k in ( @@ -440,7 +482,8 @@ def make(self, key): "cum_pref_time", ] all_subj_patch_pref_dict = { - p: {s: {a: pd.Series() for a in pref_attrs} for s in subject_names} for p in patch_names + p: {s: {a: pd.Series() for a in pref_attrs} for s in subject_names} + for p in patch_names } for patch in block_patches: @@ -463,11 +506,15 @@ def make(self, key): ).fetch1("attribute_value") patch_center = (int(patch_center["X"]), int(patch_center["Y"])) subjects_xy = subjects_positions_df[["position_x", "position_y"]].values - dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_center) ** 2, axis=1).astype(float)) + dist_to_patch = np.sqrt( + np.sum((subjects_xy - patch_center) ** 2, axis=1).astype(float) + ) dist_to_patch_df = subjects_positions_df[["subject_name"]].copy() dist_to_patch_df["dist_to_patch"] = dist_to_patch - dist_to_patch_wheel_ts_id_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subject_names) + dist_to_patch_wheel_ts_id_df = pd.DataFrame( + index=cum_wheel_dist.index, columns=subject_names + ) dist_to_patch_pel_ts_id_df = pd.DataFrame( index=patch["pellet_timestamps"], columns=subject_names ) @@ -475,10 +522,12 @@ def make(self, key): # Find closest match between pose_df indices and wheel indices if not dist_to_patch_wheel_ts_id_df.empty: dist_to_patch_wheel_ts_subj = pd.merge_asof( - left=pd.DataFrame(dist_to_patch_wheel_ts_id_df[subject_name].copy()).reset_index( - names="time" - ), - right=dist_to_patch_df[dist_to_patch_df["subject_name"] == subject_name] + left=pd.DataFrame( + dist_to_patch_wheel_ts_id_df[subject_name].copy() + ).reset_index(names="time"), + right=dist_to_patch_df[ + dist_to_patch_df["subject_name"] == subject_name + ] .copy() .reset_index(names="time"), on="time", @@ -487,16 +536,18 @@ def make(self, key): direction="nearest", tolerance=pd.Timedelta("100ms"), ) - dist_to_patch_wheel_ts_id_df[subject_name] = dist_to_patch_wheel_ts_subj[ - "dist_to_patch" - ].values + dist_to_patch_wheel_ts_id_df[subject_name] = ( + dist_to_patch_wheel_ts_subj["dist_to_patch"].values + ) # Find closest match between pose_df indices and pel indices if not dist_to_patch_pel_ts_id_df.empty: dist_to_patch_pel_ts_subj = pd.merge_asof( - left=pd.DataFrame(dist_to_patch_pel_ts_id_df[subject_name].copy()).reset_index( - names="time" - ), - right=dist_to_patch_df[dist_to_patch_df["subject_name"] == subject_name] + left=pd.DataFrame( + dist_to_patch_pel_ts_id_df[subject_name].copy() + ).reset_index(names="time"), + right=dist_to_patch_df[ + dist_to_patch_df["subject_name"] == subject_name + ] .copy() .reset_index(names="time"), on="time", @@ -505,9 +556,9 @@ def make(self, key): direction="nearest", tolerance=pd.Timedelta("200ms"), ) - dist_to_patch_pel_ts_id_df[subject_name] = dist_to_patch_pel_ts_subj[ - "dist_to_patch" - ].values + dist_to_patch_pel_ts_id_df[subject_name] = ( + dist_to_patch_pel_ts_subj["dist_to_patch"].values + ) # Get closest subject to patch at each pellet timestep closest_subjects_pellet_ts = dist_to_patch_pel_ts_id_df.idxmin(axis=1) @@ -519,8 +570,12 @@ def make(self, key): wheel_dist = cum_wheel_dist.diff().fillna(cum_wheel_dist.iloc[0]) # Assign wheel dist to closest subject for each wheel timestep for subject_name in subject_names: - subj_idxs = cum_wheel_dist_subj_df[closest_subjects_wheel_ts == subject_name].index - cum_wheel_dist_subj_df.loc[subj_idxs, subject_name] = wheel_dist[subj_idxs] + subj_idxs = cum_wheel_dist_subj_df[ + closest_subjects_wheel_ts == subject_name + ].index + cum_wheel_dist_subj_df.loc[subj_idxs, subject_name] = wheel_dist[ + subj_idxs + ] cum_wheel_dist_subj_df = cum_wheel_dist_subj_df.cumsum(axis=0) # In patch time @@ -528,9 +583,9 @@ def make(self, key): dt = np.median(np.diff(cum_wheel_dist.index)).astype(int) / 1e9 # s # Fill in `all_subj_patch_pref` for subject_name in subject_names: - all_subj_patch_pref_dict[patch["patch_name"]][subject_name]["cum_dist"] = ( - cum_wheel_dist_subj_df[subject_name].values - ) + all_subj_patch_pref_dict[patch["patch_name"]][subject_name][ + "cum_dist" + ] = cum_wheel_dist_subj_df[subject_name].values subject_in_patch = in_patch[subject_name] subject_in_patch_cum_time = subject_in_patch.cumsum().values * dt all_subj_patch_pref_dict[patch["patch_name"]][subject_name][ @@ -551,7 +606,9 @@ def make(self, key): "pellet_count": len(subj_pellets), "pellet_timestamps": subj_pellets.index.values, "patch_threshold": subj_patch_thresh, - "wheel_cumsum_distance_travelled": cum_wheel_dist_subj_df[subject_name].values, + "wheel_cumsum_distance_travelled": cum_wheel_dist_subj_df[ + subject_name + ].values, } ) @@ -560,46 +617,72 @@ def make(self, key): for subject_name in subject_names: # Get sum of subj cum wheel dists and cum in patch time all_cum_dist = np.sum( - [all_subj_patch_pref_dict[p][subject_name]["cum_dist"][-1] for p in patch_names] + [ + all_subj_patch_pref_dict[p][subject_name]["cum_dist"][-1] + for p in patch_names + ] ) all_cum_time = np.sum( - [all_subj_patch_pref_dict[p][subject_name]["cum_time"][-1] for p in patch_names] + [ + all_subj_patch_pref_dict[p][subject_name]["cum_time"][-1] + for p in patch_names + ] ) for patch_name in patch_names: cum_pref_dist = ( - all_subj_patch_pref_dict[patch_name][subject_name]["cum_dist"] / all_cum_dist + all_subj_patch_pref_dict[patch_name][subject_name]["cum_dist"] + / all_cum_dist ) cum_pref_dist = np.where(cum_pref_dist < 1e-3, 0, cum_pref_dist) - all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_dist"] = cum_pref_dist + all_subj_patch_pref_dict[patch_name][subject_name][ + "cum_pref_dist" + ] = cum_pref_dist cum_pref_time = ( - all_subj_patch_pref_dict[patch_name][subject_name]["cum_time"] / all_cum_time + all_subj_patch_pref_dict[patch_name][subject_name]["cum_time"] + / all_cum_time ) - all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_time"] = cum_pref_time + all_subj_patch_pref_dict[patch_name][subject_name][ + "cum_pref_time" + ] = cum_pref_time # sum pref at each ts across patches for each subject total_dist_pref = np.sum( np.vstack( - [all_subj_patch_pref_dict[p][subject_name]["cum_pref_dist"] for p in patch_names] + [ + all_subj_patch_pref_dict[p][subject_name]["cum_pref_dist"] + for p in patch_names + ] ), axis=0, ) total_time_pref = np.sum( np.vstack( - [all_subj_patch_pref_dict[p][subject_name]["cum_pref_time"] for p in patch_names] + [ + all_subj_patch_pref_dict[p][subject_name]["cum_pref_time"] + for p in patch_names + ] ), axis=0, ) for patch_name in patch_names: - cum_pref_dist = all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_dist"] - all_subj_patch_pref_dict[patch_name][subject_name]["running_dist_pref"] = np.divide( + cum_pref_dist = all_subj_patch_pref_dict[patch_name][subject_name][ + "cum_pref_dist" + ] + all_subj_patch_pref_dict[patch_name][subject_name][ + "running_dist_pref" + ] = np.divide( cum_pref_dist, total_dist_pref, out=np.zeros_like(cum_pref_dist), where=total_dist_pref != 0, ) - cum_pref_time = all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_time"] - all_subj_patch_pref_dict[patch_name][subject_name]["running_time_pref"] = np.divide( + cum_pref_time = all_subj_patch_pref_dict[patch_name][subject_name][ + "cum_pref_time" + ] + all_subj_patch_pref_dict[patch_name][subject_name][ + "running_time_pref" + ] = np.divide( cum_pref_time, total_time_pref, out=np.zeros_like(cum_pref_time), @@ -611,12 +694,24 @@ def make(self, key): | { "patch_name": p, "subject_name": s, - "cumulative_preference_by_time": all_subj_patch_pref_dict[p][s]["cum_pref_time"], - "cumulative_preference_by_wheel": all_subj_patch_pref_dict[p][s]["cum_pref_dist"], - "running_preference_by_time": all_subj_patch_pref_dict[p][s]["running_time_pref"], - "running_preference_by_wheel": all_subj_patch_pref_dict[p][s]["running_dist_pref"], - "final_preference_by_time": all_subj_patch_pref_dict[p][s]["cum_pref_time"][-1], - "final_preference_by_wheel": all_subj_patch_pref_dict[p][s]["cum_pref_dist"][-1], + "cumulative_preference_by_time": all_subj_patch_pref_dict[p][s][ + "cum_pref_time" + ], + "cumulative_preference_by_wheel": all_subj_patch_pref_dict[p][s][ + "cum_pref_dist" + ], + "running_preference_by_time": all_subj_patch_pref_dict[p][s][ + "running_time_pref" + ], + "running_preference_by_wheel": all_subj_patch_pref_dict[p][s][ + "running_dist_pref" + ], + "final_preference_by_time": all_subj_patch_pref_dict[p][s][ + "cum_pref_time" + ][-1], + "final_preference_by_wheel": all_subj_patch_pref_dict[p][s][ + "cum_pref_dist" + ][-1], } for p, s in itertools.product(patch_names, subject_names) ) @@ -640,7 +735,9 @@ class BlockPatchPlots(dj.Computed): def make(self, key): """Compute and plot various block-level statistics and visualizations.""" # Define subject colors and patch styling for plotting - exp_subject_names = (acquisition.Experiment.Subject & key).fetch("subject", order_by="subject") + exp_subject_names = (acquisition.Experiment.Subject & key).fetch( + "subject", order_by="subject" + ) if not len(exp_subject_names): raise ValueError( "No subjects found in the `acquisition.Experiment.Subject`, missing a manual insert step?." @@ -659,7 +756,10 @@ def make(self, key): # Figure 1 - Patch stats: patch means and pellet threshold boxplots # --- subj_patch_info = ( - (BlockSubjectAnalysis.Patch.proj("pellet_timestamps", "patch_threshold") & key) + ( + BlockSubjectAnalysis.Patch.proj("pellet_timestamps", "patch_threshold") + & key + ) .fetch(format="frame") .reset_index() ) @@ -673,28 +773,46 @@ def make(self, key): ["patch_name", "subject_name", "pellet_timestamps", "patch_threshold"] ] min_subj_patch_info = ( - min_subj_patch_info.explode(["pellet_timestamps", "patch_threshold"], ignore_index=True) + min_subj_patch_info.explode( + ["pellet_timestamps", "patch_threshold"], ignore_index=True + ) .dropna() .reset_index(drop=True) ) # Rename and reindex columns min_subj_patch_info.columns = ["patch", "subject", "time", "threshold"] - min_subj_patch_info = min_subj_patch_info.reindex(columns=["time", "patch", "threshold", "subject"]) + min_subj_patch_info = min_subj_patch_info.reindex( + columns=["time", "patch", "threshold", "subject"] + ) # Add patch mean values and block-normalized delivery times to pellet info n_patches = len(patch_info) - patch_mean_info = pd.DataFrame(index=np.arange(n_patches), columns=min_subj_patch_info.columns) + patch_mean_info = pd.DataFrame( + index=np.arange(n_patches), columns=min_subj_patch_info.columns + ) patch_mean_info["subject"] = "mean" patch_mean_info["patch"] = [d["patch_name"] for d in patch_info] - patch_mean_info["threshold"] = [((1 / d["patch_rate"]) + d["patch_offset"]) for d in patch_info] + patch_mean_info["threshold"] = [ + ((1 / d["patch_rate"]) + d["patch_offset"]) for d in patch_info + ] patch_mean_info["time"] = subj_patch_info["block_start"][0] - min_subj_patch_info_plus = pd.concat((patch_mean_info, min_subj_patch_info)).reset_index(drop=True) + min_subj_patch_info_plus = pd.concat( + (patch_mean_info, min_subj_patch_info) + ).reset_index(drop=True) min_subj_patch_info_plus["norm_time"] = ( - (min_subj_patch_info_plus["time"] - min_subj_patch_info_plus["time"].iloc[0]) - / (min_subj_patch_info_plus["time"].iloc[-1] - min_subj_patch_info_plus["time"].iloc[0]) + ( + min_subj_patch_info_plus["time"] + - min_subj_patch_info_plus["time"].iloc[0] + ) + / ( + min_subj_patch_info_plus["time"].iloc[-1] + - min_subj_patch_info_plus["time"].iloc[0] + ) ).round(3) # Plot it - box_colors = ["#0A0A0A"] + list(subject_colors_dict.values()) # subject colors + mean color + box_colors = ["#0A0A0A"] + list( + subject_colors_dict.values() + ) # subject colors + mean color patch_stats_fig = px.box( min_subj_patch_info_plus.sort_values("patch"), x="patch", @@ -724,7 +842,9 @@ def make(self, key): .dropna() .reset_index(drop=True) ) - weights_block.drop(columns=["experiment_name", "block_start"], inplace=True, errors="ignore") + weights_block.drop( + columns=["experiment_name", "block_start"], inplace=True, errors="ignore" + ) weights_block.rename(columns={"weight_timestamps": "time"}, inplace=True) weights_block.set_index("time", inplace=True) weights_block.sort_index(inplace=True) @@ -748,13 +868,17 @@ def make(self, key): # Figure 3 - Cumulative pellet count: over time, per subject, markered by patch # --- # Create dataframe with cumulative pellet count per subject - cum_pel_ct = min_subj_patch_info_plus.sort_values("time").copy().reset_index(drop=True) + cum_pel_ct = ( + min_subj_patch_info_plus.sort_values("time").copy().reset_index(drop=True) + ) patch_means = cum_pel_ct.loc[0:3][["patch", "threshold"]].rename( columns={"threshold": "mean_thresh"} ) patch_means["mean_thresh"] = patch_means["mean_thresh"].astype(float).round(1) cum_pel_ct = cum_pel_ct.merge(patch_means, on="patch", how="left") - cum_pel_ct = cum_pel_ct[~cum_pel_ct["subject"].str.contains("mean")].reset_index(drop=True) + cum_pel_ct = cum_pel_ct[ + ~cum_pel_ct["subject"].str.contains("mean") + ].reset_index(drop=True) cum_pel_ct = ( cum_pel_ct.groupby("subject", group_keys=False) .apply(lambda group: group.assign(counter=np.arange(len(group)) + 1)) @@ -764,7 +888,9 @@ def make(self, key): make_float_cols = ["threshold", "mean_thresh", "norm_time"] cum_pel_ct[make_float_cols] = cum_pel_ct[make_float_cols].astype(float) cum_pel_ct["patch_label"] = ( - cum_pel_ct["patch"] + " μ: " + cum_pel_ct["mean_thresh"].astype(float).round(1).astype(str) + cum_pel_ct["patch"] + + " μ: " + + cum_pel_ct["mean_thresh"].astype(float).round(1).astype(str) ) cum_pel_ct["norm_thresh_val"] = ( (cum_pel_ct["threshold"] - cum_pel_ct["threshold"].min()) @@ -794,7 +920,9 @@ def make(self, key): mode="markers", marker={ "symbol": patch_markers_dict[patch_grp["patch"].iloc[0]], - "color": gen_hex_grad(pel_mrkr_col, patch_grp["norm_thresh_val"]), + "color": gen_hex_grad( + pel_mrkr_col, patch_grp["norm_thresh_val"] + ), "size": 8, }, name=patch_val, @@ -814,7 +942,9 @@ def make(self, key): cum_pel_per_subject_fig = go.Figure() for id_val, id_grp in cum_pel_ct.groupby("subject"): for patch_val, patch_grp in id_grp.groupby("patch"): - cur_p_mean = patch_means[patch_means["patch"] == patch_val]["mean_thresh"].values[0] + cur_p_mean = patch_means[patch_means["patch"] == patch_val][ + "mean_thresh" + ].values[0] cur_p = patch_val.replace("Patch", "P") cum_pel_per_subject_fig.add_trace( go.Scatter( @@ -829,7 +959,9 @@ def make(self, key): # line=dict(width=2, color=subject_colors_dict[id_val]), marker={ "symbol": patch_markers_dict[patch_val], - "color": gen_hex_grad(pel_mrkr_col, patch_grp["norm_thresh_val"]), + "color": gen_hex_grad( + pel_mrkr_col, patch_grp["norm_thresh_val"] + ), "size": 8, }, name=f"{id_val} - {cur_p} - μ: {cur_p_mean}", @@ -846,7 +978,9 @@ def make(self, key): # Figure 5 - Cumulative wheel distance: over time, per subject-patch # --- # Get wheel timestamps for each patch - wheel_ts = (BlockAnalysis.Patch & key).fetch("patch_name", "wheel_timestamps", as_dict=True) + wheel_ts = (BlockAnalysis.Patch & key).fetch( + "patch_name", "wheel_timestamps", as_dict=True + ) wheel_ts = {d["patch_name"]: d["wheel_timestamps"] for d in wheel_ts} # Get subject patch data subj_wheel_cumsum_dist = (BlockSubjectAnalysis.Patch & key).fetch( @@ -866,7 +1000,9 @@ def make(self, key): for subj in subject_names: for patch_name in patch_names: cur_cum_wheel_dist = subj_wheel_cumsum_dist[(subj, patch_name)] - cur_p_mean = patch_means[patch_means["patch"] == patch_name]["mean_thresh"].values[0] + cur_p_mean = patch_means[patch_means["patch"] == patch_name][ + "mean_thresh" + ].values[0] cur_p = patch_name.replace("Patch", "P") cum_wheel_dist_fig.add_trace( go.Scatter( @@ -883,7 +1019,10 @@ def make(self, key): ) # Add markers for each pellet cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[(cum_pel_ct["subject"] == subj) & (cum_pel_ct["patch"] == patch_name)], + cum_pel_ct[ + (cum_pel_ct["subject"] == subj) + & (cum_pel_ct["patch"] == patch_name) + ], pd.DataFrame( { "time": wheel_ts[patch_name], @@ -902,11 +1041,15 @@ def make(self, key): mode="markers", marker={ "symbol": patch_markers_dict[patch_name], - "color": gen_hex_grad(pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"]), + "color": gen_hex_grad( + pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"] + ), "size": 8, }, name=f"{subj} - {cur_p} pellets", - customdata=np.stack((cur_cum_pel_ct["threshold"],), axis=-1), + customdata=np.stack( + (cur_cum_pel_ct["threshold"],), axis=-1 + ), hovertemplate="Threshold: %{customdata[0]:.2f} cm", ) ) @@ -920,10 +1063,14 @@ def make(self, key): # --- # Get and format a dataframe with preference data patch_pref = (BlockSubjectAnalysis.Preference & key).fetch(format="frame") - patch_pref.reset_index(level=["experiment_name", "block_start"], drop=True, inplace=True) + patch_pref.reset_index( + level=["experiment_name", "block_start"], drop=True, inplace=True + ) # Replace small vals with 0 small_pref_thresh = 1e-3 - patch_pref["cumulative_preference_by_wheel"] = patch_pref["cumulative_preference_by_wheel"].apply( + patch_pref["cumulative_preference_by_wheel"] = patch_pref[ + "cumulative_preference_by_wheel" + ].apply( lambda arr: np.where(np.array(arr) < small_pref_thresh, 0, np.array(arr)) ) @@ -931,7 +1078,9 @@ def calculate_running_preference(group, pref_col, out_col): # Sum pref at each ts total_pref = np.sum(np.vstack(group[pref_col].values), axis=0) # Calculate running pref - group[out_col] = group[pref_col].apply(lambda x: np.nan_to_num(x / total_pref, 0.0)) + group[out_col] = group[pref_col].apply( + lambda x: np.nan_to_num(x / total_pref, 0.0) + ) return group patch_pref = ( @@ -960,8 +1109,12 @@ def calculate_running_preference(group, pref_col, out_col): # Add trace for each subject-patch combo for subj in subject_names: for patch_name in patch_names: - cur_run_wheel_pref = patch_pref.loc[patch_name].loc[subj]["running_preference_by_wheel"] - cur_p_mean = patch_means[patch_means["patch"] == patch_name]["mean_thresh"].values[0] + cur_run_wheel_pref = patch_pref.loc[patch_name].loc[subj][ + "running_preference_by_wheel" + ] + cur_p_mean = patch_means[patch_means["patch"] == patch_name][ + "mean_thresh" + ].values[0] cur_p = patch_name.replace("Patch", "P") running_pref_by_wheel_plot.add_trace( go.Scatter( @@ -978,7 +1131,10 @@ def calculate_running_preference(group, pref_col, out_col): ) # Add markers for each pellet cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[(cum_pel_ct["subject"] == subj) & (cum_pel_ct["patch"] == patch_name)], + cum_pel_ct[ + (cum_pel_ct["subject"] == subj) + & (cum_pel_ct["patch"] == patch_name) + ], pd.DataFrame( { "time": wheel_ts[patch_name], @@ -997,11 +1153,15 @@ def calculate_running_preference(group, pref_col, out_col): mode="markers", marker={ "symbol": patch_markers_dict[patch_name], - "color": gen_hex_grad(pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"]), + "color": gen_hex_grad( + pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"] + ), "size": 8, }, name=f"{subj} - {cur_p} pellets", - customdata=np.stack((cur_cum_pel_ct["threshold"],), axis=-1), + customdata=np.stack( + (cur_cum_pel_ct["threshold"],), axis=-1 + ), hovertemplate="Threshold: %{customdata[0]:.2f} cm", ) ) @@ -1017,8 +1177,12 @@ def calculate_running_preference(group, pref_col, out_col): # Add trace for each subject-patch combo for subj in subject_names: for patch_name in patch_names: - cur_run_time_pref = patch_pref.loc[patch_name].loc[subj]["running_preference_by_time"] - cur_p_mean = patch_means[patch_means["patch"] == patch_name]["mean_thresh"].values[0] + cur_run_time_pref = patch_pref.loc[patch_name].loc[subj][ + "running_preference_by_time" + ] + cur_p_mean = patch_means[patch_means["patch"] == patch_name][ + "mean_thresh" + ].values[0] cur_p = patch_name.replace("Patch", "P") running_pref_by_patch_fig.add_trace( go.Scatter( @@ -1035,7 +1199,10 @@ def calculate_running_preference(group, pref_col, out_col): ) # Add markers for each pellet cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[(cum_pel_ct["subject"] == subj) & (cum_pel_ct["patch"] == patch_name)], + cum_pel_ct[ + (cum_pel_ct["subject"] == subj) + & (cum_pel_ct["patch"] == patch_name) + ], pd.DataFrame( { "time": wheel_ts[patch_name], @@ -1054,11 +1221,15 @@ def calculate_running_preference(group, pref_col, out_col): mode="markers", marker={ "symbol": patch_markers_dict[patch_name], - "color": gen_hex_grad(pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"]), + "color": gen_hex_grad( + pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"] + ), "size": 8, }, name=f"{subj} - {cur_p} pellets", - customdata=np.stack((cur_cum_pel_ct["threshold"],), axis=-1), + customdata=np.stack( + (cur_cum_pel_ct["threshold"],), axis=-1 + ), hovertemplate="Threshold: %{customdata[0]:.2f} cm", ) ) @@ -1072,7 +1243,9 @@ def calculate_running_preference(group, pref_col, out_col): # Figure 8 - Weighted patch preference: weighted by 'wheel_dist_spun : pel_ct' ratio # --- # Create multi-indexed dataframe with weighted distance for each subject-patch pair - pel_patches = [p for p in patch_names if "dummy" not in p.lower()] # exclude dummy patches + pel_patches = [ + p for p in patch_names if "dummy" not in p.lower() + ] # exclude dummy patches data = [] for patch in pel_patches: for subject in subject_names: @@ -1085,12 +1258,16 @@ def calculate_running_preference(group, pref_col, out_col): } ) subj_wheel_pel_weighted_dist = pd.DataFrame(data) - subj_wheel_pel_weighted_dist.set_index(["patch_name", "subject_name"], inplace=True) + subj_wheel_pel_weighted_dist.set_index( + ["patch_name", "subject_name"], inplace=True + ) subj_wheel_pel_weighted_dist["weighted_dist"] = np.nan # Calculate weighted distance subject_patch_data = (BlockSubjectAnalysis.Patch() & key).fetch(format="frame") - subject_patch_data.reset_index(level=["experiment_name", "block_start"], drop=True, inplace=True) + subject_patch_data.reset_index( + level=["experiment_name", "block_start"], drop=True, inplace=True + ) subj_wheel_pel_weighted_dist = defaultdict(lambda: defaultdict(dict)) for s in subject_names: for p in pel_patches: @@ -1098,11 +1275,14 @@ def calculate_running_preference(group, pref_col, out_col): cur_wheel_cum_dist_df = pd.DataFrame(columns=["time", "cum_wheel_dist"]) cur_wheel_cum_dist_df["time"] = wheel_ts[p] cur_wheel_cum_dist_df["cum_wheel_dist"] = ( - subject_patch_data.loc[p].loc[s]["wheel_cumsum_distance_travelled"] + 1 + subject_patch_data.loc[p].loc[s]["wheel_cumsum_distance_travelled"] + + 1 ) # Get cumulative pellet count cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[(cum_pel_ct["subject"] == s) & (cum_pel_ct["patch"] == p)], + cum_pel_ct[ + (cum_pel_ct["subject"] == s) & (cum_pel_ct["patch"] == p) + ], cur_wheel_cum_dist_df.sort_values("time"), on="time", direction="forward", @@ -1121,7 +1301,9 @@ def calculate_running_preference(group, pref_col, out_col): on="time", direction="forward", ) - max_weight = cur_cum_pel_ct.iloc[-1]["counter"] + 1 # for values after last pellet + max_weight = ( + cur_cum_pel_ct.iloc[-1]["counter"] + 1 + ) # for values after last pellet merged_df["counter"] = merged_df["counter"].fillna(max_weight) merged_df["weighted_cum_wheel_dist"] = ( merged_df.groupby("counter") @@ -1132,7 +1314,9 @@ def calculate_running_preference(group, pref_col, out_col): else: weighted_dist = cur_wheel_cum_dist_df["cum_wheel_dist"].values # Assign to dict - subj_wheel_pel_weighted_dist[p][s]["time"] = cur_wheel_cum_dist_df["time"].values + subj_wheel_pel_weighted_dist[p][s]["time"] = cur_wheel_cum_dist_df[ + "time" + ].values subj_wheel_pel_weighted_dist[p][s]["weighted_dist"] = weighted_dist # Convert back to dataframe data = [] @@ -1143,11 +1327,15 @@ def calculate_running_preference(group, pref_col, out_col): "patch_name": p, "subject_name": s, "time": subj_wheel_pel_weighted_dist[p][s]["time"], - "weighted_dist": subj_wheel_pel_weighted_dist[p][s]["weighted_dist"], + "weighted_dist": subj_wheel_pel_weighted_dist[p][s][ + "weighted_dist" + ], } ) subj_wheel_pel_weighted_dist = pd.DataFrame(data) - subj_wheel_pel_weighted_dist.set_index(["patch_name", "subject_name"], inplace=True) + subj_wheel_pel_weighted_dist.set_index( + ["patch_name", "subject_name"], inplace=True + ) # Calculate normalized weighted value def norm_inv_norm(group): @@ -1156,20 +1344,28 @@ def norm_inv_norm(group): inv_norm_dist = 1 / norm_dist inv_norm_dist = inv_norm_dist / (np.sum(inv_norm_dist, axis=0)) # Map each inv_norm_dist back to patch name. - return pd.Series(inv_norm_dist.tolist(), index=group.index, name="norm_value") + return pd.Series( + inv_norm_dist.tolist(), index=group.index, name="norm_value" + ) subj_wheel_pel_weighted_dist["norm_value"] = ( subj_wheel_pel_weighted_dist.groupby("subject_name") .apply(norm_inv_norm) .reset_index(level=0, drop=True) ) - subj_wheel_pel_weighted_dist["wheel_pref"] = patch_pref["running_preference_by_wheel"] + subj_wheel_pel_weighted_dist["wheel_pref"] = patch_pref[ + "running_preference_by_wheel" + ] # Plot it weighted_patch_pref_fig = make_subplots( rows=len(pel_patches), cols=len(subject_names), - subplot_titles=[f"{patch} - {subject}" for patch in pel_patches for subject in subject_names], + subplot_titles=[ + f"{patch} - {subject}" + for patch in pel_patches + for subject in subject_names + ], specs=[[{"secondary_y": True}] * len(subject_names)] * len(pel_patches), shared_xaxes=True, vertical_spacing=0.1, @@ -1352,7 +1548,9 @@ def make(self, key): for id_val, id_grp in centroid_df.groupby("identity_name"): # Add counts of x,y points to a grid that will be used for heatmap img_grid = np.zeros((max_x + 1, max_y + 1)) - points, counts = np.unique(id_grp[["x", "y"]].values, return_counts=True, axis=0) + points, counts = np.unique( + id_grp[["x", "y"]].values, return_counts=True, axis=0 + ) for point, count in zip(points, counts, strict=True): img_grid[point[0], point[1]] = count img_grid /= img_grid.max() # normalize @@ -1361,7 +1559,9 @@ def make(self, key): # so 45 cm/frame ~= 9 px/frame win_sz = 9 # in pixels (ensure odd for centering) kernel = np.ones((win_sz, win_sz)) / win_sz**2 # moving avg kernel - img_grid_p = np.pad(img_grid, win_sz // 2, mode="edge") # pad for full output from convolution + img_grid_p = np.pad( + img_grid, win_sz // 2, mode="edge" + ) # pad for full output from convolution img_grid_smooth = conv2d(img_grid_p, kernel) heatmaps.append((id_val, img_grid_smooth)) @@ -1390,11 +1590,17 @@ def make(self, key): # Figure 3 - Position ethogram # --- # Get Active Region (ROI) locations - epoch_query = acquisition.Epoch & (acquisition.Chunk & key & chunk_restriction).proj("epoch_start") + epoch_query = acquisition.Epoch & ( + acquisition.Chunk & key & chunk_restriction + ).proj("epoch_start") active_region_query = acquisition.EpochConfig.ActiveRegion & epoch_query - roi_locs = dict(zip(*active_region_query.fetch("region_name", "region_data"), strict=True)) + roi_locs = dict( + zip(*active_region_query.fetch("region_name", "region_data"), strict=True) + ) # get RFID reader locations - recent_rfid_query = (acquisition.Experiment.proj() * streams.Device.proj() & key).aggr( + recent_rfid_query = ( + acquisition.Experiment.proj() * streams.Device.proj() & key + ).aggr( streams.RfidReader & f"rfid_reader_install_time <= '{block_start}'", rfid_reader_install_time="max(rfid_reader_install_time)", ) @@ -1432,18 +1638,30 @@ def make(self, key): # For each ROI, compute if within ROI for roi in rois: - if roi == "Corridor": # special case for corridor, based on between inner and outer radius + if ( + roi == "Corridor" + ): # special case for corridor, based on between inner and outer radius dist = np.linalg.norm( (np.vstack((centroid_df["x"], centroid_df["y"])).T) - arena_center, axis=1, ) - pos_eth_df[roi] = (dist >= arena_inner_radius) & (dist <= arena_outer_radius) + pos_eth_df[roi] = (dist >= arena_inner_radius) & ( + dist <= arena_outer_radius + ) elif roi == "Nest": # special case for nest, based on 4 corners nest_corners = roi_locs["NestRegion"]["ArrayOfPoint"] - nest_br_x, nest_br_y = int(nest_corners[0]["X"]), int(nest_corners[0]["Y"]) - nest_bl_x, nest_bl_y = int(nest_corners[1]["X"]), int(nest_corners[1]["Y"]) - nest_tl_x, nest_tl_y = int(nest_corners[2]["X"]), int(nest_corners[2]["Y"]) - nest_tr_x, nest_tr_y = int(nest_corners[3]["X"]), int(nest_corners[3]["Y"]) + nest_br_x, nest_br_y = int(nest_corners[0]["X"]), int( + nest_corners[0]["Y"] + ) + nest_bl_x, nest_bl_y = int(nest_corners[1]["X"]), int( + nest_corners[1]["Y"] + ) + nest_tl_x, nest_tl_y = int(nest_corners[2]["X"]), int( + nest_corners[2]["Y"] + ) + nest_tr_x, nest_tr_y = int(nest_corners[3]["X"]), int( + nest_corners[3]["Y"] + ) pos_eth_df[roi] = ( (centroid_df["x"] <= nest_br_x) & (centroid_df["y"] >= nest_br_y) @@ -1457,10 +1675,13 @@ def make(self, key): else: roi_radius = gate_radius if roi == "Gate" else patch_radius # Get ROI coords - roi_x, roi_y = int(rfid_locs[roi + "Rfid"]["X"]), int(rfid_locs[roi + "Rfid"]["Y"]) + roi_x, roi_y = int(rfid_locs[roi + "Rfid"]["X"]), int( + rfid_locs[roi + "Rfid"]["Y"] + ) # Check if in ROI dist = np.linalg.norm( - (np.vstack((centroid_df["x"], centroid_df["y"])).T) - (roi_x, roi_y), + (np.vstack((centroid_df["x"], centroid_df["y"])).T) + - (roi_x, roi_y), axis=1, ) pos_eth_df[roi] = dist < roi_radius @@ -1566,7 +1787,8 @@ class AnalysisNote(dj.Manual): def get_threshold_associated_pellets(patch_key, start, end): - """Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time. + """Retrieve the pellet delivery timestamps associated with each patch threshold update + within the specified start-end time. 1. Get all patch state update timestamps (DepletionState): let's call these events "A" - Remove all events within 1 second of each other @@ -1574,8 +1796,10 @@ def get_threshold_associated_pellets(patch_key, start, end): 2. Get all pellet delivery timestamps (DeliverPellet): let's call these events "B" - Find matching beam break timestamps within 1.2s after each pellet delivery 3. For each event "A", find the nearest event "B" within 100ms before or after the event "A" - - These are the pellet delivery events "B" associated with the previous threshold update event "A" - 4. Shift back the pellet delivery timestamps by 1 to match the pellet delivery with the previous threshold update + - These are the pellet delivery events "B" associated with the previous threshold update + event "A" + 4. Shift back the pellet delivery timestamps by 1 to match the pellet delivery with the + previous threshold update 5. Remove all threshold updates events "A" without a corresponding pellet delivery event "B" Args: patch_key (dict): primary key for the patch @@ -1589,7 +1813,9 @@ def get_threshold_associated_pellets(patch_key, start, end): - offset - rate """ - chunk_restriction = acquisition.create_chunk_restriction(patch_key["experiment_name"], start, end) + chunk_restriction = acquisition.create_chunk_restriction( + patch_key["experiment_name"], start, end + ) # Step 1 - fetch data # pellet delivery trigger @@ -1597,9 +1823,9 @@ def get_threshold_associated_pellets(patch_key, start, end): streams.UndergroundFeederDeliverPellet & patch_key & chunk_restriction )[start:end] # beambreak - beambreak_df = fetch_stream(streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction)[ - start:end - ] + beambreak_df = fetch_stream( + streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction + )[start:end] # patch threshold depletion_state_df = fetch_stream( streams.UndergroundFeederDepletionState & patch_key & chunk_restriction @@ -1651,14 +1877,18 @@ def get_threshold_associated_pellets(patch_key, start, end): .set_index("time") .dropna(subset=["beam_break_timestamp"]) ) - pellet_beam_break_df.drop_duplicates(subset="beam_break_timestamp", keep="last", inplace=True) + pellet_beam_break_df.drop_duplicates( + subset="beam_break_timestamp", keep="last", inplace=True + ) # Find pellet delivery triggers that approximately coincide with each threshold update # i.e. nearest pellet delivery within 100ms before or after threshold update pellet_ts_threshold_df = ( pd.merge_asof( depletion_state_df.reset_index(), - pellet_beam_break_df.reset_index().rename(columns={"time": "pellet_timestamp"}), + pellet_beam_break_df.reset_index().rename( + columns={"time": "pellet_timestamp"} + ), left_on="time", right_on="pellet_timestamp", tolerance=pd.Timedelta("100ms"), @@ -1671,8 +1901,12 @@ def get_threshold_associated_pellets(patch_key, start, end): # Clean up the df pellet_ts_threshold_df = pellet_ts_threshold_df.drop(columns=["event_x", "event_y"]) # Shift back the pellet_timestamp values by 1 to match with the previous threshold update - pellet_ts_threshold_df.pellet_timestamp = pellet_ts_threshold_df.pellet_timestamp.shift(-1) - pellet_ts_threshold_df.beam_break_timestamp = pellet_ts_threshold_df.beam_break_timestamp.shift(-1) + pellet_ts_threshold_df.pellet_timestamp = ( + pellet_ts_threshold_df.pellet_timestamp.shift(-1) + ) + pellet_ts_threshold_df.beam_break_timestamp = ( + pellet_ts_threshold_df.beam_break_timestamp.shift(-1) + ) pellet_ts_threshold_df = pellet_ts_threshold_df.dropna( subset=["pellet_timestamp", "beam_break_timestamp"] ) @@ -1699,8 +1933,12 @@ def get_foraging_bouts( Returns: DataFrame containing foraging bouts. Columns: duration, n_pellets, cum_wheel_dist, subject. """ - max_inactive_time = pd.Timedelta(seconds=60) if max_inactive_time is None else max_inactive_time - bout_data = pd.DataFrame(columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"]) + max_inactive_time = ( + pd.Timedelta(seconds=60) if max_inactive_time is None else max_inactive_time + ) + bout_data = pd.DataFrame( + columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"] + ) subject_patch_data = (BlockSubjectAnalysis.Patch() & key).fetch(format="frame") if subject_patch_data.empty: return bout_data @@ -1744,34 +1982,52 @@ def get_foraging_bouts( wheel_s_r = pd.Timedelta(wheel_ts[1] - wheel_ts[0], unit="ns") max_inactive_win_len = int(max_inactive_time / wheel_s_r) # Find times when foraging - max_windowed_wheel_vals = patch_spun_df["cum_wheel_dist"].shift(-(max_inactive_win_len - 1)).ffill() - foraging_mask = max_windowed_wheel_vals > (patch_spun_df["cum_wheel_dist"] + min_wheel_movement) + max_windowed_wheel_vals = ( + patch_spun_df["cum_wheel_dist"].shift(-(max_inactive_win_len - 1)).ffill() + ) + foraging_mask = max_windowed_wheel_vals > ( + patch_spun_df["cum_wheel_dist"] + min_wheel_movement + ) # Discretize into foraging bouts - bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + (max_inactive_win_len - 1) + bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + ( + max_inactive_win_len - 1 + ) n_samples_in_1s = int(1 / wheel_s_r.total_seconds()) bout_end_indxs = ( np.where(np.diff(foraging_mask, prepend=0) == -1)[0] + (max_inactive_win_len - 1) + n_samples_in_1s ) - bout_end_indxs[-1] = min(bout_end_indxs[-1], len(wheel_ts) - 1) # ensure last bout ends in block + bout_end_indxs[-1] = min( + bout_end_indxs[-1], len(wheel_ts) - 1 + ) # ensure last bout ends in block # Remove bout that starts at block end if bout_start_indxs[-1] >= len(wheel_ts): bout_start_indxs = bout_start_indxs[:-1] bout_end_indxs = bout_end_indxs[:-1] if len(bout_start_indxs) != len(bout_end_indxs): - raise ValueError("Mismatch between the lengths of bout_start_indxs and bout_end_indxs.") - bout_durations = (wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs]).astype( # in seconds + raise ValueError( + "Mismatch between the lengths of bout_start_indxs and bout_end_indxs." + ) + bout_durations = ( + wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs] + ).astype( # in seconds "timedelta64[ns]" - ).astype(float) / 1e9 + ).astype( + float + ) / 1e9 bout_starts_ends = np.array( [ (wheel_ts[start_idx], wheel_ts[end_idx]) - for start_idx, end_idx in zip(bout_start_indxs, bout_end_indxs, strict=True) + for start_idx, end_idx in zip( + bout_start_indxs, bout_end_indxs, strict=True + ) ] ) all_pel_ts = np.sort( - np.concatenate([arr for arr in cur_subject_data["pellet_timestamps"] if len(arr) > 0]) + np.concatenate( + [arr for arr in cur_subject_data["pellet_timestamps"] if len(arr) > 0] + ) ) bout_pellets = np.array( [ @@ -1785,7 +2041,8 @@ def get_foraging_bouts( bout_pellets = bout_pellets[bout_pellets >= min_pellets] bout_cum_wheel_dist = np.array( [ - patch_spun_df.loc[end, "cum_wheel_dist"] - patch_spun_df.loc[start, "cum_wheel_dist"] + patch_spun_df.loc[end, "cum_wheel_dist"] + - patch_spun_df.loc[start, "cum_wheel_dist"] for start, end in bout_starts_ends ] ) diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index 7c6e6077..f4d69153 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -70,15 +70,15 @@ class Visit(dj.Part): @property def key_source(self): """Key source for OverlapVisit.""" - return dj.U("experiment_name", "place", "overlap_start") & (Visit & VisitEnd).proj( - overlap_start="visit_start" - ) + return dj.U("experiment_name", "place", "overlap_start") & ( + Visit & VisitEnd + ).proj(overlap_start="visit_start") def make(self, key): """Populate OverlapVisit table with overlapping visits.""" - visit_starts, visit_ends = (Visit * VisitEnd & key & {"visit_start": key["overlap_start"]}).fetch( - "visit_start", "visit_end" - ) + visit_starts, visit_ends = ( + Visit * VisitEnd & key & {"visit_start": key["overlap_start"]} + ).fetch("visit_start", "visit_end") visit_start = min(visit_starts) visit_end = max(visit_ends) @@ -92,7 +92,9 @@ def make(self, key): if len(overlap_query) <= 1: break overlap_visits.extend( - overlap_query.proj(overlap_start=f'"{key["overlap_start"]}"').fetch(as_dict=True) + overlap_query.proj(overlap_start=f'"{key["overlap_start"]}"').fetch( + as_dict=True + ) ) visit_starts, visit_ends = overlap_query.fetch("visit_start", "visit_end") if visit_start == max(visit_starts) and visit_end == max(visit_ends): @@ -106,7 +108,10 @@ def make(self, key): { **key, "overlap_end": visit_end, - "overlap_duration": (visit_end - key["overlap_start"]).total_seconds() / 3600, + "overlap_duration": ( + visit_end - key["overlap_start"] + ).total_seconds() + / 3600, "subject_count": len({v["subject"] for v in overlap_visits}), } ) @@ -117,10 +122,14 @@ def make(self, key): def ingest_environment_visits(experiment_names: list | None = None): - """Function to populate into `Visit` and `VisitEnd` for specified experiments (default: 'exp0.2-r0'). This ingestion routine handles only those "complete" visits, not ingesting any "on-going" visits using "analyze" method: `aeon.analyze.utils.visits()`. + """Function to populate into `Visit` and `VisitEnd` for specified + experiments (default: 'exp0.2-r0'). This ingestion routine handles + only those "complete" visits, not ingesting any "on-going" visits + using "analyze" method: `aeon.analyze.utils.visits()`. Args: - experiment_names (list, optional): list of names of the experiment to populate into the Visit table. Defaults to None. + experiment_names (list, optional): list of names of the experiment + to populate into the Visit table. Defaults to None. """ if experiment_names is None: @@ -193,16 +202,22 @@ def ingest_environment_visits(experiment_names: list | None = None): def get_maintenance_periods(experiment_name, start, end): """Get maintenance periods for the specified experiment and time range.""" # get states from acquisition.Environment.EnvironmentState - chunk_restriction = acquisition.create_chunk_restriction(experiment_name, start, end) + chunk_restriction = acquisition.create_chunk_restriction( + experiment_name, start, end + ) state_query = ( - acquisition.Environment.EnvironmentState & {"experiment_name": experiment_name} & chunk_restriction + acquisition.Environment.EnvironmentState + & {"experiment_name": experiment_name} + & chunk_restriction ) env_state_df = fetch_stream(state_query)[start:end] if env_state_df.empty: return deque([]) env_state_df.reset_index(inplace=True) - env_state_df = env_state_df[env_state_df["state"].shift() != env_state_df["state"]].reset_index( + env_state_df = env_state_df[ + env_state_df["state"].shift() != env_state_df["state"] + ].reset_index( drop=True ) # remove duplicates and keep the first one # An experiment starts with visit start (anything before the first maintenance is experiment) @@ -218,8 +233,12 @@ def get_maintenance_periods(experiment_name, start, end): env_state_df = pd.concat([env_state_df, log_df_end]) env_state_df.reset_index(drop=True, inplace=True) - maintenance_starts = env_state_df.loc[env_state_df["state"] == "Maintenance", "time"].values - maintenance_ends = env_state_df.loc[env_state_df["state"] != "Maintenance", "time"].values + maintenance_starts = env_state_df.loc[ + env_state_df["state"] == "Maintenance", "time" + ].values + maintenance_ends = env_state_df.loc[ + env_state_df["state"] != "Maintenance", "time" + ].values return deque( [ @@ -236,7 +255,9 @@ def filter_out_maintenance_periods(data_df, maintenance_period, end_time, dropna (maintenance_start, maintenance_end) = maint_period[0] if end_time < maintenance_start: # no more maintenance for this date break - maintenance_filter = (data_df.index >= maintenance_start) & (data_df.index <= maintenance_end) + maintenance_filter = (data_df.index >= maintenance_start) & ( + data_df.index <= maintenance_end + ) data_df[maintenance_filter] = np.nan if end_time >= maintenance_end: # remove this range maint_period.popleft() diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 88922f05..0f5047f9 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -91,19 +91,23 @@ def key_source(self): + chunk starts after visit_start and ends before visit_end (or NOW() - i.e. ongoing visits). """ return ( - Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") * acquisition.Chunk + Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") + * acquisition.Chunk & acquisition.SubjectEnterExit & [ "visit_start BETWEEN chunk_start AND chunk_end", "visit_end BETWEEN chunk_start AND chunk_end", "chunk_start >= visit_start AND chunk_end <= visit_end", ] - & "chunk_start < chunk_end" # in some chunks, end timestamp comes before start (timestamp error) + & "chunk_start < chunk_end" + # in some chunks, end timestamp comes before start (timestamp error) ) def make(self, key): """Populate VisitSubjectPosition for each visit""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) # -- Determine the time to start time_slicing in this chunk if chunk_start < key["visit_start"] < chunk_end: @@ -142,7 +146,8 @@ def make(self, key): object_id = (tracking.CameraTracking.Object & key).fetch1("object_id") else: logger.info( - '"More than one unique object ID found - using animal/object mapping from AnimalObjectMapping"' + "More than one unique object ID found - " + "using animal/object mapping from AnimalObjectMapping" ) if not (AnimalObjectMapping & key): raise ValueError( @@ -172,8 +177,12 @@ def make(self, key): end_time = np.array(end_time, dtype="datetime64[ns]") while time_slice_start < end_time: - time_slice_end = time_slice_start + min(self._time_slice_duration, end_time - time_slice_start) - in_time_slice = np.logical_and(timestamps >= time_slice_start, timestamps < time_slice_end) + time_slice_end = time_slice_start + min( + self._time_slice_duration, end_time - time_slice_start + ) + in_time_slice = np.logical_and( + timestamps >= time_slice_start, timestamps < time_slice_end + ) chunk_time_slices.append( { **key, @@ -193,12 +202,18 @@ def make(self, key): @classmethod def get_position(cls, visit_key=None, subject=None, start=None, end=None): - """Given a key to a single Visit, return a Pandas DataFrame for the position data of the subject for the specified Visit time period.""" + """Given a key to a single Visit, return a Pandas DataFrame for + the position data of the subject for the specified Visit time period.""" if visit_key is not None: if len(Visit & visit_key) != 1: - raise ValueError("The `visit_key` must correspond to exactly one Visit.") + raise ValueError( + "The `visit_key` must correspond to exactly one Visit." + ) start, end = ( - Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") & visit_key + Visit.join(VisitEnd, left=True).proj( + visit_end="IFNULL(visit_end, NOW())" + ) + & visit_key ).fetch1("visit_start", "visit_end") subject = visit_key["subject"] elif all((subject, start, end)): @@ -246,7 +261,7 @@ class Nest(dj.Part): -> lab.ArenaNest --- time_fraction_in_nest: float # fraction of time the animal spent in this nest in this visit - in_nest: longblob # array of indices for when the animal is in this nest (index into the position data) + in_nest: longblob # indices array marking when the animal is in this nest """ class FoodPatch(dj.Part): @@ -259,7 +274,9 @@ class FoodPatch(dj.Part): """ # Work on finished visits - key_source = Visit & (VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end") + key_source = Visit & ( + VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end" + ) def make(self, key): """Populate VisitTimeDistribution for each visit""" @@ -267,7 +284,9 @@ def make(self, key): visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) ) - maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) + maintenance_period = get_maintenance_periods( + key["experiment_name"], visit_start, visit_end + ) for visit_date in visit_dates: day_start = datetime.datetime.combine(visit_date.date(), time.min) @@ -287,12 +306,16 @@ def make(self, key): subject=key["subject"], start=day_start, end=day_end ) # filter out maintenance period based on logs - position = filter_out_maintenance_periods(position, maintenance_period, day_end) + position = filter_out_maintenance_periods( + position, maintenance_period, day_end + ) # filter for objects of the correct size valid_position = (position.area > 0) & (position.area < 1000) position[~valid_position] = np.nan - position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) + position.rename( + columns={"position_x": "x", "position_y": "y"}, inplace=True + ) # in corridor distance_from_center = tracking.compute_distance( position[["x", "y"]], @@ -336,9 +359,9 @@ def make(self, key): in_food_patch_times = [] for food_patch_key in food_patch_keys: # wheel data - food_patch_description = (acquisition.ExperimentFoodPatch & food_patch_key).fetch1( - "food_patch_description" - ) + food_patch_description = ( + acquisition.ExperimentFoodPatch & food_patch_key + ).fetch1("food_patch_description") wheel_data = acquisition.FoodPatchWheel.get_wheel_data( experiment_name=key["experiment_name"], start=pd.Timestamp(day_start), @@ -347,10 +370,12 @@ def make(self, key): using_aeon_io=True, ) # filter out maintenance period based on logs - wheel_data = filter_out_maintenance_periods(wheel_data, maintenance_period, day_end) - patch_position = (acquisition.ExperimentFoodPatch.Position & food_patch_key).fetch1( - "food_patch_position_x", "food_patch_position_y" + wheel_data = filter_out_maintenance_periods( + wheel_data, maintenance_period, day_end ) + patch_position = ( + acquisition.ExperimentFoodPatch.Position & food_patch_key + ).fetch1("food_patch_position_x", "food_patch_position_y") in_patch = tracking.is_position_in_patch( position, patch_position, @@ -391,7 +416,7 @@ class VisitSummary(dj.Computed): --- day_duration: float # total duration (in hours) total_distance_travelled: float # (m) total distance the animal travelled during this visit - total_pellet_count: int # total pellet delivered (triggered) for all patches during this visit + total_pellet_count: int # total pellet triggered for all patches during this visit total_wheel_distance_travelled: float # total wheel travelled distance for all patches """ @@ -400,12 +425,14 @@ class FoodPatch(dj.Part): -> master -> acquisition.ExperimentFoodPatch --- - pellet_count: int # number of pellets being delivered (triggered) by this patch during this visit + pellet_count: int # number of pellets delivered by this patch during this visit wheel_distance_travelled: float # wheel travelled distance during this visit for this patch """ # Work on finished visits - key_source = Visit & (VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end") + key_source = Visit & ( + VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end" + ) def make(self, key): """Populate VisitSummary for each visit""" @@ -413,7 +440,9 @@ def make(self, key): visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) ) - maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) + maintenance_period = get_maintenance_periods( + key["experiment_name"], visit_start, visit_end + ) for visit_date in visit_dates: day_start = datetime.datetime.combine(visit_date.date(), time.min) @@ -434,12 +463,18 @@ def make(self, key): subject=key["subject"], start=day_start, end=day_end ) # filter out maintenance period based on logs - position = filter_out_maintenance_periods(position, maintenance_period, day_end) + position = filter_out_maintenance_periods( + position, maintenance_period, day_end + ) # filter for objects of the correct size valid_position = (position.area > 0) & (position.area < 1000) position[~valid_position] = np.nan - position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) - position_diff = np.sqrt(np.square(np.diff(position.x)) + np.square(np.diff(position.y))) + position.rename( + columns={"position_x": "x", "position_y": "y"}, inplace=True + ) + position_diff = np.sqrt( + np.square(np.diff(position.x)) + np.square(np.diff(position.y)) + ) total_distance_travelled = np.nansum(position_diff) # in food patches - loop through all in-use patches during this visit @@ -475,9 +510,9 @@ def make(self, key): dropna=True, ).index.values # wheel data - food_patch_description = (acquisition.ExperimentFoodPatch & food_patch_key).fetch1( - "food_patch_description" - ) + food_patch_description = ( + acquisition.ExperimentFoodPatch & food_patch_key + ).fetch1("food_patch_description") wheel_data = acquisition.FoodPatchWheel.get_wheel_data( experiment_name=key["experiment_name"], start=pd.Timestamp(day_start), @@ -486,7 +521,9 @@ def make(self, key): using_aeon_io=True, ) # filter out maintenance period based on logs - wheel_data = filter_out_maintenance_periods(wheel_data, maintenance_period, day_end) + wheel_data = filter_out_maintenance_periods( + wheel_data, maintenance_period, day_end + ) food_patch_statistics.append( { @@ -494,11 +531,15 @@ def make(self, key): **food_patch_key, "visit_date": visit_date.date(), "pellet_count": len(pellet_events), - "wheel_distance_travelled": wheel_data.distance_travelled.values[-1], + "wheel_distance_travelled": wheel_data.distance_travelled.values[ + -1 + ], } ) - total_pellet_count = np.sum([p["pellet_count"] for p in food_patch_statistics]) + total_pellet_count = np.sum( + [p["pellet_count"] for p in food_patch_statistics] + ) total_wheel_distance_travelled = np.sum( [p["wheel_distance_travelled"] for p in food_patch_statistics] ) @@ -519,7 +560,12 @@ def make(self, key): @schema class VisitForagingBout(dj.Computed): - definition = """ # A time period spanning the time when the animal enters a food patch and moves the wheel to when it leaves the food patch + """ + A time period spanning the time when the animal enters a food patch and moves + the wheel to when it leaves the food patch. + """ + + definition = """ -> Visit -> acquisition.ExperimentFoodPatch bout_start: datetime(6) # start time of bout @@ -532,7 +578,10 @@ class VisitForagingBout(dj.Computed): # Work on 24/7 experiments key_source = ( - Visit & VisitSummary & (VisitEnd & "visit_duration > 24") & "experiment_name= 'exp0.2-r0'" + Visit + & VisitSummary + & (VisitEnd & "visit_duration > 24") + & "experiment_name= 'exp0.2-r0'" ) * acquisition.ExperimentFoodPatch def make(self, key): @@ -540,13 +589,17 @@ def make(self, key): visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") # get in_patch timestamps - food_patch_description = (acquisition.ExperimentFoodPatch & key).fetch1("food_patch_description") + food_patch_description = (acquisition.ExperimentFoodPatch & key).fetch1( + "food_patch_description" + ) in_patch_times = np.concatenate( - (VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & key).fetch( - "in_patch", order_by="visit_date" - ) + ( + VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & key + ).fetch("in_patch", order_by="visit_date") + ) + maintenance_period = get_maintenance_periods( + key["experiment_name"], visit_start, visit_end ) - maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) in_patch_times = filter_out_maintenance_periods( pd.DataFrame( [[food_patch_description]] * len(in_patch_times), @@ -574,8 +627,12 @@ def make(self, key): .set_index("event_time") ) # TODO: handle multiple retries of pellet delivery - maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) - patch = filter_out_maintenance_periods(patch, maintenance_period, visit_end, True) + maintenance_period = get_maintenance_periods( + key["experiment_name"], visit_start, visit_end + ) + patch = filter_out_maintenance_periods( + patch, maintenance_period, visit_end, True + ) if len(in_patch_times): change_ind = ( @@ -591,7 +648,9 @@ def make(self, key): ts_array = in_patch_times[change_ind[i - 1] : change_ind[i]] wheel_start, wheel_end = ts_array[0], ts_array[-1] - if wheel_start >= wheel_end: # skip if timestamps were misaligned or a single timestamp + if ( + wheel_start >= wheel_end + ): # skip if timestamps were misaligned or a single timestamp continue wheel_data = acquisition.FoodPatchWheel.get_wheel_data( @@ -601,14 +660,19 @@ def make(self, key): patch_name=food_patch_description, using_aeon_io=True, ) - maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) - wheel_data = filter_out_maintenance_periods(wheel_data, maintenance_period, visit_end, True) + maintenance_period = get_maintenance_periods( + key["experiment_name"], visit_start, visit_end + ) + wheel_data = filter_out_maintenance_periods( + wheel_data, maintenance_period, visit_end, True + ) self.insert1( { **key, "bout_start": ts_array[0], "bout_end": ts_array[-1], - "bout_duration": (ts_array[-1] - ts_array[0]) / np.timedelta64(1, "s"), + "bout_duration": (ts_array[-1] - ts_array[0]) + / np.timedelta64(1, "s"), "wheel_distance_travelled": wheel_data.distance_travelled[-1], "pellet_count": len(patch.loc[wheel_start:wheel_end]), } diff --git a/aeon/dj_pipeline/create_experiments/create_experiment_01.py b/aeon/dj_pipeline/create_experiments/create_experiment_01.py index 18edb4c3..cb66455d 100644 --- a/aeon/dj_pipeline/create_experiments/create_experiment_01.py +++ b/aeon/dj_pipeline/create_experiments/create_experiment_01.py @@ -255,7 +255,9 @@ def add_arena_setup(): # manually update coordinates of foodpatch and nest patch_coordinates = {"Patch1": (1.13, 1.59, 0), "Patch2": (1.19, 0.50, 0)} - for patch_key in (acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name}).fetch("KEY"): + for patch_key in (acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name}).fetch( + "KEY" + ): patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1("food_patch_description") x, y, z = patch_coordinates[patch] acquisition.ExperimentFoodPatch.Position.update1( diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py index 497cc9e9..f0f0dde8 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py @@ -38,7 +38,10 @@ def create_new_experiment(): skip_duplicates=True, ) acquisition.Experiment.Subject.insert( - [{"experiment_name": experiment_name, "subject": s["subject"]} for s in subject_list], + [ + {"experiment_name": experiment_name, "subject": s["subject"]} + for s in subject_list + ], skip_duplicates=True, ) @@ -94,8 +97,12 @@ def add_arena_setup(): # manually update coordinates of foodpatch and nest patch_coordinates = {"Patch1": (1.13, 1.59, 0), "Patch2": (1.19, 0.50, 0)} - for patch_key in (acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name}).fetch("KEY"): - patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1("food_patch_description") + for patch_key in ( + acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name} + ).fetch("KEY"): + patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1( + "food_patch_description" + ) x, y, z = patch_coordinates[patch] acquisition.ExperimentFoodPatch.Position.update1( { @@ -155,13 +162,15 @@ def fixID(subjid, valid_ids=None, valid_id_file=None): if ";" in subjid: subjidA, subjidB = subjid.split(";") return ( - f"{fixID(subjidA.strip(), valid_ids=valid_ids)};{fixID(subjidB.strip(), valid_ids=valid_ids)}" + f"{fixID(subjidA.strip(), valid_ids=valid_ids)};" + f"{fixID(subjidB.strip(), valid_ids=valid_ids)}" ) if "vs" in subjid: subjidA, tmp, subjidB = subjid.split(" ")[1:] return ( - f"{fixID(subjidA.strip(), valid_ids=valid_ids)};{fixID(subjidB.strip(), valid_ids=valid_ids)}" + f"{fixID(subjidA.strip(), valid_ids=valid_ids)};" + f"{fixID(subjidB.strip(), valid_ids=valid_ids)}" ) try: diff --git a/aeon/dj_pipeline/populate/process.py b/aeon/dj_pipeline/populate/process.py index 5c2e4d15..d49fead5 100644 --- a/aeon/dj_pipeline/populate/process.py +++ b/aeon/dj_pipeline/populate/process.py @@ -1,8 +1,12 @@ """Start an Aeon ingestion process. -This script defines auto-processing routines to operate the DataJoint pipeline for the Aeon project. Three separate "process" functions are defined to call `populate()` for different groups of tables, depending on their priority in the ingestion routines (high, mid, low). +This script defines auto-processing routines to operate the DataJoint pipeline +for the Aeon project. Three separate "process" functions are defined to call +`populate()` for different groups of tables, depending on their priority in +the ingestion routines (high, mid, low). -Each process function is run in a while-loop with the total run-duration configurable via command line argument '--duration' (if not set, runs perpetually) +Each process function is run in a while-loop with the total run-duration configurable +via command line argument '--duration' (if not set, runs perpetually) - the loop will not begin a new cycle after this period of time (in seconds) - the loop will run perpetually if duration<0 or if duration==None - the script will not be killed _at_ this limit, it will keep executing, @@ -22,7 +26,8 @@ Usage from python: - `from aeon.dj_pipeline.populate.process import run; run(worker_name='high_priority', duration=20, sleep=5)` + `from aeon.dj_pipeline.populate.process import run + run(worker_name='high_priority', duration=20, sleep=5)` """ @@ -76,7 +81,9 @@ def run(**kwargs): try: worker.run() except Exception: - logger.exception("action '{}' encountered an exception:".format(kwargs["worker_name"])) + logger.exception( + "action '{}' encountered an exception:".format(kwargs["worker_name"]) + ) logger.info("Ingestion process ended.") diff --git a/aeon/dj_pipeline/populate/worker.py b/aeon/dj_pipeline/populate/worker.py index c93a9bce..35e93da6 100644 --- a/aeon/dj_pipeline/populate/worker.py +++ b/aeon/dj_pipeline/populate/worker.py @@ -117,4 +117,6 @@ def get_workflow_operation_overview(): """Get the workflow operation overview for the worker schema.""" from datajoint_utilities.dj_worker.utils import get_workflow_operation_overview - return get_workflow_operation_overview(worker_schema_name=worker_schema_name, db_prefixes=[db_prefix]) + return get_workflow_operation_overview( + worker_schema_name=worker_schema_name, db_prefixes=[db_prefix] + ) diff --git a/aeon/dj_pipeline/qc.py b/aeon/dj_pipeline/qc.py index 5fa101ef..cc665885 100644 --- a/aeon/dj_pipeline/qc.py +++ b/aeon/dj_pipeline/qc.py @@ -31,7 +31,7 @@ class QCRoutine(dj.Lookup): --- qc_routine_order: int # the order in which this qc routine is executed qc_routine_description: varchar(255) # description of this QC routine - qc_module: varchar(64) # the module where the qc_function can be imported from - e.g. aeon.analysis.quality_control + qc_module: varchar(64) # module path, e.g., aeon.analysis.quality_control qc_function: varchar(64) # the function used to evaluate this QC - e.g. check_drop_frame """ @@ -62,7 +62,9 @@ def key_source(self): return ( acquisition.Chunk * ( - streams.SpinnakerVideoSource.join(streams.SpinnakerVideoSource.RemovalTime, left=True) + streams.SpinnakerVideoSource.join( + streams.SpinnakerVideoSource.RemovalTime, left=True + ) & "spinnaker_video_source_name='CameraTop'" ) & "chunk_start >= spinnaker_video_source_install_time" @@ -71,16 +73,21 @@ def key_source(self): def make(self, key): """Perform quality control checks on the CameraTop stream""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) - device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name") + device_name = (streams.SpinnakerVideoSource & key).fetch1( + "spinnaker_video_source_name" + ) data_dirs = acquisition.Experiment.get_data_directories(key) devices_schema = getattr( acquisition.aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(devices_schema, device_name).Video @@ -107,9 +114,11 @@ def make(self, key): **key, "drop_count": deltas.frame_offset.iloc[-1], "max_harp_delta": deltas.time_delta.max().total_seconds(), - "max_camera_delta": deltas.hw_timestamp_delta.max() / 1e9, # convert to seconds + "max_camera_delta": deltas.hw_timestamp_delta.max() + / 1e9, # convert to seconds "timestamps": videodata.index.values, - "time_delta": deltas.time_delta.values / np.timedelta64(1, "s"), # convert to seconds + "time_delta": deltas.time_delta.values + / np.timedelta64(1, "s"), # convert to seconds "frame_delta": deltas.frame_delta.values, "hw_counter_delta": deltas.hw_counter_delta.values, "hw_timestamp_delta": deltas.hw_timestamp_delta.values, diff --git a/aeon/dj_pipeline/report.py b/aeon/dj_pipeline/report.py index d09bfae2..b016fde7 100644 --- a/aeon/dj_pipeline/report.py +++ b/aeon/dj_pipeline/report.py @@ -32,7 +32,9 @@ class InArenaSummaryPlot(dj.Computed): summary_plot_png: attach """ - key_source = analysis.InArena & analysis.InArenaTimeDistribution & analysis.InArenaSummary + key_source = ( + analysis.InArena & analysis.InArenaTimeDistribution & analysis.InArenaSummary + ) color_code = { "Patch1": "b", @@ -44,15 +46,17 @@ class InArenaSummaryPlot(dj.Computed): def make(self, key): """Make method for InArenaSummaryPlot table""" - in_arena_start, in_arena_end = (analysis.InArena * analysis.InArenaEnd & key).fetch1( - "in_arena_start", "in_arena_end" - ) + in_arena_start, in_arena_end = ( + analysis.InArena * analysis.InArenaEnd & key + ).fetch1("in_arena_start", "in_arena_end") # subject's position data in the time_slices position = analysis.InArenaSubjectPosition.get_position(key) position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) - position_minutes_elapsed = (position.index - in_arena_start).total_seconds() / 60 + position_minutes_elapsed = ( + position.index - in_arena_start + ).total_seconds() / 60 # figure fig = plt.figure(figsize=(20, 9)) @@ -67,12 +71,16 @@ def make(self, key): # position plot non_nan = np.logical_and(~np.isnan(position.x), ~np.isnan(position.y)) - analysis_plotting.heatmap(position[non_nan], 50, ax=position_ax, bins=500, alpha=0.5) + analysis_plotting.heatmap( + position[non_nan], 50, ax=position_ax, bins=500, alpha=0.5 + ) # event rate plots in_arena_food_patches = ( analysis.InArena - * acquisition.ExperimentFoodPatch.join(acquisition.ExperimentFoodPatch.RemovalTime, left=True) + * acquisition.ExperimentFoodPatch.join( + acquisition.ExperimentFoodPatch.RemovalTime, left=True + ) & key & "in_arena_start >= food_patch_install_time" & 'in_arena_start < IFNULL(food_patch_remove_time, "2200-01-01")' @@ -139,7 +147,9 @@ def make(self, key): color=self.color_code[food_patch_key["food_patch_description"]], alpha=0.3, ) - threshold_change_ind = np.where(wheel_threshold[:-1] != wheel_threshold[1:])[0] + threshold_change_ind = np.where( + wheel_threshold[:-1] != wheel_threshold[1:] + )[0] threshold_ax.vlines( wheel_time[threshold_change_ind + 1], ymin=wheel_threshold[threshold_change_ind], @@ -151,17 +161,20 @@ def make(self, key): ) # ethogram - in_arena, in_corridor, arena_time, corridor_time = (analysis.InArenaTimeDistribution & key).fetch1( + in_arena, in_corridor, arena_time, corridor_time = ( + analysis.InArenaTimeDistribution & key + ).fetch1( "in_arena", "in_corridor", "time_fraction_in_arena", "time_fraction_in_corridor", ) - nest_keys, in_nests, nests_times = (analysis.InArenaTimeDistribution.Nest & key).fetch( - "KEY", "in_nest", "time_fraction_in_nest" - ) + nest_keys, in_nests, nests_times = ( + analysis.InArenaTimeDistribution.Nest & key + ).fetch("KEY", "in_nest", "time_fraction_in_nest") patch_names, in_patches, patches_times = ( - analysis.InArenaTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & key + analysis.InArenaTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch + & key ).fetch("food_patch_description", "in_patch", "time_fraction_in_patch") ethogram_ax.plot( @@ -192,7 +205,9 @@ def make(self, key): alpha=0.6, label="nest", ) - for patch_idx, (patch_name, in_patch) in enumerate(zip(patch_names, in_patches)): + for patch_idx, (patch_name, in_patch) in enumerate( + zip(patch_names, in_patches) + ): ethogram_ax.plot( position_minutes_elapsed[in_patch], np.full_like(position_minutes_elapsed[in_patch], (patch_idx + 3)), @@ -233,7 +248,9 @@ def make(self, key): rate_ax.set_title("foraging rate (bin size = 10 min)") distance_ax.set_ylabel("distance travelled (m)") threshold_ax.set_ylabel("threshold") - threshold_ax.set_ylim([threshold_ax.get_ylim()[0] - 100, threshold_ax.get_ylim()[1] + 100]) + threshold_ax.set_ylim( + [threshold_ax.get_ylim()[0] - 100, threshold_ax.get_ylim()[1] + 100] + ) ethogram_ax.set_xlabel("time (min)") analysis_plotting.set_ymargin(distance_ax, 0.2, 0.1) for ax in (rate_ax, distance_ax, pellet_ax, time_dist_ax, threshold_ax): @@ -262,7 +279,9 @@ def make(self, key): # ---- Save fig and insert ---- save_dir = _make_path(key) - fig_dict = _save_figs((fig,), ("summary_plot_png",), save_dir=save_dir, prefix=save_dir.name) + fig_dict = _save_figs( + (fig,), ("summary_plot_png",), save_dir=save_dir, prefix=save_dir.name + ) self.insert1({**key, **fig_dict}) @@ -276,7 +295,7 @@ class SubjectRewardRateDifference(dj.Computed): -> acquisition.Experiment.Subject --- in_arena_count: int - reward_rate_difference_plotly: longblob # dictionary storing the plotly object (from fig.to_plotly_json()) + reward_rate_difference_plotly: longblob # dict storing the plotly object (from fig.to_plotly_json()) """ key_source = acquisition.Experiment.Subject & analysis.InArenaRewardRate @@ -299,7 +318,10 @@ def make(self, key): @classmethod def delete_outdated_entries(cls): - """Each entry in this table correspond to one subject. However, the plot is capturing data for all sessions.Hence a dynamic update routine is needed to recompute the plot as new sessions become available.""" + """Each entry in this table correspond to one subject. + However, the plot is capturing data for all sessions. + Hence a dynamic update routine is needed to recompute + the plot as new sessions become available.""" outdated_entries = ( cls * ( @@ -320,7 +342,7 @@ class SubjectWheelTravelledDistance(dj.Computed): -> acquisition.Experiment.Subject --- in_arena_count: int - wheel_travelled_distance_plotly: longblob # dictionary storing the plotly object (from fig.to_plotly_json()) + wheel_travelled_distance_plotly: longblob # dict storing the plotly object (from fig.to_plotly_json()) """ key_source = acquisition.Experiment.Subject & analysis.InArenaSummary @@ -345,7 +367,10 @@ def make(self, key): @classmethod def delete_outdated_entries(cls): - """Each entry in this table correspond to one subject. However the plot is capturing data for all sessions. Hence a dynamic update routine is needed to recompute the plot as new sessions become available.""" + """Each entry in this table correspond to one subject. + However the plot is capturing data for all sessions. + Hence a dynamic update routine is needed to recompute + the plot as new sessions become available.""" outdated_entries = ( cls * ( @@ -389,7 +414,10 @@ def make(self, key): @classmethod def delete_outdated_entries(cls): - """Each entry in this table correspond to one subject. However the plot is capturing data for all sessions. Hence a dynamic update routine is needed to recompute the plot as new sessions become available.""" + """Each entry in this table correspond to one subject. + However the plot is capturing data for all sessions. + Hence a dynamic update routine is needed to recompute + the plot as new sessions become available.""" outdated_entries = ( cls * ( @@ -419,7 +447,7 @@ class VisitDailySummaryPlot(dj.Computed): definition = """ -> Visit --- - pellet_count_plotly: longblob # Dictionary storing the plotly object (from fig.to_plotly_json()) + pellet_count_plotly: longblob # Dict storing the plotly object (from fig.to_plotly_json()) wheel_distance_travelled_plotly: longblob total_distance_travelled_plotly: longblob weight_patch_plotly: longblob @@ -431,7 +459,10 @@ class VisitDailySummaryPlot(dj.Computed): """ key_source = ( - Visit & analysis.VisitSummary & (VisitEnd & "visit_duration > 24") & "experiment_name= 'exp0.2-r0'" + Visit + & analysis.VisitSummary + & (VisitEnd & "visit_duration > 24") + & "experiment_name= 'exp0.2-r0'" ) def make(self, key): @@ -540,7 +571,12 @@ def _make_path(in_arena_key): experiment_name, subject, in_arena_start = (analysis.InArena & in_arena_key).fetch1( "experiment_name", "subject", "in_arena_start" ) - output_dir = store_stage / experiment_name / subject / in_arena_start.strftime("%y%m%d_%H%M%S_%f") + output_dir = ( + store_stage + / experiment_name + / subject + / in_arena_start.strftime("%y%m%d_%H%M%S_%f") + ) output_dir.mkdir(parents=True, exist_ok=True) return output_dir diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index f6896c2d..f1e9ad88 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -86,7 +86,9 @@ def make(self, key): ) return elif len(animal_resp) > 1: - raise ValueError(f"Found {len(animal_resp)} with eartag {eartag_or_id}, expect one") + raise ValueError( + f"Found {len(animal_resp)} with eartag {eartag_or_id}, expect one" + ) else: animal_resp = animal_resp[0] @@ -185,17 +187,21 @@ def get_reference_weight(cls, subject_name): """Get the reference weight for the subject.""" subj_key = {"subject": subject_name} - food_restrict_query = SubjectProcedure & subj_key & "procedure_name = 'R02 - food restriction'" + food_restrict_query = ( + SubjectProcedure & subj_key & "procedure_name = 'R02 - food restriction'" + ) if food_restrict_query: - ref_date = food_restrict_query.fetch("procedure_date", order_by="procedure_date DESC", limit=1)[ - 0 - ] + ref_date = food_restrict_query.fetch( + "procedure_date", order_by="procedure_date DESC", limit=1 + )[0] else: ref_date = datetime.now(timezone.utc).date() weight_query = SubjectWeight & subj_key & f"weight_time < '{ref_date}'" ref_weight = ( - weight_query.fetch("weight", order_by="weight_time DESC", limit=1)[0] if weight_query else -1 + weight_query.fetch("weight", order_by="weight_time DESC", limit=1)[0] + if weight_query + else -1 ) entry = { @@ -253,7 +259,9 @@ def _auto_schedule(self): ): return - PyratIngestionTask.insert1({"pyrat_task_scheduled_time": next_task_schedule_time}) + PyratIngestionTask.insert1( + {"pyrat_task_scheduled_time": next_task_schedule_time} + ) def make(self, key): """Automatically import or update entries in the Subject table.""" @@ -261,11 +269,15 @@ def make(self, key): new_eartags = [] for responsible_id in lab.User.fetch("responsible_id"): # 1 - retrieve all animals from this user - animal_resp = get_pyrat_data(endpoint="animals", params={"responsible_id": responsible_id}) + animal_resp = get_pyrat_data( + endpoint="animals", params={"responsible_id": responsible_id} + ) for animal_entry in animal_resp: # 2 - find animal with comment - Project Aeon eartag_or_id = animal_entry["eartag_or_id"] - comment_resp = get_pyrat_data(endpoint=f"animals/{eartag_or_id}/comments") + comment_resp = get_pyrat_data( + endpoint=f"animals/{eartag_or_id}/comments" + ) for comment in comment_resp: if comment["attributes"]: first_attr = comment["attributes"][0] @@ -294,7 +306,9 @@ def make(self, key): { **key, "execution_time": execution_time, - "execution_duration": (completion_time - execution_time).total_seconds(), + "execution_duration": ( + completion_time - execution_time + ).total_seconds(), "new_pyrat_entry_count": new_entry_count, } ) @@ -340,7 +354,9 @@ def make(self, key): for cmt in comment_resp: cmt["subject"] = eartag_or_id cmt["attributes"] = json.dumps(cmt["attributes"], default=str) - SubjectComment.insert(comment_resp, skip_duplicates=True, allow_direct_insert=True) + SubjectComment.insert( + comment_resp, skip_duplicates=True, allow_direct_insert=True + ) weight_resp = get_pyrat_data(endpoint=f"animals/{eartag_or_id}/weights") SubjectWeight.insert( @@ -349,7 +365,9 @@ def make(self, key): allow_direct_insert=True, ) - procedure_resp = get_pyrat_data(endpoint=f"animals/{eartag_or_id}/procedures") + procedure_resp = get_pyrat_data( + endpoint=f"animals/{eartag_or_id}/procedures" + ) SubjectProcedure.insert( [{**v, "subject": eartag_or_id} for v in procedure_resp], skip_duplicates=True, @@ -364,7 +382,9 @@ def make(self, key): { **key, "execution_time": execution_time, - "execution_duration": (completion_time - execution_time).total_seconds(), + "execution_duration": ( + completion_time - execution_time + ).total_seconds(), } ) @@ -377,7 +397,9 @@ class CreatePyratIngestionTask(dj.Computed): def make(self, key): """Create one new PyratIngestionTask for every newly added users.""" - PyratIngestionTask.insert1({"pyrat_task_scheduled_time": datetime.now(timezone.utc)}) + PyratIngestionTask.insert1( + {"pyrat_task_scheduled_time": datetime.now(timezone.utc)} + ) time.sleep(1) self.insert1(key) @@ -454,7 +476,8 @@ def get_pyrat_data(endpoint: str, params: dict = None, **kwargs): if pyrat_system_token is None or pyrat_user_token is None: raise ValueError( - "The PYRAT tokens must be defined as an environment variable named 'PYRAT_SYSTEM_TOKEN' and 'PYRAT_USER_TOKEN'" + "The PYRAT tokens must be defined as an environment \ + variable named 'PYRAT_SYSTEM_TOKEN' and 'PYRAT_USER_TOKEN'" ) session = requests.Session() diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 4969cf76..fd3d5117 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -83,14 +83,18 @@ def insert_new_params( ): """Insert a new set of parameters for a given tracking method.""" if tracking_paramset_id is None: - tracking_paramset_id = (dj.U().aggr(cls, n="max(tracking_paramset_id)").fetch1("n") or 0) + 1 + tracking_paramset_id = ( + dj.U().aggr(cls, n="max(tracking_paramset_id)").fetch1("n") or 0 + ) + 1 param_dict = { "tracking_method": tracking_method, "tracking_paramset_id": tracking_paramset_id, "paramset_description": paramset_description, "params": params, - "param_set_hash": dict_to_uuid({**params, "tracking_method": tracking_method}), + "param_set_hash": dict_to_uuid( + {**params, "tracking_method": tracking_method} + ), } param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} @@ -119,7 +123,13 @@ def insert_new_params( @schema class SLEAPTracking(dj.Imported): - definition = """ # Tracked objects position data from a particular VideoSource for multi-animal experiment using the SLEAP tracking method per chunk + """ + Tracked objects position data from a particular + VideoSource for multi-animal experiment using the SLEAP tracking + method per chunk. + """ + + definition = """ -> acquisition.Chunk -> streams.SpinnakerVideoSource -> TrackingParamSet @@ -153,7 +163,9 @@ def key_source(self): return ( acquisition.Chunk * ( - streams.SpinnakerVideoSource.join(streams.SpinnakerVideoSource.RemovalTime, left=True) + streams.SpinnakerVideoSource.join( + streams.SpinnakerVideoSource.RemovalTime, left=True + ) & "spinnaker_video_source_name='CameraTop'" ) * (TrackingParamSet & "tracking_paramset_id = 1") @@ -163,17 +175,22 @@ def key_source(self): def make(self, key): """Ingest SLEAP tracking data for a given chunk.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) - device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name") + device_name = (streams.SpinnakerVideoSource & key).fetch1( + "spinnaker_video_source_name" + ) devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(devices_schema, device_name).Pose @@ -185,7 +202,9 @@ def make(self, key): ) if not len(pose_data): - raise ValueError(f"No SLEAP data found for {key['experiment_name']} - {device_name}") + raise ValueError( + f"No SLEAP data found for {key['experiment_name']} - {device_name}" + ) # get identity names class_names = np.unique(pose_data.identity) @@ -218,7 +237,9 @@ def make(self, key): if part == anchor_part: identity_likelihood = part_position.identity_likelihood.values if isinstance(identity_likelihood[0], dict): - identity_likelihood = np.array([v[identity] for v in identity_likelihood]) + identity_likelihood = np.array( + [v[identity] for v in identity_likelihood] + ) pose_identity_entries.append( { @@ -263,7 +284,9 @@ def is_position_in_nest(position_df, nest_key, xcol="x", ycol="y") -> pd.Series: """Given the session key and the position data - arrays of x and y return an array of boolean indicating whether or not a position is inside the nest. """ - nest_vertices = list(zip(*(lab.ArenaNest.Vertex & nest_key).fetch("vertex_x", "vertex_y"))) + nest_vertices = list( + zip(*(lab.ArenaNest.Vertex & nest_key).fetch("vertex_x", "vertex_y")) + ) nest_path = matplotlib.path.Path(nest_vertices) position_df["in_nest"] = nest_path.contains_points(position_df[[xcol, ycol]]) return position_df["in_nest"] @@ -290,7 +313,9 @@ def _get_position( start_query = table & obj_restriction & start_restriction end_query = table & obj_restriction & end_restriction if not (start_query and end_query): - raise ValueError(f"No position data found for {object_name} between {start} and {end}") + raise ValueError( + f"No position data found for {object_name} between {start} and {end}" + ) time_restriction = ( f'{start_attr} >= "{min(start_query.fetch(start_attr))}"' @@ -298,10 +323,14 @@ def _get_position( ) # subject's position data in the time slice - fetched_data = (table & obj_restriction & time_restriction).fetch(*fetch_attrs, order_by=start_attr) + fetched_data = (table & obj_restriction & time_restriction).fetch( + *fetch_attrs, order_by=start_attr + ) if not len(fetched_data[0]): - raise ValueError(f"No position data found for {object_name} between {start} and {end}") + raise ValueError( + f"No position data found for {object_name} between {start} and {end}" + ) timestamp_attr = next(attr for attr in fetch_attrs if "timestamps" in attr) diff --git a/aeon/dj_pipeline/utils/load_metadata.py b/aeon/dj_pipeline/utils/load_metadata.py index 4c3b7315..2c7f2aa3 100644 --- a/aeon/dj_pipeline/utils/load_metadata.py +++ b/aeon/dj_pipeline/utils/load_metadata.py @@ -42,7 +42,9 @@ def insert_stream_types(): else: # If the existed stream type does not have the same name: # human error, trying to add the same content with different name - raise dj.DataJointError(f"The specified stream type already exists - name: {pname}") + raise dj.DataJointError( + f"The specified stream type already exists - name: {pname}" + ) else: streams.StreamType.insert1(entry) @@ -55,7 +57,9 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): streams = dj.VirtualModule("streams", streams_maker.schema_name) device_info: dict[dict] = get_device_info(devices_schema) - device_type_mapper, device_sn = get_device_mapper(devices_schema, metadata_yml_filepath) + device_type_mapper, device_sn = get_device_mapper( + devices_schema, metadata_yml_filepath + ) # Add device type to device_info. Only add if device types that are defined in Metadata.yml device_info = { @@ -92,7 +96,8 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): {"device_type": device_type, "stream_type": stream_type} for device_type, stream_list in device_stream_map.items() for stream_type in stream_list - if not streams.DeviceType.Stream & {"device_type": device_type, "stream_type": stream_type} + if not streams.DeviceType.Stream + & {"device_type": device_type, "stream_type": stream_type} ] new_devices = [ @@ -101,7 +106,8 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): "device_type": device_config["device_type"], } for device_name, device_config in device_info.items() - if device_sn[device_name] and not streams.Device & {"device_serial_number": device_sn[device_name]} + if device_sn[device_name] + and not streams.Device & {"device_serial_number": device_sn[device_name]} ] # Insert new entries. @@ -119,7 +125,9 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): streams.Device.insert(new_devices) -def extract_epoch_config(experiment_name: str, devices_schema: DotMap, metadata_yml_filepath: str) -> dict: +def extract_epoch_config( + experiment_name: str, devices_schema: DotMap, metadata_yml_filepath: str +) -> dict: """Parse experiment metadata YAML file and extract epoch configuration. Args: @@ -131,7 +139,9 @@ def extract_epoch_config(experiment_name: str, devices_schema: DotMap, metadata_ dict: epoch_config [dict] """ metadata_yml_filepath = pathlib.Path(metadata_yml_filepath) - epoch_start = datetime.datetime.strptime(metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S") + epoch_start = datetime.datetime.strptime( + metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S" + ) epoch_config: dict = ( io_api.load( metadata_yml_filepath.parent.as_posix(), @@ -146,15 +156,22 @@ def extract_epoch_config(experiment_name: str, devices_schema: DotMap, metadata_ commit = epoch_config["metadata"]["Revision"] if not commit: - raise ValueError(f'Neither "Commit" nor "Revision" found in {metadata_yml_filepath}') + raise ValueError( + f'Neither "Commit" nor "Revision" found in {metadata_yml_filepath}' + ) devices: list[dict] = json.loads( - json.dumps(epoch_config["metadata"]["Devices"], default=lambda x: x.__dict__, indent=4) + json.dumps( + epoch_config["metadata"]["Devices"], default=lambda x: x.__dict__, indent=4 + ) ) - # Maintain backward compatibility - In exp02, it is a list of dict. From presocial onward, it's a dict of dict. + # Maintain backward compatibility - In exp02, it is a list of dict. + # From presocial onward, it's a dict of dict. if isinstance(devices, list): - devices: dict = {d.pop("Name"): d for d in devices} # {deivce_name: device_config} + devices: dict = { + d.pop("Name"): d for d in devices + } # {deivce_name: device_config} return { "experiment_name": experiment_name, @@ -178,15 +195,17 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath experiment_key = {"experiment_name": experiment_name} metadata_yml_filepath = pathlib.Path(metadata_yml_filepath) - epoch_config = extract_epoch_config(experiment_name, devices_schema, metadata_yml_filepath) + epoch_config = extract_epoch_config( + experiment_name, devices_schema, metadata_yml_filepath + ) previous_epoch = (acquisition.Experiment & experiment_key).aggr( acquisition.Epoch & f'epoch_start < "{epoch_config["epoch_start"]}"', epoch_start="MAX(epoch_start)", ) - if len(acquisition.EpochConfig.Meta & previous_epoch) and epoch_config["commit"] == ( - acquisition.EpochConfig.Meta & previous_epoch - ).fetch1("commit"): + if len(acquisition.EpochConfig.Meta & previous_epoch) and epoch_config[ + "commit" + ] == (acquisition.EpochConfig.Meta & previous_epoch).fetch1("commit"): # if identical commit -> no changes return @@ -207,7 +226,9 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath if not (streams.Device & device_key): logger.warning( - f"Device {device_name} (serial number: {device_sn}) is not yet registered in streams.Device.\nThis should not happen - check if metadata.yml and schemas dotmap are consistent. Skipping..." + f"Device {device_name} (serial number: {device_sn}) is not \ + yet registered in streams.Device.\nThis should not happen - \ + check if metadata.yml and schemas dotmap are consistent. Skipping..." ) # skip if this device (with a serial number) is not yet inserted in streams.Device continue @@ -218,7 +239,9 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath table_entry = { "experiment_name": experiment_name, **device_key, - f"{dj.utils.from_camel_case(table.__name__)}_install_time": epoch_config["epoch_start"], + f"{dj.utils.from_camel_case(table.__name__)}_install_time": epoch_config[ + "epoch_start" + ], f"{dj.utils.from_camel_case(table.__name__)}_name": device_name, } @@ -235,15 +258,23 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath { **table_entry, "attribute_name": "SamplingFrequency", - "attribute_value": video_controller[device_config["TriggerFrequency"]], + "attribute_value": video_controller[ + device_config["TriggerFrequency"] + ], } ) - """Check if this device is currently installed. If the same device serial number is currently installed check for any changes in configuration. If not, skip this""" - current_device_query = table - table.RemovalTime & experiment_key & device_key + """Check if this device is currently installed. If the same + device serial number is currently installed check for any + changes in configuration. If not, skip this""" + current_device_query = ( + table - table.RemovalTime & experiment_key & device_key + ) if current_device_query: - current_device_config: list[dict] = (table.Attribute & current_device_query).fetch( + current_device_config: list[dict] = ( + table.Attribute & current_device_query + ).fetch( "experiment_name", "device_serial_number", "attribute_name", @@ -251,7 +282,11 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath as_dict=True, ) new_device_config: list[dict] = [ - {k: v for k, v in entry.items() if dj.utils.from_camel_case(table.__name__) not in k} + { + k: v + for k, v in entry.items() + if dj.utils.from_camel_case(table.__name__) not in k + } for entry in table_attribute_entry ] @@ -261,7 +296,10 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath for config in current_device_config } ) == dict_to_uuid( - {config["attribute_name"]: config["attribute_value"] for config in new_device_config} + { + config["attribute_name"]: config["attribute_value"] + for config in new_device_config + } ): # Skip if none of the configuration has changed. continue @@ -379,10 +417,14 @@ def _get_class_path(obj): "aeon.schema.social", ]: device_info[device_name]["stream_type"].append(stream_type) - device_info[device_name]["stream_reader"].append(_get_class_path(stream_obj)) + device_info[device_name]["stream_reader"].append( + _get_class_path(stream_obj) + ) required_args = [ - k for k in inspect.signature(stream_obj.__init__).parameters if k != "self" + k + for k in inspect.signature(stream_obj.__init__).parameters + if k != "self" ] pattern = schema_dict[device_name][stream_type].get("pattern") schema_dict[device_name][stream_type]["pattern"] = pattern.replace( @@ -390,23 +432,35 @@ def _get_class_path(obj): ) kwargs = { - k: v for k, v in schema_dict[device_name][stream_type].items() if k in required_args + k: v + for k, v in schema_dict[device_name][stream_type].items() + if k in required_args } device_info[device_name]["stream_reader_kwargs"].append(kwargs) # Add hash device_info[device_name]["stream_hash"].append( - dict_to_uuid({**kwargs, "stream_reader": _get_class_path(stream_obj)}) + dict_to_uuid( + {**kwargs, "stream_reader": _get_class_path(stream_obj)} + ) ) else: stream_type = device.__class__.__name__ device_info[device_name]["stream_type"].append(stream_type) device_info[device_name]["stream_reader"].append(_get_class_path(device)) - required_args = {k: None for k in inspect.signature(device.__init__).parameters if k != "self"} + required_args = { + k: None + for k in inspect.signature(device.__init__).parameters + if k != "self" + } pattern = schema_dict[device_name].get("pattern") - schema_dict[device_name]["pattern"] = pattern.replace(device_name, "{pattern}") + schema_dict[device_name]["pattern"] = pattern.replace( + device_name, "{pattern}" + ) - kwargs = {k: v for k, v in schema_dict[device_name].items() if k in required_args} + kwargs = { + k: v for k, v in schema_dict[device_name].items() if k in required_args + } device_info[device_name]["stream_reader_kwargs"].append(kwargs) # Add hash device_info[device_name]["stream_hash"].append( @@ -416,7 +470,10 @@ def _get_class_path(obj): def get_device_mapper(devices_schema: DotMap, metadata_yml_filepath: Path): - """Returns a mapping dictionary between device name and device type based on the dataset schema and metadata.yml from the experiment. Store the mapper dictionary and read from it if the type info doesn't exist in Metadata.yml. + """Returns a mapping dictionary between device name and device type + based on the dataset schema and metadata.yml from the experiment. + Store the mapper dictionary and read from it if the type info doesn't + exist in Metadata.yml. Args: devices_schema (DotMap): DotMap object (e.g., exp02) @@ -456,7 +513,8 @@ def get_device_mapper(devices_schema: DotMap, metadata_yml_filepath: Path): device_type_mapper[item.Name] = item.Type device_sn[item.Name] = ( item.SerialNumber or item.PortName or None - ) # assign either the serial number (if it exists) or port name. If neither exists, assign None + ) # assign either the serial number (if it exists) or port name. + # If neither exists, assign None elif isinstance(item, str): # presocial if meta_data.Devices[item].get("Type"): device_type_mapper[item] = meta_data.Devices[item].get("Type") @@ -496,7 +554,9 @@ def ingest_epoch_metadata_octagon(experiment_name, metadata_yml_filepath): ("Wall8", "Wall"), ] - epoch_start = datetime.datetime.strptime(metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S") + epoch_start = datetime.datetime.strptime( + metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S" + ) for device_idx, (device_name, device_type) in enumerate(oct01_devices): device_sn = f"oct01_{device_idx}" @@ -505,8 +565,13 @@ def ingest_epoch_metadata_octagon(experiment_name, metadata_yml_filepath): skip_duplicates=True, ) experiment_table = getattr(streams, f"Experiment{device_type}") - if not (experiment_table & {"experiment_name": experiment_name, "device_serial_number": device_sn}): - experiment_table.insert1((experiment_name, device_sn, epoch_start, device_name)) + if not ( + experiment_table + & {"experiment_name": experiment_name, "device_serial_number": device_sn} + ): + experiment_table.insert1( + (experiment_name, device_sn, epoch_start, device_name) + ) # endregion diff --git a/aeon/dj_pipeline/utils/paths.py b/aeon/dj_pipeline/utils/paths.py index 1df21e64..b2b38d9a 100644 --- a/aeon/dj_pipeline/utils/paths.py +++ b/aeon/dj_pipeline/utils/paths.py @@ -36,7 +36,8 @@ def get_repository_path(repository_name: str) -> pathlib.Path: def find_root_directory( root_directories: str | pathlib.Path, full_path: str | pathlib.Path ) -> pathlib.Path: - """Given multiple potential root directories and a full-path, search and return one directory that is the parent of the given path. + """Given multiple potential root directories and a full-path, + search and return one directory that is the parent of the given path. Args: root_directories (str | pathlib.Path): A list of potential root directories. @@ -67,5 +68,6 @@ def find_root_directory( except StopIteration: raise FileNotFoundError( - f"No valid root directory found (from {root_directories})" f" for {full_path}" + f"No valid root directory found (from {root_directories})" + f" for {full_path}" ) diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index a8826455..be0ac80b 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -25,24 +25,30 @@ def plot_reward_rate_differences(subject_keys): - """Plotting the reward rate differences between food patches (Patch 2 - Patch 1) for all sessions from all subjects specified in "subject_keys". + """Plotting the reward rate differences between food patches + (Patch 2 - Patch 1) for all sessions from all subjects specified in "subject_keys". Examples: ``` - subject_keys = (acquisition.Experiment.Subject & 'experiment_name = "exp0.1-r0"').fetch('KEY') + subject_keys = + (acquisition.Experiment.Subject & 'experiment_name = "exp0.1-r0"').fetch('KEY') fig = plot_reward_rate_differences(subject_keys) ``` """ subj_names, sess_starts, rate_timestamps, rate_diffs = ( analysis.InArenaRewardRate & subject_keys - ).fetch("subject", "in_arena_start", "pellet_rate_timestamps", "patch2_patch1_rate_diff") + ).fetch( + "subject", "in_arena_start", "pellet_rate_timestamps", "patch2_patch1_rate_diff" + ) nSessions = len(sess_starts) longest_rateDiff = np.max([len(t) for t in rate_timestamps]) max_session_idx = np.argmax([len(t) for t in rate_timestamps]) - max_session_elapsed_times = rate_timestamps[max_session_idx] - rate_timestamps[max_session_idx][0] + max_session_elapsed_times = ( + rate_timestamps[max_session_idx] - rate_timestamps[max_session_idx][0] + ) x_labels = [t.total_seconds() / 60 for t in max_session_elapsed_times] y_labels = [ @@ -76,7 +82,8 @@ def plot_reward_rate_differences(subject_keys): def plot_wheel_travelled_distance(session_keys): - """Plotting the wheel travelled distance for different patches for all sessions specified in "session_keys". + """Plotting the wheel travelled distance for different patches + for all sessions specified in "session_keys". Examples: ``` @@ -87,12 +94,15 @@ def plot_wheel_travelled_distance(session_keys): ``` """ distance_travelled_query = ( - analysis.InArenaSummary.FoodPatch * acquisition.ExperimentFoodPatch.proj("food_patch_description") + analysis.InArenaSummary.FoodPatch + * acquisition.ExperimentFoodPatch.proj("food_patch_description") & session_keys ) distance_travelled_df = ( - distance_travelled_query.proj("food_patch_description", "wheel_distance_travelled") + distance_travelled_query.proj( + "food_patch_description", "wheel_distance_travelled" + ) .fetch(format="frame") .reset_index() ) @@ -154,7 +164,8 @@ def plot_average_time_distribution(session_keys): & session_keys ) .aggr( - analysis.InArenaTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch, + analysis.InArenaTimeDistribution.FoodPatch + * acquisition.ExperimentFoodPatch, avg_in_patch="AVG(time_fraction_in_patch)", ) .fetch("subject", "food_patch_description", "avg_in_patch") @@ -208,16 +219,21 @@ def plot_visit_daily_summary( Args: visit_key (dict) : Key from the VisitSummary table - attr (str): Name of the attribute to plot (e.g., 'pellet_count', 'wheel_distance_travelled', 'total_distance_travelled') - per_food_patch (bool, optional): Separately plot results from different food patches. Defaults to False. + attr (str): Name of the attribute to plot (e.g., 'pellet_count', + 'wheel_distance_travelled', 'total_distance_travelled') + per_food_patch (bool, optional): Separately plot results from + different food patches. Defaults to False. Returns: fig: Figure object Examples: - >>> fig = plot_visit_daily_summary(visit_key, attr='pellet_count', per_food_patch=True) - >>> fig = plot_visit_daily_summary(visit_key, attr='wheel_distance_travelled', per_food_patch=True) - >>> fig = plot_visit_daily_summary(visit_key, attr='total_distance_travelled') + >>> fig = plot_visit_daily_summary(visit_key, attr='pellet_count', + per_food_patch=True) + >>> fig = plot_visit_daily_summary(visit_key, + attr='wheel_distance_travelled', per_food_patch=True) + >>> fig = plot_visit_daily_summary(visit_key, + attr='total_distance_travelled') """ per_food_patch = not attr.startswith("total") color = "food_patch_description" if per_food_patch else None @@ -232,11 +248,15 @@ def plot_visit_daily_summary( .reset_index() ) else: - visit_per_day_df = (VisitSummary & visit_key).fetch(format="frame").reset_index() + visit_per_day_df = ( + (VisitSummary & visit_key).fetch(format="frame").reset_index() + ) if not attr.startswith("total"): attr = "total_" + attr - visit_per_day_df["day"] = visit_per_day_df["visit_date"] - visit_per_day_df["visit_date"].min() + visit_per_day_df["day"] = ( + visit_per_day_df["visit_date"] - visit_per_day_df["visit_date"].min() + ) visit_per_day_df["day"] = visit_per_day_df["day"].dt.days fig = px.bar( @@ -285,8 +305,10 @@ def plot_foraging_bouts_count( Args: visit_key (dict): Key from the Visit table - freq (str): Frequency level at which the visit time distribution is plotted. Corresponds to pandas freq. - per_food_patch (bool, optional): Separately plot results from different food patches. Defaults to False. + freq (str): Frequency level at which the visit time + distribution is plotted. Corresponds to pandas freq. + per_food_patch (bool, optional): Separately plot results from + different food patches. Defaults to False. min_bout_duration (int): Minimum foraging bout duration (in seconds) min_pellet_count (int): Minimum number of pellets min_wheel_dist (int): Minimum wheel distance travelled (in cm) @@ -295,7 +317,8 @@ def plot_foraging_bouts_count( fig: Figure object Examples: - >>> fig = plot_foraging_bouts_count(visit_key, freq="D", per_food_patch=True, min_bout_duration=1, min_wheel_dist=1) + >>> fig = plot_foraging_bouts_count(visit_key, freq="D", + per_food_patch=True, min_bout_duration=1, min_wheel_dist=1) """ # Get all foraging bouts for the visit foraging_bouts = ( @@ -327,10 +350,14 @@ def plot_foraging_bouts_count( else [foraging_bouts["bout_start"].dt.floor("D")] ) - foraging_bouts_count = foraging_bouts.groupby(group_by_attrs).size().reset_index(name="count") + foraging_bouts_count = ( + foraging_bouts.groupby(group_by_attrs).size().reset_index(name="count") + ) visit_start = (VisitEnd & visit_key).fetch1("visit_start") - foraging_bouts_count["day"] = (foraging_bouts_count["bout_start"].dt.date - visit_start.date()).dt.days + foraging_bouts_count["day"] = ( + foraging_bouts_count["bout_start"].dt.date - visit_start.date() + ).dt.days fig = px.bar( foraging_bouts_count, @@ -344,7 +371,10 @@ def plot_foraging_bouts_count( width=700, height=400, template="simple_white", - title=visit_key["subject"] + "
Foraging bouts: count (freq='" + freq + "')", + title=visit_key["subject"] + + "
Foraging bouts: count (freq='" + + freq + + "')", ) fig.update_layout( @@ -376,8 +406,10 @@ def plot_foraging_bouts_distribution( Args: visit_key (dict): Key from the Visit table - attr (str): Options include: pellet_count, bout_duration, wheel_distance_travelled - per_food_patch (bool, optional): Separately plot results from different food patches. Defaults to False. + attr (str): Options include: pellet_count, bout_duration, + wheel_distance_travelled + per_food_patch (bool, optional): Separately plot results from + different food patches. Defaults to False. min_bout_duration (int): Minimum foraging bout duration (in seconds) min_pellet_count (int): Minimum number of pellets min_wheel_dist (int): Minimum wheel distance travelled (in cm) @@ -416,7 +448,9 @@ def plot_foraging_bouts_distribution( fig = go.Figure() if per_food_patch: - patch_names = (acquisition.ExperimentFoodPatch & visit_key).fetch("food_patch_description") + patch_names = (acquisition.ExperimentFoodPatch & visit_key).fetch( + "food_patch_description" + ) for patch in patch_names: bouts = foraging_bouts[foraging_bouts["food_patch_description"] == patch] fig.add_trace( @@ -443,7 +477,9 @@ def plot_foraging_bouts_distribution( ) fig.update_layout( - title_text=visit_key["subject"] + "
Foraging bouts: " + attr.replace("_", " "), + title_text=visit_key["subject"] + + "
Foraging bouts: " + + attr.replace("_", " "), xaxis_title="date", yaxis_title=attr.replace("_", " "), violingap=0, @@ -469,7 +505,8 @@ def plot_visit_time_distribution(visit_key, freq="D"): Args: visit_key (dict): Key from the Visit table - freq (str): Frequency level at which the visit time distribution is plotted. Corresponds to pandas freq. + freq (str): Frequency level at which the visit time distribution + is plotted. Corresponds to pandas freq. Returns: fig: Figure object @@ -481,11 +518,17 @@ def plot_visit_time_distribution(visit_key, freq="D"): region = _get_region_data(visit_key) # Compute time spent per region - time_spent = region.groupby([region.index.floor(freq), "region"]).size().reset_index(name="count") - time_spent["time_fraction"] = time_spent["count"] / time_spent.groupby("timestamps")["count"].transform( - "sum" + time_spent = ( + region.groupby([region.index.floor(freq), "region"]) + .size() + .reset_index(name="count") ) - time_spent["day"] = (time_spent["timestamps"] - time_spent["timestamps"].min()).dt.days + time_spent["time_fraction"] = time_spent["count"] / time_spent.groupby( + "timestamps" + )["count"].transform("sum") + time_spent["day"] = ( + time_spent["timestamps"] - time_spent["timestamps"].min() + ).dt.days fig = px.bar( time_spent, @@ -497,7 +540,10 @@ def plot_visit_time_distribution(visit_key, freq="D"): "time_fraction": "time fraction", "timestamps": "date" if freq == "D" else "time", }, - title=visit_key["subject"] + "
Fraction of time spent in each region (freq='" + freq + "')", + title=visit_key["subject"] + + "
Fraction of time spent in each region (freq='" + + freq + + "')", width=700, height=400, template="simple_white", @@ -541,7 +587,9 @@ def _get_region_data(visit_key, attrs=None): for attr in attrs: if attr == "in_nest": # Nest in_nest = np.concatenate( - (VisitTimeDistribution.Nest & visit_key).fetch(attr, order_by="visit_date") + (VisitTimeDistribution.Nest & visit_key).fetch( + attr, order_by="visit_date" + ) ) region = pd.concat( [ @@ -556,14 +604,16 @@ def _get_region_data(visit_key, attrs=None): elif attr == "in_patch": # Food patch # Find all patches patches = np.unique( - (VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & visit_key).fetch( - "food_patch_description" - ) + ( + VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch + & visit_key + ).fetch("food_patch_description") ) for patch in patches: in_patch = np.concatenate( ( - VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch + VisitTimeDistribution.FoodPatch + * acquisition.ExperimentFoodPatch & visit_key & f"food_patch_description = '{patch}'" ).fetch("in_patch", order_by="visit_date") @@ -595,13 +645,19 @@ def _get_region_data(visit_key, attrs=None): region = region.sort_index().rename_axis("timestamps") # Exclude data during maintenance - maintenance_period = get_maintenance_periods(visit_key["experiment_name"], visit_start, visit_end) - region = filter_out_maintenance_periods(region, maintenance_period, visit_end, dropna=True) + maintenance_period = get_maintenance_periods( + visit_key["experiment_name"], visit_start, visit_end + ) + region = filter_out_maintenance_periods( + region, maintenance_period, visit_end, dropna=True + ) return region -def plot_weight_patch_data(visit_key, freq="H", smooth_weight=True, min_weight=0, max_weight=35): +def plot_weight_patch_data( + visit_key, freq="H", smooth_weight=True, min_weight=0, max_weight=35 +): """Plot subject weight and patch data (pellet trigger count) per visit. Args: @@ -618,7 +674,9 @@ def plot_weight_patch_data(visit_key, freq="H", smooth_weight=True, min_weight=0 >>> fig = plot_weight_patch_data(visit_key, freq="H", smooth_weight=True) >>> fig = plot_weight_patch_data(visit_key, freq="D") """ - subject_weight = _get_filtered_subject_weight(visit_key, smooth_weight, min_weight, max_weight) + subject_weight = _get_filtered_subject_weight( + visit_key, smooth_weight, min_weight, max_weight + ) # Count pellet trigger per patch per day/hour/... patch = _get_patch_data(visit_key) @@ -646,8 +704,12 @@ def plot_weight_patch_data(visit_key, freq="H", smooth_weight=True, min_weight=0 for p in patch_names: fig.add_trace( go.Bar( - x=patch_summary[patch_summary["food_patch_description"] == p]["event_time"], - y=patch_summary[patch_summary["food_patch_description"] == p]["event_type"], + x=patch_summary[patch_summary["food_patch_description"] == p][ + "event_time" + ], + y=patch_summary[patch_summary["food_patch_description"] == p][ + "event_type" + ], name=p, ), secondary_y=False, @@ -672,7 +734,10 @@ def plot_weight_patch_data(visit_key, freq="H", smooth_weight=True, min_weight=0 fig.update_layout( barmode="stack", hovermode="x", - title_text=visit_key["subject"] + "
Weight and pellet count (freq='" + freq + "')", + title_text=visit_key["subject"] + + "
Weight and pellet count (freq='" + + freq + + "')", xaxis_title="date" if freq == "D" else "time", yaxis={"title": "pellet count"}, yaxis2={"title": "weight"}, @@ -693,7 +758,9 @@ def plot_weight_patch_data(visit_key, freq="H", smooth_weight=True, min_weight=0 return fig -def _get_filtered_subject_weight(visit_key, smooth_weight=True, min_weight=0, max_weight=35): +def _get_filtered_subject_weight( + visit_key, smooth_weight=True, min_weight=0, max_weight=35 +): """Retrieve subject weight from WeightMeasurementFiltered table. Args: @@ -732,7 +799,9 @@ def _get_filtered_subject_weight(visit_key, smooth_weight=True, min_weight=0, ma subject_weight = subject_weight.loc[visit_start:visit_end] # Exclude data during maintenance - maintenance_period = get_maintenance_periods(visit_key["experiment_name"], visit_start, visit_end) + maintenance_period = get_maintenance_periods( + visit_key["experiment_name"], visit_start, visit_end + ) subject_weight = filter_out_maintenance_periods( subject_weight, maintenance_period, visit_end, dropna=True ) @@ -749,7 +818,9 @@ def _get_filtered_subject_weight(visit_key, smooth_weight=True, min_weight=0, ma subject_weight = subject_weight.resample("1T").mean().dropna() if smooth_weight: - subject_weight["weight_subject"] = savgol_filter(subject_weight["weight_subject"], 10, 3) + subject_weight["weight_subject"] = savgol_filter( + subject_weight["weight_subject"], 10, 3 + ) return subject_weight @@ -770,7 +841,9 @@ def _get_patch_data(visit_key): ( dj.U("event_time", "event_type", "food_patch_description") & ( - acquisition.FoodPatchEvent * acquisition.EventType * acquisition.ExperimentFoodPatch + acquisition.FoodPatchEvent + * acquisition.EventType + * acquisition.ExperimentFoodPatch & f'event_time BETWEEN "{visit_start}" AND "{visit_end}"' & 'event_type = "TriggerPellet"' ) @@ -783,7 +856,11 @@ def _get_patch_data(visit_key): # TODO: handle repeat attempts (pellet delivery trigger and beam break) # Exclude data during maintenance - maintenance_period = get_maintenance_periods(visit_key["experiment_name"], visit_start, visit_end) - patch = filter_out_maintenance_periods(patch, maintenance_period, visit_end, dropna=True) + maintenance_period = get_maintenance_periods( + visit_key["experiment_name"], visit_start, visit_end + ) + patch = filter_out_maintenance_periods( + patch, maintenance_period, visit_end, dropna=True + ) return patch diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index e3d3259f..f86b1e1d 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -114,10 +114,7 @@ def get_device_stream_template(device_type: str, stream_type: str, streams_modul # DeviceDataStream table(s) stream_detail = ( streams_module.StreamType - & ( - streams_module.DeviceType.Stream - & {"device_type": device_type, "stream_type": stream_type} - ) + & (streams_module.DeviceType.Stream & {"device_type": device_type, "stream_type": stream_type}) ).fetch1() for i, n in enumerate(stream_detail["stream_reader"].split(".")): @@ -159,8 +156,7 @@ def key_source(self): """ key_source_query = ( - acquisition.Chunk - * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) + acquisition.Chunk * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) & f"chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time" & f'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time,\ "2200-01-01")' @@ -170,9 +166,7 @@ def key_source(self): def make(self, key): """Load and insert the data for the DeviceDataStream table.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) device_name = (ExperimentDevice & key).fetch1( @@ -182,13 +176,10 @@ def make(self, key): devices_schema = getattr( aeon_schemas, ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} + acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]} ).fetch1("devices_schema_name"), ) - stream_reader = getattr( - getattr(devices_schema, device_name), "{stream_type}" - ) + stream_reader = getattr(getattr(devices_schema, device_name), "{stream_type}") stream_data = io_api.load( root=data_dirs, diff --git a/aeon/io/reader.py b/aeon/io/reader.py index 144f6d3c..cdce7eb7 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -500,7 +500,9 @@ def from_dict(data, pattern=None): def to_dict(dotmap): """Converts a DotMap object to a dictionary.""" if isinstance(dotmap, Reader): - kwargs = {k: v for k, v in vars(dotmap).items() if k not in ["pattern"] and not k.startswith("_")} + kwargs = { + k: v for k, v in vars(dotmap).items() if k not in ["pattern"] and not k.startswith("_") + } kwargs["type"] = type(dotmap).__name__ return kwargs return {k: to_dict(v) for k, v in dotmap.items()} diff --git a/tests/dj_pipeline/test_pipeline_instantiation.py b/tests/dj_pipeline/test_pipeline_instantiation.py index f53bde20..087d0c05 100644 --- a/tests/dj_pipeline/test_pipeline_instantiation.py +++ b/tests/dj_pipeline/test_pipeline_instantiation.py @@ -45,7 +45,9 @@ def test_experiment_creation(test_params, pipeline, experiment_creation): if raw_dir != test_params["raw_dir"]: raise AssertionError(f"Expected raw directory '{test_params['raw_dir']}', but got '{raw_dir}'.") - exp_subjects = (acquisition.Experiment.Subject & {"experiment_name": experiment_name}).fetch("subject") + exp_subjects = (acquisition.Experiment.Subject & {"experiment_name": experiment_name}).fetch( + "subject" + ) if len(exp_subjects) != test_params["subject_count"]: raise AssertionError( f"Expected subject count {test_params['subject_count']}, but got {len(exp_subjects)}." From 4d76c773143bdf43c86b583750cf1538aeae4ee0 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 23:40:41 +0000 Subject: [PATCH 028/143] fix: resolve F401 issues --- aeon/dj_pipeline/acquisition.py | 4 +--- aeon/dj_pipeline/analysis/visit.py | 2 +- aeon/dj_pipeline/analysis/visit_analysis.py | 1 - .../dj_pipeline/create_experiments/create_socialexperiment.py | 1 - aeon/dj_pipeline/populate/worker.py | 1 - aeon/dj_pipeline/report.py | 2 +- aeon/dj_pipeline/tracking.py | 2 -- 7 files changed, 3 insertions(+), 10 deletions(-) diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index ee5fa322..68cc0efc 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -6,11 +6,9 @@ import re import datajoint as dj -import numpy as np import pandas as pd -from aeon.analysis import utils as analysis_utils -from aeon.dj_pipeline import get_schema_name, lab, subject +from aeon.dj_pipeline import get_schema_name from aeon.dj_pipeline.utils import paths from aeon.io import api as io_api from aeon.io import reader as io_reader diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index f4d69153..010f3824 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -9,7 +9,7 @@ from aeon.analysis import utils as analysis_utils from aeon.dj_pipeline import get_schema_name, fetch_stream -from aeon.dj_pipeline import acquisition, lab, qc, tracking +from aeon.dj_pipeline import acquisition schema = dj.schema(get_schema_name("analysis")) diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 0f5047f9..5d900d0a 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -7,7 +7,6 @@ import numpy as np import pandas as pd -from aeon.dj_pipeline import get_schema_name from aeon.dj_pipeline import acquisition, lab, tracking from aeon.dj_pipeline.analysis.visit import ( Visit, diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py index 68925d29..581a8999 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py @@ -1,6 +1,5 @@ """Function to create new social experiments""" -from pathlib import Path from datetime import datetime from aeon.dj_pipeline import acquisition from aeon.dj_pipeline.utils.paths import get_repository_path diff --git a/aeon/dj_pipeline/populate/worker.py b/aeon/dj_pipeline/populate/worker.py index 35e93da6..a4605273 100644 --- a/aeon/dj_pipeline/populate/worker.py +++ b/aeon/dj_pipeline/populate/worker.py @@ -7,7 +7,6 @@ DataJointWorker, ErrorLog, WorkerLog, - RegisteredWorker, ) from datajoint_utilities.dj_worker.worker_schema import is_djtable diff --git a/aeon/dj_pipeline/report.py b/aeon/dj_pipeline/report.py index b016fde7..9e60a3b2 100644 --- a/aeon/dj_pipeline/report.py +++ b/aeon/dj_pipeline/report.py @@ -14,7 +14,7 @@ from aeon.dj_pipeline.analysis.visit import Visit, VisitEnd from aeon.dj_pipeline.analysis.visit_analysis import * -from . import acquisition, analysis, get_schema_name +from . import acquisition, analysis logger = dj.logger diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index fd3d5117..b36222c0 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -1,6 +1,5 @@ """DataJoint schema for tracking data.""" -from pathlib import Path import datajoint as dj import matplotlib.path @@ -12,7 +11,6 @@ dict_to_uuid, get_schema_name, lab, - qc, streams, ) from aeon.io import api as io_api From 163ff9fd1e32e72b255a56056d9cc18765e301e0 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 23:43:02 +0000 Subject: [PATCH 029/143] fix: resolve B905 issues add strict=False to zip --- aeon/dj_pipeline/acquisition.py | 8 ++++---- aeon/dj_pipeline/analysis/block_analysis.py | 2 +- aeon/dj_pipeline/analysis/visit.py | 4 ++-- aeon/dj_pipeline/create_experiments/create_presocial.py | 2 +- aeon/dj_pipeline/lab.py | 2 +- aeon/dj_pipeline/report.py | 4 ++-- aeon/dj_pipeline/tracking.py | 4 ++-- aeon/dj_pipeline/utils/plotting.py | 4 ++-- 8 files changed, 15 insertions(+), 15 deletions(-) diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index 68cc0efc..69e5a142 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -35,7 +35,7 @@ class ExperimentType(dj.Lookup): experiment_type: varchar(32) """ - contents = zip(["foraging", "social"]) + contents = zip(["foraging", "social"], strict=False) @schema @@ -63,7 +63,7 @@ class DevicesSchema(dj.Lookup): devices_schema_name: varchar(32) """ - contents = zip(aeon_schemas.__all__) + contents = zip(aeon_schemas.__all__, strict=False) # ------------------- Data repository/directory ------------------------ @@ -75,7 +75,7 @@ class PipelineRepository(dj.Lookup): repository_name: varchar(16) """ - contents = zip(["ceph_aeon"]) + contents = zip(["ceph_aeon"], strict=False) @schema @@ -84,7 +84,7 @@ class DirectoryType(dj.Lookup): directory_type: varchar(16) """ - contents = zip(["raw", "processed", "qc"]) + contents = zip(["raw", "processed", "qc"], strict=False) # ------------------- GENERAL INFORMATION ABOUT AN EXPERIMENT -------------------- diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index ddc220e3..1ae8972d 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -227,7 +227,7 @@ def make(self, key): patch_keys, patch_names = patch_query.fetch("KEY", "underground_feeder_name") block_patch_entries = [] - for patch_key, patch_name in zip(patch_keys, patch_names): + for patch_key, patch_name in zip(patch_keys, patch_names, strict=False): # pellet delivery and patch threshold data depletion_state_df = fetch_stream( streams.UndergroundFeederDepletionState & patch_key & chunk_restriction diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index 010f3824..45ac73b8 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -163,7 +163,7 @@ def ingest_environment_visits(experiment_names: list | None = None): "enter_exit_time", "event_type", order_by="enter_exit_time", - ) + ), strict=False ) ) enter_exit_df.columns = ["id", "time", "event"] @@ -243,7 +243,7 @@ def get_maintenance_periods(experiment_name, start, end): return deque( [ (pd.Timestamp(start), pd.Timestamp(end)) - for start, end in zip(maintenance_starts, maintenance_ends) + for start, end in zip(maintenance_starts, maintenance_ends, strict=False) ] ) # queue object. pop out from left after use diff --git a/aeon/dj_pipeline/create_experiments/create_presocial.py b/aeon/dj_pipeline/create_experiments/create_presocial.py index c66d7725..19e9cc84 100644 --- a/aeon/dj_pipeline/create_experiments/create_presocial.py +++ b/aeon/dj_pipeline/create_experiments/create_presocial.py @@ -47,7 +47,7 @@ def create_new_experiment(): "directory_type": "raw", "directory_path": f"aeon/data/raw/{computer}/{experiment_type}", } - for experiment_name, computer in zip(experiment_names, computers) + for experiment_name, computer in zip(experiment_names, computers, strict=False) ], skip_duplicates=True, ) diff --git a/aeon/dj_pipeline/lab.py b/aeon/dj_pipeline/lab.py index b0c6204b..413d2763 100644 --- a/aeon/dj_pipeline/lab.py +++ b/aeon/dj_pipeline/lab.py @@ -85,7 +85,7 @@ class ArenaShape(dj.Lookup): definition = """ arena_shape: varchar(32) """ - contents = zip(["square", "circular", "rectangular", "linear", "octagon"]) + contents = zip(["square", "circular", "rectangular", "linear", "octagon"], strict=False) @schema diff --git a/aeon/dj_pipeline/report.py b/aeon/dj_pipeline/report.py index 9e60a3b2..fbc5cc3b 100644 --- a/aeon/dj_pipeline/report.py +++ b/aeon/dj_pipeline/report.py @@ -206,7 +206,7 @@ def make(self, key): label="nest", ) for patch_idx, (patch_name, in_patch) in enumerate( - zip(patch_names, in_patches) + zip(patch_names, in_patches, strict=False) ): ethogram_ax.plot( position_minutes_elapsed[in_patch], @@ -584,7 +584,7 @@ def _make_path(in_arena_key): def _save_figs(figs, fig_names, save_dir, prefix, extension=".png"): """Save figures and return a dictionary with figure names and file paths""" fig_dict = {} - for fig, figname in zip(figs, fig_names): + for fig, figname in zip(figs, fig_names, strict=False): fig_fp = save_dir / (prefix + "_" + figname + extension) fig.tight_layout() fig.savefig(fig_fp, dpi=300) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index b36222c0..2e8a6c39 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -283,7 +283,7 @@ def is_position_in_nest(position_df, nest_key, xcol="x", ycol="y") -> pd.Series: return an array of boolean indicating whether or not a position is inside the nest. """ nest_vertices = list( - zip(*(lab.ArenaNest.Vertex & nest_key).fetch("vertex_x", "vertex_y")) + zip(*(lab.ArenaNest.Vertex & nest_key).fetch("vertex_x", "vertex_y"), strict=False) ) nest_path = matplotlib.path.Path(nest_vertices) position_df["in_nest"] = nest_path.contains_points(position_df[[xcol, ycol]]) @@ -336,7 +336,7 @@ def _get_position( position = pd.DataFrame( { k: np.hstack(v) * scale_factor if k in attrs_to_scale else np.hstack(v) - for k, v in zip(fetch_attrs, fetched_data) + for k, v in zip(fetch_attrs, fetched_data, strict=False) } ) position.set_index(timestamp_attr, inplace=True) diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index be0ac80b..d9a3d93c 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -53,7 +53,7 @@ def plot_reward_rate_differences(subject_keys): y_labels = [ f'{subj_name}_{sess_start.strftime("%m/%d/%Y")}' - for subj_name, sess_start in zip(subj_names, sess_starts) + for subj_name, sess_start in zip(subj_names, sess_starts, strict=False) ] rateDiffs_matrix = np.full((nSessions, longest_rateDiff), np.nan) @@ -110,7 +110,7 @@ def plot_wheel_travelled_distance(session_keys): distance_travelled_df["in_arena"] = [ f'{subj_name}_{sess_start.strftime("%m/%d/%Y")}' for subj_name, sess_start in zip( - distance_travelled_df.subject, distance_travelled_df.in_arena_start + distance_travelled_df.subject, distance_travelled_df.in_arena_start, strict=False ) ] From e320939c56fb0b3746850fdf399fcc6deff1dc75 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 23:43:34 +0000 Subject: [PATCH 030/143] fix: add to previous commit another change --- aeon/dj_pipeline/tracking.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 2e8a6c39..fe7ccce9 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -1,6 +1,5 @@ """DataJoint schema for tracking data.""" - import datajoint as dj import matplotlib.path import numpy as np @@ -283,7 +282,10 @@ def is_position_in_nest(position_df, nest_key, xcol="x", ycol="y") -> pd.Series: return an array of boolean indicating whether or not a position is inside the nest. """ nest_vertices = list( - zip(*(lab.ArenaNest.Vertex & nest_key).fetch("vertex_x", "vertex_y"), strict=False) + zip( + *(lab.ArenaNest.Vertex & nest_key).fetch("vertex_x", "vertex_y"), + strict=False, + ) ) nest_path = matplotlib.path.Path(nest_vertices) position_df["in_nest"] = nest_path.contains_points(position_df[[xcol, ycol]]) From 3ac365282996fd6cf28a7a3a1035f2c286e38e1c Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 23:51:18 +0000 Subject: [PATCH 031/143] fix: resolve S324 issue replaced `hashlib.md5` with `haslib.sha256` to improve security --- aeon/dj_pipeline/__init__.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/__init__.py b/aeon/dj_pipeline/__init__.py index 27bff7c8..3b39c1ba 100644 --- a/aeon/dj_pipeline/__init__.py +++ b/aeon/dj_pipeline/__init__.py @@ -15,7 +15,9 @@ db_prefix = dj.config["custom"].get("database.prefix", _default_database_prefix) -repository_config = dj.config["custom"].get("repository_config", _default_repository_config) +repository_config = dj.config["custom"].get( + "repository_config", _default_repository_config +) def get_schema_name(name) -> str: @@ -25,11 +27,11 @@ def get_schema_name(name) -> str: def dict_to_uuid(key) -> uuid.UUID: """Given a dictionary `key`, returns a hash string as UUID.""" - hashed = hashlib.md5() + hashed = hashlib.sha256() for k, v in sorted(key.items()): hashed.update(str(k).encode()) hashed.update(str(v).encode()) - return uuid.UUID(hex=hashed.hexdigest()) + return uuid.UUID(hex=hashed.hexdigest()[:32]) def fetch_stream(query, drop_pk=True): @@ -40,7 +42,9 @@ def fetch_stream(query, drop_pk=True): """ df = (query & "sample_count > 0").fetch(format="frame").reset_index() cols2explode = [ - c for c in query.heading.secondary_attributes if query.heading.attributes[c].type == "longblob" + c + for c in query.heading.secondary_attributes + if query.heading.attributes[c].type == "longblob" ] df = df.explode(column=cols2explode) cols2drop = ["sample_count"] + (query.primary_key if drop_pk else []) From 71f100d946f7df7e7f018c81b1622efdfbffd255 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Mon, 28 Oct 2024 23:58:10 +0000 Subject: [PATCH 032/143] fix: resolve E722 issues Avoid using bare `except` clauses --- aeon/dj_pipeline/__init__.py | 2 +- .../dj_pipeline/create_experiments/create_socialexperiment_0.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/__init__.py b/aeon/dj_pipeline/__init__.py index 3b39c1ba..0cfeef23 100644 --- a/aeon/dj_pipeline/__init__.py +++ b/aeon/dj_pipeline/__init__.py @@ -68,5 +68,5 @@ def fetch_stream(query, drop_pk=True): from .utils import streams_maker streams = dj.VirtualModule("streams", streams_maker.schema_name) - except: + except ImportError: pass diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py index f0f0dde8..3623ce4f 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py @@ -176,7 +176,7 @@ def fixID(subjid, valid_ids=None, valid_id_file=None): try: ld = [jl.levenshtein_distance(subjid, x[-len(subjid) :]) for x in valid_ids] return valid_ids[np.argmin(ld)] - except: + except ValueError: return subjid From 6ab8badb43df38b1ec2802f14dd6646c7e689285 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 00:08:39 +0000 Subject: [PATCH 033/143] fix: resolve B904 issue raise exception with `raise...from err` --- aeon/dj_pipeline/utils/paths.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/utils/paths.py b/aeon/dj_pipeline/utils/paths.py index b2b38d9a..14376506 100644 --- a/aeon/dj_pipeline/utils/paths.py +++ b/aeon/dj_pipeline/utils/paths.py @@ -66,8 +66,8 @@ def find_root_directory( if pathlib.Path(root_dir) in set(full_path.parents) ) - except StopIteration: + except StopIteration as err: raise FileNotFoundError( f"No valid root directory found (from {root_directories})" f" for {full_path}" - ) + ) from err From 07d6a35336bd77c52dfd97eae8a148365c5f4d2b Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 00:12:31 +0000 Subject: [PATCH 034/143] fix: resolve UP038 issues use the new `X | Y` syntax in `isinstance` calls instead of using (X, Y) (in py3.10) --- aeon/dj_pipeline/subject.py | 2 +- aeon/dj_pipeline/utils/paths.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index f1e9ad88..cfcbd475 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -486,7 +486,7 @@ def get_pyrat_data(endpoint: str, params: dict = None, **kwargs): if params is not None: params_str_list = [] for k, v in params.items(): - if isinstance(v, (list, tuple)): + if isinstance(v, (list | tuple)): for i in v: params_str_list.append(f"{k}={i}") else: diff --git a/aeon/dj_pipeline/utils/paths.py b/aeon/dj_pipeline/utils/paths.py index 14376506..3271fc30 100644 --- a/aeon/dj_pipeline/utils/paths.py +++ b/aeon/dj_pipeline/utils/paths.py @@ -56,7 +56,7 @@ def find_root_directory( raise FileNotFoundError(f"{full_path} does not exist!") # turn to list if only a single root directory is provided - if isinstance(root_directories, (str, pathlib.Path)): + if isinstance(root_directories, (str | pathlib.Path)): root_directories = [root_directories] try: From 445770a7669d3b655419f79de00da848f6a882be Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 00:31:22 +0000 Subject: [PATCH 035/143] fix: resolve D205 issues blank lines between summary line and description --- aeon/dj_pipeline/analysis/block_analysis.py | 16 +++- aeon/dj_pipeline/analysis/visit.py | 15 ++-- aeon/dj_pipeline/analysis/visit_analysis.py | 15 ++-- aeon/dj_pipeline/lab.py | 5 +- aeon/dj_pipeline/report.py | 25 ++++-- .../scripts/clone_and_freeze_exp01.py | 4 +- .../scripts/clone_and_freeze_exp02.py | 8 +- .../scripts/update_timestamps_longblob.py | 16 +++- aeon/dj_pipeline/tracking.py | 7 +- aeon/dj_pipeline/utils/load_metadata.py | 11 ++- aeon/dj_pipeline/utils/paths.py | 1 + aeon/dj_pipeline/utils/plotting.py | 6 +- aeon/dj_pipeline/utils/streams_maker.py | 27 +++++-- aeon/io/reader.py | 79 ++++++++++++++----- 14 files changed, 181 insertions(+), 54 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 1ae8972d..e272690d 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -138,7 +138,9 @@ class BlockAnalysis(dj.Computed): @property def key_source(self): """ + Ensure that the chunk ingestion has caught up with this block before processing + (there exists a chunk that ends after the block end time) """ ks = Block.aggr(acquisition.Chunk, latest_chunk_end="MAX(chunk_end)") @@ -175,8 +177,12 @@ class Subject(dj.Part): """ def make(self, key): - """Restrict, fetch and aggregate data from different streams to + """ + + Restrict, fetch and aggregate data from different streams to + produce intermediate data products at a per-block level + (for different patches and different subjects). 1. Query data for all chunks within the block. @@ -1787,9 +1793,13 @@ class AnalysisNote(dj.Manual): def get_threshold_associated_pellets(patch_key, start, end): - """Retrieve the pellet delivery timestamps associated with each patch threshold update + """ + Retrieve the pellet delivery timestamps associated with each patch threshold update + within the specified start-end time. + + Notes: 1. Get all patch state update timestamps (DepletionState): let's call these events "A" - Remove all events within 1 second of each other - Remove all events without threshold value (NaN) @@ -1801,10 +1811,12 @@ def get_threshold_associated_pellets(patch_key, start, end): 4. Shift back the pellet delivery timestamps by 1 to match the pellet delivery with the previous threshold update 5. Remove all threshold updates events "A" without a corresponding pellet delivery event "B" + Args: patch_key (dict): primary key for the patch start (datetime): start timestamp end (datetime): end timestamp + Returns: pd.DataFrame: DataFrame with the following columns: - threshold_update_timestamp (index) diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index 45ac73b8..ef223143 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -122,10 +122,14 @@ def make(self, key): def ingest_environment_visits(experiment_names: list | None = None): - """Function to populate into `Visit` and `VisitEnd` for specified - experiments (default: 'exp0.2-r0'). This ingestion routine handles - only those "complete" visits, not ingesting any "on-going" visits - using "analyze" method: `aeon.analyze.utils.visits()`. + """ + Function to populate into `Visit` and `VisitEnd` for specified + + experiments (default: 'exp0.2-r0'). + + This ingestion routine handles only those "complete" visits, + not ingesting any "on-going" visits using "analyze" method: + `aeon.analyze.utils.visits()`. Args: experiment_names (list, optional): list of names of the experiment @@ -163,7 +167,8 @@ def ingest_environment_visits(experiment_names: list | None = None): "enter_exit_time", "event_type", order_by="enter_exit_time", - ), strict=False + ), + strict=False, ) ) enter_exit_df.columns = ["id", "time", "event"] diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 5d900d0a..05205d1b 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -84,7 +84,9 @@ class TimeSlice(dj.Part): @property def key_source(self): - """Chunk for all visits: + """ + Chunk for all visits: + + visit_start during this Chunk - i.e. first chunk of the visit + visit_end during this Chunk - i.e. last chunk of the visit + chunk starts after visit_start and ends before visit_end (or NOW() - i.e. ongoing visits). @@ -201,8 +203,12 @@ def make(self, key): @classmethod def get_position(cls, visit_key=None, subject=None, start=None, end=None): - """Given a key to a single Visit, return a Pandas DataFrame for - the position data of the subject for the specified Visit time period.""" + """ + Return a Pandas df of the subject's position data for a specified Visit given its key. + + Given a key to a single Visit, return a Pandas DataFrame for + the position data of the subject for the specified Visit time period. + """ if visit_key is not None: if len(Visit & visit_key) != 1: raise ValueError( @@ -560,8 +566,7 @@ def make(self, key): @schema class VisitForagingBout(dj.Computed): """ - A time period spanning the time when the animal enters a food patch and moves - the wheel to when it leaves the food patch. + Time period from when the animal enters to when it leaves a food patch while moving the wheel. """ definition = """ diff --git a/aeon/dj_pipeline/lab.py b/aeon/dj_pipeline/lab.py index 413d2763..9f0966b2 100644 --- a/aeon/dj_pipeline/lab.py +++ b/aeon/dj_pipeline/lab.py @@ -85,12 +85,15 @@ class ArenaShape(dj.Lookup): definition = """ arena_shape: varchar(32) """ - contents = zip(["square", "circular", "rectangular", "linear", "octagon"], strict=False) + contents = zip( + ["square", "circular", "rectangular", "linear", "octagon"], strict=False + ) @schema class Arena(dj.Lookup): """Coordinate frame convention: + + x-dimension: x=0 is the left most point of the bounding box of the arena + y-dimension: y=0 is the top most point of the bounding box of the arena + z-dimension: z=0 is the lowest point of the arena (e.g. the ground) diff --git a/aeon/dj_pipeline/report.py b/aeon/dj_pipeline/report.py index fbc5cc3b..22b11c95 100644 --- a/aeon/dj_pipeline/report.py +++ b/aeon/dj_pipeline/report.py @@ -318,10 +318,15 @@ def make(self, key): @classmethod def delete_outdated_entries(cls): - """Each entry in this table correspond to one subject. + """ + Dynamically update the plot for all sessions. + + Each entry in this table correspond to one subject. However, the plot is capturing data for all sessions. Hence a dynamic update routine is needed to recompute - the plot as new sessions become available.""" + the plot as new sessions become available. + + """ outdated_entries = ( cls * ( @@ -367,10 +372,14 @@ def make(self, key): @classmethod def delete_outdated_entries(cls): - """Each entry in this table correspond to one subject. + """ + Dynamically update the plot for all sessions. + + Each entry in this table correspond to one subject. However the plot is capturing data for all sessions. Hence a dynamic update routine is needed to recompute - the plot as new sessions become available.""" + the plot as new sessions become available. + """ outdated_entries = ( cls * ( @@ -414,10 +423,14 @@ def make(self, key): @classmethod def delete_outdated_entries(cls): - """Each entry in this table correspond to one subject. + """ + Dynamically update the plot for all sessions. + + Each entry in this table correspond to one subject. However the plot is capturing data for all sessions. Hence a dynamic update routine is needed to recompute - the plot as new sessions become available.""" + the plot as new sessions become available. + """ outdated_entries = ( cls * ( diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py index a85042b4..3859c51e 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py @@ -1,4 +1,6 @@ -"""March 2022 +""" +March 2022 + Cloning and archiving schemas and data for experiment 0.1. """ diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py index 024d0900..bc8494a7 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py @@ -1,4 +1,7 @@ -"""Jan 2024: Cloning and archiving schemas and data for experiment 0.2. +""" +Jan 2024 + +Cloning and archiving schemas and data for experiment 0.2. The pipeline code associated with this archived data pipeline is here https://github.com/SainsburyWellcomeCentre/aeon_mecha/releases/tag/dj_exp02_stable @@ -108,7 +111,8 @@ def validate(): target_entry_count = len(target_tbl()) missing_entries[orig_schema_name][source_tbl.__name__] = { "entry_count_diff": source_entry_count - target_entry_count, - "db_size_diff": source_tbl().size_on_disk - target_tbl().size_on_disk, + "db_size_diff": source_tbl().size_on_disk + - target_tbl().size_on_disk, } return { diff --git a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py index 9cd845ef..0946dff9 100644 --- a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py +++ b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py @@ -1,4 +1,6 @@ -"""July 2022 +""" +July 2022 + Upgrade all timestamps longblob fields with datajoint 0.13.7. """ @@ -12,7 +14,9 @@ from tqdm import tqdm if dj.__version__ < "0.13.7": - raise ImportError(f"DataJoint version must be at least 0.13.7, but found {dj.__version__}.") + raise ImportError( + f"DataJoint version must be at least 0.13.7, but found {dj.__version__}." + ) schema = dj.schema("u_thinh_aeonfix") @@ -40,7 +44,13 @@ def main(): for schema_name in schema_names: vm = dj.create_virtual_module(schema_name, schema_name) table_names = [ - ".".join([dj.utils.to_camel_case(s) for s in tbl_name.strip("`").split("__") if s]) + ".".join( + [ + dj.utils.to_camel_case(s) + for s in tbl_name.strip("`").split("__") + if s + ] + ) for tbl_name in vm.schema.list_tables() ] for table_name in table_names: diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index fe7ccce9..c3af9474 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -121,6 +121,8 @@ def insert_new_params( @schema class SLEAPTracking(dj.Imported): """ + Tracking data from SLEAP for multi-animal experiments. + Tracked objects position data from a particular VideoSource for multi-animal experiment using the SLEAP tracking method per chunk. @@ -278,7 +280,10 @@ def is_position_in_patch( def is_position_in_nest(position_df, nest_key, xcol="x", ycol="y") -> pd.Series: - """Given the session key and the position data - arrays of x and y + """ + Check if a position is inside the nest. + + Notes: Given the session key and the position data - arrays of x and y return an array of boolean indicating whether or not a position is inside the nest. """ nest_vertices = list( diff --git a/aeon/dj_pipeline/utils/load_metadata.py b/aeon/dj_pipeline/utils/load_metadata.py index 2c7f2aa3..2046c072 100644 --- a/aeon/dj_pipeline/utils/load_metadata.py +++ b/aeon/dj_pipeline/utils/load_metadata.py @@ -50,7 +50,11 @@ def insert_stream_types(): def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): - """Use aeon.schema.schemas and metadata.yml to insert into streams.DeviceType and streams.Device. + """ + + Insert device types into streams.DeviceType and streams.Device. + + Notes: Use aeon.schema.schemas and metadata.yml to insert into streams.DeviceType and streams.Device. Only insert device types that were defined both in the device schema (e.g., exp02) and Metadata.yml. It then creates new device tables under streams schema. """ @@ -470,7 +474,10 @@ def _get_class_path(obj): def get_device_mapper(devices_schema: DotMap, metadata_yml_filepath: Path): - """Returns a mapping dictionary between device name and device type + """ + Returns a mapping dictionary of device names to types based on the dataset schema and metadata.yml. + + Notes: Returns a mapping dictionary between device name and device type based on the dataset schema and metadata.yml from the experiment. Store the mapper dictionary and read from it if the type info doesn't exist in Metadata.yml. diff --git a/aeon/dj_pipeline/utils/paths.py b/aeon/dj_pipeline/utils/paths.py index 3271fc30..d2f9bb75 100644 --- a/aeon/dj_pipeline/utils/paths.py +++ b/aeon/dj_pipeline/utils/paths.py @@ -37,6 +37,7 @@ def find_root_directory( root_directories: str | pathlib.Path, full_path: str | pathlib.Path ) -> pathlib.Path: """Given multiple potential root directories and a full-path, + search and return one directory that is the parent of the given path. Args: diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index d9a3d93c..af1c475a 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -26,6 +26,7 @@ def plot_reward_rate_differences(subject_keys): """Plotting the reward rate differences between food patches + (Patch 2 - Patch 1) for all sessions from all subjects specified in "subject_keys". Examples: @@ -83,6 +84,7 @@ def plot_reward_rate_differences(subject_keys): def plot_wheel_travelled_distance(session_keys): """Plotting the wheel travelled distance for different patches + for all sessions specified in "session_keys". Examples: @@ -110,7 +112,9 @@ def plot_wheel_travelled_distance(session_keys): distance_travelled_df["in_arena"] = [ f'{subj_name}_{sess_start.strftime("%m/%d/%Y")}' for subj_name, sess_start in zip( - distance_travelled_df.subject, distance_travelled_df.in_arena_start, strict=False + distance_travelled_df.subject, + distance_travelled_df.in_arena_start, + strict=False, ) ] diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index f86b1e1d..aab8915e 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -23,9 +23,13 @@ class StreamType(dj.Lookup): - """Catalog of all steam types for the different device types used across + """ + Catalog of all steam types for the different device types used across + Project Aeon. One StreamType corresponds to one reader class in `aeon.io.reader`. + The combination of `stream_reader` and `stream_reader_kwargs` should fully + specify the data loading routine for a particular device, using the `aeon.io.utils`. """ @@ -114,7 +118,10 @@ def get_device_stream_template(device_type: str, stream_type: str, streams_modul # DeviceDataStream table(s) stream_detail = ( streams_module.StreamType - & (streams_module.DeviceType.Stream & {"device_type": device_type, "stream_type": stream_type}) + & ( + streams_module.DeviceType.Stream + & {"device_type": device_type, "stream_type": stream_type} + ) ).fetch1() for i, n in enumerate(stream_detail["stream_reader"].split(".")): @@ -151,12 +158,15 @@ class DeviceDataStream(dj.Imported): def key_source(self): """ Only the combination of Chunk and device_type with overlapping time + + Chunk(s) that started after device_type install time and ended before device_type remove time + + Chunk(s) that started after device_type install time for device_type that are not yet removed """ key_source_query = ( - acquisition.Chunk * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) + acquisition.Chunk + * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) & f"chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time" & f'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time,\ "2200-01-01")' @@ -166,7 +176,9 @@ def key_source(self): def make(self, key): """Load and insert the data for the DeviceDataStream table.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) device_name = (ExperimentDevice & key).fetch1( @@ -176,10 +188,13 @@ def make(self, key): devices_schema = getattr( aeon_schemas, ( - acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]} + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} ).fetch1("devices_schema_name"), ) - stream_reader = getattr(getattr(devices_schema, device_name), "{stream_type}") + stream_reader = getattr( + getattr(devices_schema, device_name), "{stream_type}" + ) stream_data = io_api.load( root=data_dirs, diff --git a/aeon/io/reader.py b/aeon/io/reader.py index cdce7eb7..4d598db5 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -70,8 +70,12 @@ def read(self, file): payloadtype = _payloadtypes[data[4] & ~0x10] elementsize = payloadtype.itemsize payloadshape = (length, payloadsize // elementsize) - seconds = np.ndarray(length, dtype=np.uint32, buffer=data, offset=5, strides=stride) - ticks = np.ndarray(length, dtype=np.uint16, buffer=data, offset=9, strides=stride) + seconds = np.ndarray( + length, dtype=np.uint32, buffer=data, offset=5, strides=stride + ) + ticks = np.ndarray( + length, dtype=np.uint16, buffer=data, offset=9, strides=stride + ) seconds = ticks * _SECONDS_PER_TICK + seconds payload = np.ndarray( payloadshape, @@ -82,7 +86,9 @@ def read(self, file): ) if self.columns is not None and payloadshape[1] < len(self.columns): - data = pd.DataFrame(payload, index=seconds, columns=self.columns[: payloadshape[1]]) + data = pd.DataFrame( + payload, index=seconds, columns=self.columns[: payloadshape[1]] + ) data[self.columns[payloadshape[1] :]] = math.nan return data else: @@ -111,13 +117,17 @@ class Metadata(Reader): def __init__(self, pattern="Metadata"): """Initialize the object with the specified pattern.""" - super().__init__(pattern, columns=["workflow", "commit", "metadata"], extension="yml") + super().__init__( + pattern, columns=["workflow", "commit", "metadata"], extension="yml" + ) def read(self, file): """Returns metadata for the specified epoch.""" epoch_str = file.parts[-2] date_str, time_str = epoch_str.split("T") - time = datetime.datetime.fromisoformat(date_str + "T" + time_str.replace("-", ":")) + time = datetime.datetime.fromisoformat( + date_str + "T" + time_str.replace("-", ":") + ) with open(file) as fp: metadata = json.load(fp) workflow = metadata.pop("Workflow") @@ -151,6 +161,7 @@ def read(self, file): class JsonList(Reader): """ Extracts data from json list (.jsonl) files, + where the key "seconds" stores the Aeon timestamp, in seconds. """ @@ -260,7 +271,9 @@ class Position(Harp): def __init__(self, pattern): """Initialize the object with a specified pattern and columns.""" - super().__init__(pattern, columns=["x", "y", "angle", "major", "minor", "area", "id"]) + super().__init__( + pattern, columns=["x", "y", "angle", "major", "minor", "area", "id"] + ) class BitmaskEvent(Harp): @@ -319,7 +332,9 @@ class Video(Csv): def __init__(self, pattern): """Initialize the object with a specified pattern.""" - super().__init__(pattern, columns=["hw_counter", "hw_timestamp", "_frame", "_path", "_epoch"]) + super().__init__( + pattern, columns=["hw_counter", "hw_timestamp", "_frame", "_path", "_epoch"] + ) self._rawcolumns = ["time"] + self.columns[0:2] def read(self, file): @@ -333,7 +348,9 @@ def read(self, file): class Pose(Harp): - """Reader for Harp-binarized tracking data given a model + """ + Reader for Harp-binarized tracking data given a model + that outputs id, parts, and likelihoods. Columns: @@ -345,7 +362,9 @@ class (int): Int ID of a subject in the environment. y (float): Y-coordinate of the bodypart. """ - def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/processed"): + def __init__( + self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/processed" + ): """Pose reader constructor.""" # `pattern` for this reader should typically be '_*' super().__init__(pattern, columns=None) @@ -384,10 +403,16 @@ def read(self, file: Path) -> pd.DataFrame: # Drop any repeat parts. unique_parts, unique_idxs = np.unique(parts, return_index=True) repeat_idxs = np.setdiff1d(np.arange(len(parts)), unique_idxs) - if repeat_idxs: # drop x, y, and likelihood cols for repeat parts (skip first 5 cols) + if ( + repeat_idxs + ): # drop x, y, and likelihood cols for repeat parts (skip first 5 cols) init_rep_part_col_idx = (repeat_idxs - 1) * 3 + 5 - rep_part_col_idxs = np.concatenate([np.arange(i, i + 3) for i in init_rep_part_col_idx]) - keep_part_col_idxs = np.setdiff1d(np.arange(len(data.columns)), rep_part_col_idxs) + rep_part_col_idxs = np.concatenate( + [np.arange(i, i + 3) for i in init_rep_part_col_idx] + ) + keep_part_col_idxs = np.setdiff1d( + np.arange(len(data.columns)), rep_part_col_idxs + ) data = data.iloc[:, keep_part_col_idxs] parts = unique_parts @@ -395,18 +420,25 @@ def read(self, file: Path) -> pd.DataFrame: data = self.class_int2str(data, config_file) n_parts = len(parts) part_data_list = [pd.DataFrame()] * n_parts - new_columns = pd.Series(["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]) + new_columns = pd.Series( + ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"] + ) new_data = pd.DataFrame(columns=new_columns) for i, part in enumerate(parts): part_columns = ( - columns[0 : (len(identities) + 1)] if bonsai_sleap_v == BONSAI_SLEAP_V3 else columns[0:2] + columns[0 : (len(identities) + 1)] + if bonsai_sleap_v == BONSAI_SLEAP_V3 + else columns[0:2] ) part_columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"]) part_data = pd.DataFrame(data[part_columns]) if bonsai_sleap_v == BONSAI_SLEAP_V3: # combine all identity_likelihood cols into a single col as dict part_data["identity_likelihood"] = part_data.apply( - lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, + lambda row: { + identity: row[f"{identity}_likelihood"] + for identity in identities + }, axis=1, ) part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True) @@ -471,10 +503,14 @@ def class_int2str(data: pd.DataFrame, config_file: Path) -> pd.DataFrame: return data @classmethod - def get_config_file(cls, config_file_dir: Path, config_file_names: None | list[str] = None) -> Path: + def get_config_file( + cls, config_file_dir: Path, config_file_names: None | list[str] = None + ) -> Path: """Returns the config file from a model's config directory.""" if config_file_names is None: - config_file_names = ["confmap_config.json"] # SLEAP (add for other trackers to this list) + config_file_names = [ + "confmap_config.json" + ] # SLEAP (add for other trackers to this list) config_file = None for f in config_file_names: if (config_file_dir / f).exists(): @@ -493,7 +529,10 @@ def from_dict(data, pattern=None): return globals()[reader_type](pattern=pattern, **kwargs) return DotMap( - {k: from_dict(v, f"{pattern}_{k}" if pattern is not None else k) for k, v in data.items()} + { + k: from_dict(v, f"{pattern}_{k}" if pattern is not None else k) + for k, v in data.items() + } ) @@ -501,7 +540,9 @@ def to_dict(dotmap): """Converts a DotMap object to a dictionary.""" if isinstance(dotmap, Reader): kwargs = { - k: v for k, v in vars(dotmap).items() if k not in ["pattern"] and not k.startswith("_") + k: v + for k, v in vars(dotmap).items() + if k not in ["pattern"] and not k.startswith("_") } kwargs["type"] = type(dotmap).__name__ return kwargs From 4c82c4068108d273398b4fe491ca60146dd5bb4d Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 00:32:50 +0000 Subject: [PATCH 036/143] fix: resolve D202 issues --- aeon/dj_pipeline/analysis/block_analysis.py | 1 - aeon/dj_pipeline/analysis/visit.py | 1 - aeon/dj_pipeline/create_experiments/create_presocial.py | 1 - aeon/dj_pipeline/utils/streams_maker.py | 1 - 4 files changed, 4 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index e272690d..5d2698d5 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -1482,7 +1482,6 @@ class BlockSubjectPositionPlots(dj.Computed): def make(self, key): """Compute and plot various block-level statistics and visualizations.""" - # Get some block info block_start, block_end = (Block & key).fetch1("block_start", "block_end") chunk_restriction = acquisition.create_chunk_restriction( diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index ef223143..8adcde7f 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -135,7 +135,6 @@ def ingest_environment_visits(experiment_names: list | None = None): experiment_names (list, optional): list of names of the experiment to populate into the Visit table. Defaults to None. """ - if experiment_names is None: experiment_names = ["exp0.2-r0"] place_key = {"place": "environment"} diff --git a/aeon/dj_pipeline/create_experiments/create_presocial.py b/aeon/dj_pipeline/create_experiments/create_presocial.py index 19e9cc84..4df4453b 100644 --- a/aeon/dj_pipeline/create_experiments/create_presocial.py +++ b/aeon/dj_pipeline/create_experiments/create_presocial.py @@ -55,7 +55,6 @@ def create_new_experiment(): def main(): """Main function to create a new experiment.""" - create_new_experiment() diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index aab8915e..c745c7e0 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -163,7 +163,6 @@ def key_source(self): + Chunk(s) that started after device_type install time for device_type that are not yet removed """ - key_source_query = ( acquisition.Chunk * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) From 7ddb5cdc1db71c3e9e9da1fcf5c35d6d4d92f97a Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 00:37:43 +0000 Subject: [PATCH 037/143] fix: resolve F403 unused import * --- aeon/dj_pipeline/analysis/visit_analysis.py | 1 - aeon/dj_pipeline/report.py | 1 - 2 files changed, 2 deletions(-) diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 05205d1b..aa3a2f8b 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -19,7 +19,6 @@ # schema = dj.schema(get_schema_name("analysis")) schema = dj.schema() - # ---------- Position Filtering Method ------------------ diff --git a/aeon/dj_pipeline/report.py b/aeon/dj_pipeline/report.py index 22b11c95..e2c444c6 100644 --- a/aeon/dj_pipeline/report.py +++ b/aeon/dj_pipeline/report.py @@ -12,7 +12,6 @@ from aeon.analysis import plotting as analysis_plotting from aeon.dj_pipeline.analysis.visit import Visit, VisitEnd -from aeon.dj_pipeline.analysis.visit_analysis import * from . import acquisition, analysis From 8b58eeeb0674cd82eabcdd1c14f71fe56d4fd5aa Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 00:50:12 +0000 Subject: [PATCH 038/143] fix: Resolve PLR2004 issues Replace magic values by constant values --- aeon/dj_pipeline/analysis/block_analysis.py | 29 +++++++++++++++------ aeon/dj_pipeline/analysis/visit_analysis.py | 8 ++++-- aeon/dj_pipeline/subject.py | 4 ++- aeon/dj_pipeline/tracking.py | 6 +++-- 4 files changed, 34 insertions(+), 13 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 5d2698d5..fe2e2f3a 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -639,7 +639,10 @@ def make(self, key): all_subj_patch_pref_dict[patch_name][subject_name]["cum_dist"] / all_cum_dist ) - cum_pref_dist = np.where(cum_pref_dist < 1e-3, 0, cum_pref_dist) + CUM_PREF_DIST_MIN = 1e-3 + cum_pref_dist = np.where( + cum_pref_dist < CUM_PREF_DIST_MIN, 0, cum_pref_dist + ) all_subj_patch_pref_dict[patch_name][subject_name][ "cum_pref_dist" ] = cum_pref_dist @@ -1853,19 +1856,29 @@ def get_threshold_associated_pellets(patch_key, start, end): ) # Step 2 - Remove invalid rows (back-to-back events) - # pellet delivery trigger - time difference is less than 1.2 seconds - invalid_rows = delivered_pellet_df.index.to_series().diff().dt.total_seconds() < 1.2 + BTB_TIME_DIFF = ( + 1.2 # pellet delivery trigger - time difference is less than 1.2 seconds + ) + invalid_rows = ( + delivered_pellet_df.index.to_series().diff().dt.total_seconds() < BTB_TIME_DIFF + ) delivered_pellet_df = delivered_pellet_df[~invalid_rows] # exclude manual deliveries delivered_pellet_df = delivered_pellet_df.loc[ delivered_pellet_df.index.difference(manual_delivery_df.index) ] - # beambreak - time difference is less than 1 seconds - invalid_rows = beambreak_df.index.to_series().diff().dt.total_seconds() < 1 + + BB_TIME_DIFF = 1.0 # beambreak - time difference is less than 1 seconds + invalid_rows = ( + beambreak_df.index.to_series().diff().dt.total_seconds() < BB_TIME_DIFF + ) beambreak_df = beambreak_df[~invalid_rows] - # patch threshold - time difference is less than 1 seconds + + PT_TIME_DIFF = 1.0 # patch threshold - time difference is less than 1 seconds depletion_state_df = depletion_state_df.dropna(subset=["threshold"]) - invalid_rows = depletion_state_df.index.to_series().diff().dt.total_seconds() < 1 + invalid_rows = ( + depletion_state_df.index.to_series().diff().dt.total_seconds() < PT_TIME_DIFF + ) depletion_state_df = depletion_state_df[~invalid_rows] # Return empty if no data @@ -1882,7 +1895,7 @@ def get_threshold_associated_pellets(patch_key, start, end): beambreak_df.reset_index().rename(columns={"time": "beam_break_timestamp"}), left_on="time", right_on="beam_break_timestamp", - tolerance=pd.Timedelta("1.2s"), + tolerance=pd.Timedelta("{BTB_TIME_DIFF}s"), direction="forward", ) .set_index("time") diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index aa3a2f8b..b0406b95 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -19,6 +19,10 @@ # schema = dj.schema(get_schema_name("analysis")) schema = dj.schema() +# Constants values +MIN_AREA = 0 +MAX_AREA = 1000 + # ---------- Position Filtering Method ------------------ @@ -315,7 +319,7 @@ def make(self, key): ) # filter for objects of the correct size - valid_position = (position.area > 0) & (position.area < 1000) + valid_position = (position.area > MIN_AREA) & (position.area < MAX_AREA) position[~valid_position] = np.nan position.rename( columns={"position_x": "x", "position_y": "y"}, inplace=True @@ -471,7 +475,7 @@ def make(self, key): position, maintenance_period, day_end ) # filter for objects of the correct size - valid_position = (position.area > 0) & (position.area < 1000) + valid_position = (position.area > MIN_AREA) & (position.area < MAX_AREA) position[~valid_position] = np.nan position.rename( columns={"position_x": "x", "position_y": "y"}, inplace=True diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index cfcbd475..6bb69214 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -497,7 +497,9 @@ def get_pyrat_data(endpoint: str, params: dict = None, **kwargs): response = session.get(base_url + endpoint + params_str, **kwargs) - if response.status_code != 200: + RESPONSE_STATUS_CODE_OK = 200 + + if response.status_code != RESPONSE_STATUS_CODE_OK: raise requests.exceptions.HTTPError( f"PyRat API errored out with response code: {response.status_code}" ) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index c3af9474..de0759c7 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -257,11 +257,13 @@ def make(self, key): # ---------- HELPER ------------------ +TARGET_LENGTH = 2 + def compute_distance(position_df, target, xcol="x", ycol="y"): """Compute the distance of the position data from a target point.""" - if len(target) != 2: - raise ValueError("Target must be a list of tuple of length 2.") + if len(target) != TARGET_LENGTH: + raise ValueError(f"Target must be a list of tuple of length {TARGET_LENGTH}.") return np.sqrt(np.square(position_df[[xcol, ycol]] - target).sum(axis=1)) From 4f66d8cc1a86a8da35d897c5359aa11fc6c8eed9 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 00:57:41 +0000 Subject: [PATCH 039/143] fix: resolve SIM108 issue by refactoring code --- aeon/dj_pipeline/analysis/visit_analysis.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index b0406b95..5c3ea591 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -114,13 +114,11 @@ def make(self, key): ) # -- Determine the time to start time_slicing in this chunk - if chunk_start < key["visit_start"] < chunk_end: - # For chunk containing the visit_start - i.e. first chunk of this visit - start_time = key["visit_start"] - else: - # For chunks after the first chunk of this visit - start_time = chunk_start - + start_time = ( + key["visit_start"] + if chunk_start < key["visit_start"] < chunk_end + else chunk_start # Use chunk_start if the visit has not yet started. + ) # -- Determine the time to end time_slicing in this chunk if VisitEnd & key: # finished visit visit_end = (VisitEnd & key).fetch1("visit_end") From 3303c582d9176503f53c811862304cdb193f839a Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 01:02:16 +0000 Subject: [PATCH 040/143] fix: resolve I001 issues --- aeon/dj_pipeline/analysis/visit.py | 9 ++++----- aeon/dj_pipeline/analysis/visit_analysis.py | 2 +- .../create_experiments/create_socialexperiment.py | 2 +- aeon/dj_pipeline/populate/process.py | 4 ++-- aeon/dj_pipeline/populate/worker.py | 3 +-- aeon/dj_pipeline/qc.py | 4 +--- 6 files changed, 10 insertions(+), 14 deletions(-) diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index 8adcde7f..64275bf8 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -1,15 +1,14 @@ """Module for visit-related tables in the analysis schema.""" import datetime +from collections import deque + import datajoint as dj -import pandas as pd import numpy as np -from collections import deque +import pandas as pd from aeon.analysis import utils as analysis_utils - -from aeon.dj_pipeline import get_schema_name, fetch_stream -from aeon.dj_pipeline import acquisition +from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name schema = dj.schema(get_schema_name("analysis")) diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 5c3ea591..bd7a07a1 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -11,8 +11,8 @@ from aeon.dj_pipeline.analysis.visit import ( Visit, VisitEnd, - get_maintenance_periods, filter_out_maintenance_periods, + get_maintenance_periods, ) logger = dj.logger diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py index 581a8999..57540085 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py @@ -1,10 +1,10 @@ """Function to create new social experiments""" from datetime import datetime + from aeon.dj_pipeline import acquisition from aeon.dj_pipeline.utils.paths import get_repository_path - # ---- Programmatic creation of a new social experiment ---- # Infer experiment metadata from the experiment name # User-specified "experiment_name" (everything else should be automatically inferred) diff --git a/aeon/dj_pipeline/populate/process.py b/aeon/dj_pipeline/populate/process.py index d49fead5..d9dbd808 100644 --- a/aeon/dj_pipeline/populate/process.py +++ b/aeon/dj_pipeline/populate/process.py @@ -38,10 +38,10 @@ from aeon.dj_pipeline.populate.worker import ( acquisition_worker, - logger, analysis_worker, - streams_worker, + logger, pyrat_worker, + streams_worker, ) # ---- some wrappers to support execution as script or CLI diff --git a/aeon/dj_pipeline/populate/worker.py b/aeon/dj_pipeline/populate/worker.py index a4605273..f7232693 100644 --- a/aeon/dj_pipeline/populate/worker.py +++ b/aeon/dj_pipeline/populate/worker.py @@ -10,8 +10,7 @@ ) from datajoint_utilities.dj_worker.worker_schema import is_djtable -from aeon.dj_pipeline import db_prefix -from aeon.dj_pipeline import subject, acquisition, tracking, qc +from aeon.dj_pipeline import acquisition, db_prefix, qc, subject, tracking from aeon.dj_pipeline.analysis import block_analysis from aeon.dj_pipeline.utils import streams_maker diff --git a/aeon/dj_pipeline/qc.py b/aeon/dj_pipeline/qc.py index cc665885..cc3b23b3 100644 --- a/aeon/dj_pipeline/qc.py +++ b/aeon/dj_pipeline/qc.py @@ -4,11 +4,9 @@ import numpy as np import pandas as pd +from aeon.dj_pipeline import acquisition, get_schema_name, streams from aeon.io import api as io_api -from aeon.dj_pipeline import get_schema_name -from aeon.dj_pipeline import acquisition, streams - schema = dj.schema(get_schema_name("qc")) logger = dj.logger From d44fed18bb407fc44ea35030c9958b366f0edc68 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 01:04:50 +0000 Subject: [PATCH 041/143] fix: another round of global ruff check --- aeon/__init__.py | 2 +- aeon/analysis/__init__.py | 2 +- aeon/analysis/block_plotting.py | 4 +--- aeon/analysis/utils.py | 2 +- aeon/dj_pipeline/__init__.py | 2 +- aeon/dj_pipeline/acquisition.py | 2 +- aeon/dj_pipeline/analysis/__init__.py | 2 +- aeon/dj_pipeline/analysis/block_analysis.py | 15 +++++---------- aeon/dj_pipeline/analysis/visit.py | 3 +-- aeon/dj_pipeline/analysis/visit_analysis.py | 9 +++------ aeon/dj_pipeline/populate/__init__.py | 2 +- aeon/dj_pipeline/populate/worker.py | 3 +-- aeon/dj_pipeline/report.py | 11 ++++------- .../scripts/clone_and_freeze_exp01.py | 3 +-- .../scripts/clone_and_freeze_exp02.py | 3 +-- .../scripts/update_timestamps_longblob.py | 3 +-- aeon/dj_pipeline/subject.py | 18 +++++++++--------- aeon/dj_pipeline/tracking.py | 6 ++---- aeon/dj_pipeline/utils/__init__.py | 2 +- aeon/dj_pipeline/utils/load_metadata.py | 10 +++------- aeon/dj_pipeline/utils/paths.py | 3 +-- aeon/dj_pipeline/utils/streams_maker.py | 6 ++---- aeon/io/__init__.py | 2 +- aeon/io/device.py | 2 +- aeon/io/reader.py | 10 ++++------ aeon/schema/__init__.py | 2 +- aeon/schema/dataset.py | 2 +- aeon/schema/octagon.py | 2 +- aeon/schema/schemas.py | 2 +- aeon/schema/social_01.py | 2 +- aeon/schema/social_02.py | 2 +- aeon/schema/social_03.py | 2 +- aeon/schema/streams.py | 2 +- 33 files changed, 57 insertions(+), 86 deletions(-) diff --git a/aeon/__init__.py b/aeon/__init__.py index 0dc5ee9d..48a87f97 100644 --- a/aeon/__init__.py +++ b/aeon/__init__.py @@ -1,4 +1,4 @@ -""" Top-level package for aeon. """ +"""Top-level package for aeon.""" from importlib.metadata import PackageNotFoundError, version diff --git a/aeon/analysis/__init__.py b/aeon/analysis/__init__.py index b48aecd3..52f0038f 100644 --- a/aeon/analysis/__init__.py +++ b/aeon/analysis/__init__.py @@ -1 +1 @@ -""" Utilities for analyzing data. """ +"""Utilities for analyzing data.""" diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index 5b04977d..92a2dca1 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -61,9 +61,7 @@ def gen_subject_colors_dict(subject_names): def gen_patch_style_dict(patch_names): - """ - - Based on a list of patches, generates a dictionary of: + """Based on a list of patches, generates a dictionary of: - patch_colors_dict: patch name to color - patch_markers_dict: patch name to marker diff --git a/aeon/analysis/utils.py b/aeon/analysis/utils.py index c8950151..9f0b08e6 100644 --- a/aeon/analysis/utils.py +++ b/aeon/analysis/utils.py @@ -1,4 +1,4 @@ -""" Helper functions for data analysis and visualization.""" +"""Helper functions for data analysis and visualization.""" import numpy as np import pandas as pd diff --git a/aeon/dj_pipeline/__init__.py b/aeon/dj_pipeline/__init__.py index 0cfeef23..c8b0b1c5 100644 --- a/aeon/dj_pipeline/__init__.py +++ b/aeon/dj_pipeline/__init__.py @@ -1,4 +1,4 @@ -""" DataJoint pipeline for Aeon. """ +"""DataJoint pipeline for Aeon.""" import hashlib import os diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index 69e5a142..3d5bf577 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -1,4 +1,4 @@ -""" DataJoint schema for the acquisition pipeline. """ +"""DataJoint schema for the acquisition pipeline.""" import datetime import json diff --git a/aeon/dj_pipeline/analysis/__init__.py b/aeon/dj_pipeline/analysis/__init__.py index b48aecd3..52f0038f 100644 --- a/aeon/dj_pipeline/analysis/__init__.py +++ b/aeon/dj_pipeline/analysis/__init__.py @@ -1 +1 @@ -""" Utilities for analyzing data. """ +"""Utilities for analyzing data.""" diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index fe2e2f3a..068e2ddb 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -3,7 +3,7 @@ import itertools import json from collections import defaultdict -from datetime import datetime, timezone +from datetime import UTC, datetime import datajoint as dj import numpy as np @@ -137,9 +137,7 @@ class BlockAnalysis(dj.Computed): @property def key_source(self): - """ - - Ensure that the chunk ingestion has caught up with this block before processing + """Ensure that the chunk ingestion has caught up with this block before processing (there exists a chunk that ends after the block end time) """ @@ -177,9 +175,7 @@ class Subject(dj.Part): """ def make(self, key): - """ - - Restrict, fetch and aggregate data from different streams to + """Restrict, fetch and aggregate data from different streams to produce intermediate data products at a per-block level @@ -278,7 +274,7 @@ def make(self, key): # log a note and pick the first rate to move forward AnalysisNote.insert1( { - "note_timestamp": datetime.now(timezone.utc), + "note_timestamp": datetime.now(UTC), "note_type": "Multiple patch rates", "note": ( f"Found multiple patch rates for block {key} " @@ -1795,8 +1791,7 @@ class AnalysisNote(dj.Manual): def get_threshold_associated_pellets(patch_key, start, end): - """ - Retrieve the pellet delivery timestamps associated with each patch threshold update + """Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time. diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index 64275bf8..41d39060 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -121,8 +121,7 @@ def make(self, key): def ingest_environment_visits(experiment_names: list | None = None): - """ - Function to populate into `Visit` and `VisitEnd` for specified + """Function to populate into `Visit` and `VisitEnd` for specified experiments (default: 'exp0.2-r0'). diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index bd7a07a1..3028fc10 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -87,8 +87,7 @@ class TimeSlice(dj.Part): @property def key_source(self): - """ - Chunk for all visits: + """Chunk for all visits: + visit_start during this Chunk - i.e. first chunk of the visit + visit_end during this Chunk - i.e. last chunk of the visit @@ -204,8 +203,7 @@ def make(self, key): @classmethod def get_position(cls, visit_key=None, subject=None, start=None, end=None): - """ - Return a Pandas df of the subject's position data for a specified Visit given its key. + """Return a Pandas df of the subject's position data for a specified Visit given its key. Given a key to a single Visit, return a Pandas DataFrame for the position data of the subject for the specified Visit time period. @@ -566,8 +564,7 @@ def make(self, key): @schema class VisitForagingBout(dj.Computed): - """ - Time period from when the animal enters to when it leaves a food patch while moving the wheel. + """Time period from when the animal enters to when it leaves a food patch while moving the wheel. """ definition = """ diff --git a/aeon/dj_pipeline/populate/__init__.py b/aeon/dj_pipeline/populate/__init__.py index ca091e15..7178aec4 100644 --- a/aeon/dj_pipeline/populate/__init__.py +++ b/aeon/dj_pipeline/populate/__init__.py @@ -1 +1 @@ -""" Utilities for the workflow orchestration. """ +"""Utilities for the workflow orchestration.""" diff --git a/aeon/dj_pipeline/populate/worker.py b/aeon/dj_pipeline/populate/worker.py index f7232693..1a1ff302 100644 --- a/aeon/dj_pipeline/populate/worker.py +++ b/aeon/dj_pipeline/populate/worker.py @@ -1,5 +1,4 @@ -""" -This module defines the workers for the AEON pipeline. +"""This module defines the workers for the AEON pipeline. """ import datajoint as dj diff --git a/aeon/dj_pipeline/report.py b/aeon/dj_pipeline/report.py index e2c444c6..92c210e3 100644 --- a/aeon/dj_pipeline/report.py +++ b/aeon/dj_pipeline/report.py @@ -1,4 +1,4 @@ -"""DataJoint schema dedicated for tables containing figures. """ +"""DataJoint schema dedicated for tables containing figures.""" import datetime import json @@ -317,8 +317,7 @@ def make(self, key): @classmethod def delete_outdated_entries(cls): - """ - Dynamically update the plot for all sessions. + """Dynamically update the plot for all sessions. Each entry in this table correspond to one subject. However, the plot is capturing data for all sessions. @@ -371,8 +370,7 @@ def make(self, key): @classmethod def delete_outdated_entries(cls): - """ - Dynamically update the plot for all sessions. + """Dynamically update the plot for all sessions. Each entry in this table correspond to one subject. However the plot is capturing data for all sessions. @@ -422,8 +420,7 @@ def make(self, key): @classmethod def delete_outdated_entries(cls): - """ - Dynamically update the plot for all sessions. + """Dynamically update the plot for all sessions. Each entry in this table correspond to one subject. However the plot is capturing data for all sessions. diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py index 3859c51e..f1645ab3 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py @@ -1,5 +1,4 @@ -""" -March 2022 +"""March 2022 Cloning and archiving schemas and data for experiment 0.1. """ diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py index bc8494a7..87bf8f31 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py @@ -1,5 +1,4 @@ -""" -Jan 2024 +"""Jan 2024 Cloning and archiving schemas and data for experiment 0.2. diff --git a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py index 0946dff9..370929c5 100644 --- a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py +++ b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py @@ -1,5 +1,4 @@ -""" -July 2022 +"""July 2022 Upgrade all timestamps longblob fields with datajoint 0.13.7. """ diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index 6bb69214..78d4bb21 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -3,7 +3,7 @@ import json import os import time -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta import datajoint as dj import requests @@ -195,7 +195,7 @@ def get_reference_weight(cls, subject_name): "procedure_date", order_by="procedure_date DESC", limit=1 )[0] else: - ref_date = datetime.now(timezone.utc).date() + ref_date = datetime.now(UTC).date() weight_query = SubjectWeight & subj_key & f"weight_time < '{ref_date}'" ref_weight = ( @@ -207,7 +207,7 @@ def get_reference_weight(cls, subject_name): entry = { "subject": subject_name, "reference_weight": ref_weight, - "last_updated_time": datetime.now(timezone.utc), + "last_updated_time": datetime.now(UTC), } cls.update1(entry) if cls & {"subject": subject_name} else cls.insert1(entry) @@ -250,7 +250,7 @@ class PyratIngestion(dj.Imported): def _auto_schedule(self): """Automatically schedule the next task.""" - utc_now = datetime.now(timezone.utc) + utc_now = datetime.now(UTC) next_task_schedule_time = utc_now + timedelta(hours=self.schedule_interval) if ( @@ -265,7 +265,7 @@ def _auto_schedule(self): def make(self, key): """Automatically import or update entries in the Subject table.""" - execution_time = datetime.now(timezone.utc) + execution_time = datetime.now(UTC) new_eartags = [] for responsible_id in lab.User.fetch("responsible_id"): # 1 - retrieve all animals from this user @@ -301,7 +301,7 @@ def make(self, key): new_entry_count += 1 logger.info(f"Inserting {new_entry_count} new subject(s) from Pyrat") - completion_time = datetime.now(timezone.utc) + completion_time = datetime.now(UTC) self.insert1( { **key, @@ -334,7 +334,7 @@ class PyratCommentWeightProcedure(dj.Imported): def make(self, key): """Automatically import or update entries in the PyratCommentWeightProcedure table.""" - execution_time = datetime.now(timezone.utc) + execution_time = datetime.now(UTC) logger.info("Extracting weights/comments/procedures") eartag_or_id = key["subject"] @@ -377,7 +377,7 @@ def make(self, key): # compute/update reference weight SubjectReferenceWeight.get_reference_weight(eartag_or_id) finally: - completion_time = datetime.now(timezone.utc) + completion_time = datetime.now(UTC) self.insert1( { **key, @@ -398,7 +398,7 @@ class CreatePyratIngestionTask(dj.Computed): def make(self, key): """Create one new PyratIngestionTask for every newly added users.""" PyratIngestionTask.insert1( - {"pyrat_task_scheduled_time": datetime.now(timezone.utc)} + {"pyrat_task_scheduled_time": datetime.now(UTC)} ) time.sleep(1) self.insert1(key) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index de0759c7..f6d298eb 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -120,8 +120,7 @@ def insert_new_params( @schema class SLEAPTracking(dj.Imported): - """ - Tracking data from SLEAP for multi-animal experiments. + """Tracking data from SLEAP for multi-animal experiments. Tracked objects position data from a particular VideoSource for multi-animal experiment using the SLEAP tracking @@ -282,8 +281,7 @@ def is_position_in_patch( def is_position_in_nest(position_df, nest_key, xcol="x", ycol="y") -> pd.Series: - """ - Check if a position is inside the nest. + """Check if a position is inside the nest. Notes: Given the session key and the position data - arrays of x and y return an array of boolean indicating whether or not a position is inside the nest. diff --git a/aeon/dj_pipeline/utils/__init__.py b/aeon/dj_pipeline/utils/__init__.py index 82bbb4bd..0bb46925 100644 --- a/aeon/dj_pipeline/utils/__init__.py +++ b/aeon/dj_pipeline/utils/__init__.py @@ -1 +1 @@ -""" Helper functions and utilities for the Aeon project. """ +"""Helper functions and utilities for the Aeon project.""" diff --git a/aeon/dj_pipeline/utils/load_metadata.py b/aeon/dj_pipeline/utils/load_metadata.py index 2046c072..9d055fb7 100644 --- a/aeon/dj_pipeline/utils/load_metadata.py +++ b/aeon/dj_pipeline/utils/load_metadata.py @@ -1,5 +1,4 @@ -""" -Load metadata from the experiment and insert into streams schema. +"""Load metadata from the experiment and insert into streams schema. """ import datetime @@ -50,9 +49,7 @@ def insert_stream_types(): def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): - """ - - Insert device types into streams.DeviceType and streams.Device. + """Insert device types into streams.DeviceType and streams.Device. Notes: Use aeon.schema.schemas and metadata.yml to insert into streams.DeviceType and streams.Device. Only insert device types that were defined both in the device schema (e.g., exp02) and Metadata.yml. @@ -474,8 +471,7 @@ def _get_class_path(obj): def get_device_mapper(devices_schema: DotMap, metadata_yml_filepath: Path): - """ - Returns a mapping dictionary of device names to types based on the dataset schema and metadata.yml. + """Returns a mapping dictionary of device names to types based on the dataset schema and metadata.yml. Notes: Returns a mapping dictionary between device name and device type based on the dataset schema and metadata.yml from the experiment. diff --git a/aeon/dj_pipeline/utils/paths.py b/aeon/dj_pipeline/utils/paths.py index d2f9bb75..d22c00cb 100644 --- a/aeon/dj_pipeline/utils/paths.py +++ b/aeon/dj_pipeline/utils/paths.py @@ -1,5 +1,4 @@ -""" -Utility functions for working with paths in the context of the DJ pipeline. +"""Utility functions for working with paths in the context of the DJ pipeline. """ from __future__ import annotations diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index c745c7e0..792f22ca 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -23,8 +23,7 @@ class StreamType(dj.Lookup): - """ - Catalog of all steam types for the different device types used across + """Catalog of all steam types for the different device types used across Project Aeon. One StreamType corresponds to one reader class in `aeon.io.reader`. @@ -156,8 +155,7 @@ class DeviceDataStream(dj.Imported): @property def key_source(self): - """ - Only the combination of Chunk and device_type with overlapping time + """Only the combination of Chunk and device_type with overlapping time + Chunk(s) that started after device_type install time and ended before device_type remove time diff --git a/aeon/io/__init__.py b/aeon/io/__init__.py index f481ec8e..e23efc36 100644 --- a/aeon/io/__init__.py +++ b/aeon/io/__init__.py @@ -1 +1 @@ -""" Utilities for I/O operations. """ +"""Utilities for I/O operations.""" diff --git a/aeon/io/device.py b/aeon/io/device.py index 56d3dacb..8e473662 100644 --- a/aeon/io/device.py +++ b/aeon/io/device.py @@ -1,4 +1,4 @@ -""" Deprecated Device class for grouping multiple Readers into a logical device. """ +"""Deprecated Device class for grouping multiple Readers into a logical device.""" import inspect diff --git a/aeon/io/reader.py b/aeon/io/reader.py index 4d598db5..999f2acb 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -1,4 +1,4 @@ -""" Module for reading data from raw files in an Aeon dataset.""" +"""Module for reading data from raw files in an Aeon dataset.""" from __future__ import annotations @@ -159,8 +159,7 @@ def read(self, file): class JsonList(Reader): - """ - Extracts data from json list (.jsonl) files, + """Extracts data from json list (.jsonl) files, where the key "seconds" stores the Aeon timestamp, in seconds. """ @@ -173,7 +172,7 @@ def __init__(self, pattern, columns=(), root_key="value", extension="jsonl"): def read(self, file): """Reads data from the specified jsonl file.""" - with open(file, "r") as f: + with open(file) as f: df = pd.read_json(f, lines=True) df.set_index("seconds", inplace=True) for column in self.columns: @@ -348,8 +347,7 @@ def read(self, file): class Pose(Harp): - """ - Reader for Harp-binarized tracking data given a model + """Reader for Harp-binarized tracking data given a model that outputs id, parts, and likelihoods. diff --git a/aeon/schema/__init__.py b/aeon/schema/__init__.py index 3de266c2..bbce21ee 100644 --- a/aeon/schema/__init__.py +++ b/aeon/schema/__init__.py @@ -1 +1 @@ -""" Utilities for the schemas. """ +"""Utilities for the schemas.""" diff --git a/aeon/schema/dataset.py b/aeon/schema/dataset.py index 225cf335..b187eb8b 100644 --- a/aeon/schema/dataset.py +++ b/aeon/schema/dataset.py @@ -1,4 +1,4 @@ -""" Dataset schema definitions. """ +"""Dataset schema definitions.""" from dotmap import DotMap diff --git a/aeon/schema/octagon.py b/aeon/schema/octagon.py index 351bef31..ac121dbe 100644 --- a/aeon/schema/octagon.py +++ b/aeon/schema/octagon.py @@ -1,4 +1,4 @@ -""" Octagon schema definition. """ +"""Octagon schema definition.""" import aeon.io.reader as _reader from aeon.schema.streams import Stream, StreamGroup diff --git a/aeon/schema/schemas.py b/aeon/schema/schemas.py index 06f8598c..1d37b1d1 100644 --- a/aeon/schema/schemas.py +++ b/aeon/schema/schemas.py @@ -1,4 +1,4 @@ -""" Schemas for different experiments. """ +"""Schemas for different experiments.""" from dotmap import DotMap diff --git a/aeon/schema/social_01.py b/aeon/schema/social_01.py index 4edaec9f..f7f44b80 100644 --- a/aeon/schema/social_01.py +++ b/aeon/schema/social_01.py @@ -1,4 +1,4 @@ -""" This module contains the schema for the social_01 dataset. """ +"""This module contains the schema for the social_01 dataset.""" import aeon.io.reader as _reader from aeon.schema.streams import Stream diff --git a/aeon/schema/social_02.py b/aeon/schema/social_02.py index 0564599f..c3b64f8d 100644 --- a/aeon/schema/social_02.py +++ b/aeon/schema/social_02.py @@ -1,4 +1,4 @@ -""" This module defines the schema for the social_02 dataset. """ +"""This module defines the schema for the social_02 dataset.""" import aeon.io.reader as _reader from aeon.schema import core, foraging diff --git a/aeon/schema/social_03.py b/aeon/schema/social_03.py index 6206f0f9..e1f624f5 100644 --- a/aeon/schema/social_03.py +++ b/aeon/schema/social_03.py @@ -1,4 +1,4 @@ -""" This module contains the schema for the social_03 dataset. """ +"""This module contains the schema for the social_03 dataset.""" import aeon.io.reader as _reader from aeon.schema.streams import Stream diff --git a/aeon/schema/streams.py b/aeon/schema/streams.py index 2c1cb94a..0269a2f6 100644 --- a/aeon/schema/streams.py +++ b/aeon/schema/streams.py @@ -1,4 +1,4 @@ -""" Contains classes for defining data streams and devices. """ +"""Contains classes for defining data streams and devices.""" import inspect from itertools import chain From 71ac75fc0dde26bf720cae7a9440c73d3579fc35 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 01:28:14 +0000 Subject: [PATCH 042/143] fix: another round: all checks passed for `aeon/analysis` --- aeon/analysis/block_plotting.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index 92a2dca1..33c52663 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -29,13 +29,15 @@ def gen_hex_grad(hex_col, vals, min_l=0.3): """Generates an array of hex color values based on a gradient defined by unit-normalized values.""" # Convert hex to rgb to hls - h, l, s = rgb_to_hls( + h, ll, s = rgb_to_hls( *[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)] - ) # noqa: E741 + ) grad = np.empty(shape=(len(vals),), dtype=" Date: Tue, 29 Oct 2024 01:53:22 +0000 Subject: [PATCH 043/143] fix: another round: all checks passed for `aeon/dj_pipeline` --- aeon/dj_pipeline/analysis/block_analysis.py | 22 +++------ aeon/dj_pipeline/analysis/visit.py | 4 +- aeon/dj_pipeline/analysis/visit_analysis.py | 20 ++++---- .../create_experiment_02.py | 9 ++-- .../create_experiments/create_octagon_1.py | 9 ++-- .../create_experiments/create_presocial.py | 12 +++-- .../create_socialexperiment.py | 12 +++-- .../create_socialexperiment_0.py | 4 +- aeon/dj_pipeline/lab.py | 2 +- aeon/dj_pipeline/populate/process.py | 6 +-- aeon/dj_pipeline/populate/worker.py | 3 +- aeon/dj_pipeline/qc.py | 4 +- aeon/dj_pipeline/report.py | 8 +-- .../scripts/clone_and_freeze_exp01.py | 7 +-- .../scripts/clone_and_freeze_exp02.py | 6 +-- .../scripts/update_timestamps_longblob.py | 9 ++-- aeon/dj_pipeline/utils/load_metadata.py | 11 +++-- aeon/dj_pipeline/utils/paths.py | 9 ++-- aeon/dj_pipeline/utils/plotting.py | 10 ++-- aeon/dj_pipeline/utils/streams_maker.py | 49 +++++++------------ 20 files changed, 94 insertions(+), 122 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 068e2ddb..bcb89e4d 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -137,10 +137,7 @@ class BlockAnalysis(dj.Computed): @property def key_source(self): - """Ensure that the chunk ingestion has caught up with this block before processing - - (there exists a chunk that ends after the block end time) - """ + """Ensure that the chunk ingestion has caught up with this block before processing (there exists a chunk that ends after the block end time).""" # noqa 501 ks = Block.aggr(acquisition.Chunk, latest_chunk_end="MAX(chunk_end)") ks = ks * Block & "latest_chunk_end >= block_end" & "block_end IS NOT NULL" return ks @@ -175,17 +172,14 @@ class Subject(dj.Part): """ def make(self, key): - """Restrict, fetch and aggregate data from different streams to - - produce intermediate data products at a per-block level - - (for different patches and different subjects). + """ + Restrict, fetch and aggregate data from different streams to produce intermediate data products at a per-block level (for different patches and different subjects). 1. Query data for all chunks within the block. 2. Fetch streams, filter by maintenance period. 3. Fetch subject position data (SLEAP). 4. Aggregate and insert into the table. - """ + """ # noqa 501 block_start, block_end = (Block & key).fetch1("block_start", "block_end") chunk_restriction = acquisition.create_chunk_restriction( @@ -1791,12 +1785,8 @@ class AnalysisNote(dj.Manual): def get_threshold_associated_pellets(patch_key, start, end): - """Retrieve the pellet delivery timestamps associated with each patch threshold update - - within the specified start-end time. + """Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time. - - Notes: 1. Get all patch state update timestamps (DepletionState): let's call these events "A" - Remove all events within 1 second of each other - Remove all events without threshold value (NaN) @@ -1821,7 +1811,7 @@ def get_threshold_associated_pellets(patch_key, start, end): - beam_break_timestamp - offset - rate - """ + """ # noqa 501 chunk_restriction = acquisition.create_chunk_restriction( patch_key["experiment_name"], start, end ) diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index 41d39060..5266c5a9 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -121,9 +121,7 @@ def make(self, key): def ingest_environment_visits(experiment_names: list | None = None): - """Function to populate into `Visit` and `VisitEnd` for specified - - experiments (default: 'exp0.2-r0'). + """Function to populate into `Visit` and `VisitEnd` for specified experiments (default: 'exp0.2-r0'). This ingestion routine handles only those "complete" visits, not ingesting any "on-going" visits using "analyze" method: diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 3028fc10..b03ac593 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -70,8 +70,7 @@ class VisitSubjectPosition(dj.Computed): """ class TimeSlice(dj.Part): - definition = """ - # A short time-slice (e.g. 10 minutes) of the recording of a given animal for a visit + definition = """ # A short time-slice (e.g. 10min) of the recording of a given animal for a visit -> master time_slice_start: datetime(6) # datetime of the start of this time slice --- @@ -87,7 +86,7 @@ class TimeSlice(dj.Part): @property def key_source(self): - """Chunk for all visits: + """Chunk for all visits as the following conditions. + visit_start during this Chunk - i.e. first chunk of the visit + visit_end during this Chunk - i.e. last chunk of the visit @@ -107,7 +106,7 @@ def key_source(self): ) def make(self, key): - """Populate VisitSubjectPosition for each visit""" + """Populate VisitSubjectPosition for each visit.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -221,9 +220,9 @@ def get_position(cls, visit_key=None, subject=None, start=None, end=None): ).fetch1("visit_start", "visit_end") subject = visit_key["subject"] elif all((subject, start, end)): - start = start - end = end - subject = subject + start = start # noqa PLW0127 + end = end # noqa PLW0127 + subject = subject # noqa PLW0127 else: raise ValueError( 'Either "visit_key" or all three "subject", "start" and "end" has to be specified' @@ -283,7 +282,7 @@ class FoodPatch(dj.Part): ) def make(self, key): - """Populate VisitTimeDistribution for each visit""" + """Populate VisitTimeDistribution for each visit.""" visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) @@ -439,7 +438,7 @@ class FoodPatch(dj.Part): ) def make(self, key): - """Populate VisitSummary for each visit""" + """Populate VisitSummary for each visit.""" visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) @@ -564,8 +563,7 @@ def make(self, key): @schema class VisitForagingBout(dj.Computed): - """Time period from when the animal enters to when it leaves a food patch while moving the wheel. - """ + """Time period from when the animal enters to when it leaves a food patch while moving the wheel.""" definition = """ -> Visit diff --git a/aeon/dj_pipeline/create_experiments/create_experiment_02.py b/aeon/dj_pipeline/create_experiments/create_experiment_02.py index c5aead5b..f14b0342 100644 --- a/aeon/dj_pipeline/create_experiments/create_experiment_02.py +++ b/aeon/dj_pipeline/create_experiments/create_experiment_02.py @@ -1,4 +1,4 @@ -"""Function to create new experiments for experiment0.2""" +"""Function to create new experiments for experiment0.2.""" from aeon.dj_pipeline import acquisition, lab, subject @@ -8,7 +8,7 @@ def create_new_experiment(): - """Create new experiment for experiment0.2""" + """Create new experiment for experiment0.2.""" # ---------------- Subject ----------------- subject_list = [ {"subject": "BAA-1100699", "sex": "U", "subject_birth_date": "2021-01-01"}, @@ -33,7 +33,10 @@ def create_new_experiment(): skip_duplicates=True, ) acquisition.Experiment.Subject.insert( - [{"experiment_name": experiment_name, "subject": s["subject"]} for s in subject_list], + [ + {"experiment_name": experiment_name, "subject": s["subject"]} + for s in subject_list + ], skip_duplicates=True, ) diff --git a/aeon/dj_pipeline/create_experiments/create_octagon_1.py b/aeon/dj_pipeline/create_experiments/create_octagon_1.py index 1d95e1d5..98b13b41 100644 --- a/aeon/dj_pipeline/create_experiments/create_octagon_1.py +++ b/aeon/dj_pipeline/create_experiments/create_octagon_1.py @@ -1,4 +1,4 @@ -"""Function to create new experiments for octagon1.0""" +"""Function to create new experiments for octagon1.0.""" from aeon.dj_pipeline import acquisition, subject @@ -8,7 +8,7 @@ def create_new_experiment(): - """Create new experiment for octagon1.0""" + """Create new experiment for octagon1.0.""" # ---------------- Subject ----------------- # This will get replaced by content from colony.csv subject_list = [ @@ -36,7 +36,10 @@ def create_new_experiment(): skip_duplicates=True, ) acquisition.Experiment.Subject.insert( - [{"experiment_name": experiment_name, "subject": s["subject"]} for s in subject_list], + [ + {"experiment_name": experiment_name, "subject": s["subject"]} + for s in subject_list + ], skip_duplicates=True, ) diff --git a/aeon/dj_pipeline/create_experiments/create_presocial.py b/aeon/dj_pipeline/create_experiments/create_presocial.py index 4df4453b..0a60e59d 100644 --- a/aeon/dj_pipeline/create_experiments/create_presocial.py +++ b/aeon/dj_pipeline/create_experiments/create_presocial.py @@ -1,4 +1,4 @@ -"""Function to create new experiments for presocial0.1""" +"""Function to create new experiments for presocial0.1.""" from aeon.dj_pipeline import acquisition, lab, subject @@ -9,10 +9,12 @@ def create_new_experiment(): - """Create new experiments for presocial0.1""" + """Create new experiments for presocial0.1.""" lab.Location.insert1({"lab": "SWC", "location": location}, skip_duplicates=True) - acquisition.ExperimentType.insert1({"experiment_type": experiment_type}, skip_duplicates=True) + acquisition.ExperimentType.insert1( + {"experiment_type": experiment_type}, skip_duplicates=True + ) acquisition.Experiment.insert( [ @@ -47,7 +49,9 @@ def create_new_experiment(): "directory_type": "raw", "directory_path": f"aeon/data/raw/{computer}/{experiment_type}", } - for experiment_name, computer in zip(experiment_names, computers, strict=False) + for experiment_name, computer in zip( + experiment_names, computers, strict=False + ) ], skip_duplicates=True, ) diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py index 57540085..4b90b018 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py @@ -1,4 +1,4 @@ -"""Function to create new social experiments""" +"""Function to create new social experiments.""" from datetime import datetime @@ -17,7 +17,7 @@ def create_new_social_experiment(experiment_name): - """Create new social experiment""" + """Create new social experiment.""" exp_name, machine_name = experiment_name.split("-") raw_dir = ceph_data_dir / "raw" / machine_name.upper() / exp_name if not raw_dir.exists(): @@ -39,7 +39,9 @@ def create_new_social_experiment(experiment_name): "experiment_name": experiment_name, "repository_name": "ceph_aeon", "directory_type": dir_type, - "directory_path": (ceph_data_dir / dir_type / machine_name.upper() / exp_name) + "directory_path": ( + ceph_data_dir / dir_type / machine_name.upper() / exp_name + ) .relative_to(ceph_dir) .as_posix(), "load_order": load_order, @@ -52,7 +54,9 @@ def create_new_social_experiment(experiment_name): new_experiment_entry, skip_duplicates=True, ) - acquisition.Experiment.Directory.insert(experiment_directories, skip_duplicates=True) + acquisition.Experiment.Directory.insert( + experiment_directories, skip_duplicates=True + ) acquisition.Experiment.DevicesSchema.insert1( { "experiment_name": experiment_name, diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py index 3623ce4f..cc69ced4 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py @@ -1,4 +1,4 @@ -"""Function to create new experiments for social0-r1""" +"""Function to create new experiments for social0-r1.""" import pathlib @@ -12,7 +12,7 @@ def create_new_experiment(): - """Create new experiments for social0-r1""" + """Create new experiments for social0-r1.""" # ---------------- Subject ----------------- subject_list = [ {"subject": "BAA-1100704", "sex": "U", "subject_birth_date": "2021-01-01"}, diff --git a/aeon/dj_pipeline/lab.py b/aeon/dj_pipeline/lab.py index 9f0966b2..b5a4c3c5 100644 --- a/aeon/dj_pipeline/lab.py +++ b/aeon/dj_pipeline/lab.py @@ -92,7 +92,7 @@ class ArenaShape(dj.Lookup): @schema class Arena(dj.Lookup): - """Coordinate frame convention: + """Coordinate frame convention as the following items. + x-dimension: x=0 is the left most point of the bounding box of the arena + y-dimension: y=0 is the top most point of the bounding box of the arena diff --git a/aeon/dj_pipeline/populate/process.py b/aeon/dj_pipeline/populate/process.py index d9dbd808..ae02eef2 100644 --- a/aeon/dj_pipeline/populate/process.py +++ b/aeon/dj_pipeline/populate/process.py @@ -1,11 +1,11 @@ """Start an Aeon ingestion process. This script defines auto-processing routines to operate the DataJoint pipeline -for the Aeon project. Three separate "process" functions are defined to call -`populate()` for different groups of tables, depending on their priority in +for the Aeon project. Three separate "process" functions are defined to call +`populate()` for different groups of tables, depending on their priority in the ingestion routines (high, mid, low). -Each process function is run in a while-loop with the total run-duration configurable +Each process function is run in a while-loop with the total run-duration configurable via command line argument '--duration' (if not set, runs perpetually) - the loop will not begin a new cycle after this period of time (in seconds) - the loop will run perpetually if duration<0 or if duration==None diff --git a/aeon/dj_pipeline/populate/worker.py b/aeon/dj_pipeline/populate/worker.py index 1a1ff302..835b6610 100644 --- a/aeon/dj_pipeline/populate/worker.py +++ b/aeon/dj_pipeline/populate/worker.py @@ -1,5 +1,4 @@ -"""This module defines the workers for the AEON pipeline. -""" +"""This module defines the workers for the AEON pipeline.""" import datajoint as dj from datajoint_utilities.dj_worker import ( diff --git a/aeon/dj_pipeline/qc.py b/aeon/dj_pipeline/qc.py index cc3b23b3..50f493fe 100644 --- a/aeon/dj_pipeline/qc.py +++ b/aeon/dj_pipeline/qc.py @@ -56,7 +56,7 @@ class CameraQC(dj.Imported): @property def key_source(self): - """Return the keys for the CameraQC table""" + """Return the keys for the CameraQC table.""" return ( acquisition.Chunk * ( @@ -70,7 +70,7 @@ def key_source(self): ) # CameraTop def make(self, key): - """Perform quality control checks on the CameraTop stream""" + """Perform quality control checks on the CameraTop stream.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) diff --git a/aeon/dj_pipeline/report.py b/aeon/dj_pipeline/report.py index 92c210e3..dc377176 100644 --- a/aeon/dj_pipeline/report.py +++ b/aeon/dj_pipeline/report.py @@ -44,7 +44,7 @@ class InArenaSummaryPlot(dj.Computed): } def make(self, key): - """Make method for InArenaSummaryPlot table""" + """Make method for InArenaSummaryPlot table.""" in_arena_start, in_arena_end = ( analysis.InArena * analysis.InArenaEnd & key ).fetch1("in_arena_start", "in_arena_end") @@ -475,7 +475,7 @@ class VisitDailySummaryPlot(dj.Computed): ) def make(self, key): - """Make method for VisitDailySummaryPlot table""" + """Make method for VisitDailySummaryPlot table.""" from aeon.dj_pipeline.utils.plotting import ( plot_foraging_bouts_count, plot_foraging_bouts_distribution, @@ -575,7 +575,7 @@ def make(self, key): def _make_path(in_arena_key): - """Make path for saving figures""" + """Make path for saving figures.""" store_stage = pathlib.Path(dj.config["stores"]["djstore"]["stage"]) experiment_name, subject, in_arena_start = (analysis.InArena & in_arena_key).fetch1( "experiment_name", "subject", "in_arena_start" @@ -591,7 +591,7 @@ def _make_path(in_arena_key): def _save_figs(figs, fig_names, save_dir, prefix, extension=".png"): - """Save figures and return a dictionary with figure names and file paths""" + """Save figures and return a dictionary with figure names and file paths.""" fig_dict = {} for fig, figname in zip(figs, fig_names, strict=False): fig_fp = save_dir / (prefix + "_" + figname + extension) diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py index f1645ab3..34ee1878 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py @@ -1,7 +1,4 @@ -"""March 2022 - -Cloning and archiving schemas and data for experiment 0.1. -""" +"""March 2022. Cloning and archiving schemas and data for experiment 0.1.""" import os @@ -35,7 +32,7 @@ def clone_pipeline(): - """Clone the pipeline for experiment 0.1""" + """Clone the pipeline for experiment 0.1.""" diagram = None for orig_schema_name in schema_name_mapper: virtual_module = dj.create_virtual_module(orig_schema_name, orig_schema_name) diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py index 87bf8f31..5ba5e5ce 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py @@ -1,6 +1,4 @@ -"""Jan 2024 - -Cloning and archiving schemas and data for experiment 0.2. +"""Jan 2024. Cloning and archiving schemas and data for experiment 0.2. The pipeline code associated with this archived data pipeline is here https://github.com/SainsburyWellcomeCentre/aeon_mecha/releases/tag/dj_exp02_stable @@ -40,7 +38,7 @@ def clone_pipeline(): - """Clone the pipeline for experiment 0.2""" + """Clone the pipeline for experiment 0.2.""" diagram = None for orig_schema_name in schema_name_mapper: virtual_module = dj.create_virtual_module(orig_schema_name, orig_schema_name) diff --git a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py index 370929c5..47aa6134 100644 --- a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py +++ b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py @@ -1,16 +1,13 @@ -"""July 2022 - -Upgrade all timestamps longblob fields with datajoint 0.13.7. -""" +"""July 2022. Upgrade all timestamps longblob fields with datajoint 0.13.7.""" from datetime import datetime import datajoint as dj +import numpy as np +from tqdm import tqdm logger = dj.logger -import numpy as np -from tqdm import tqdm if dj.__version__ < "0.13.7": raise ImportError( diff --git a/aeon/dj_pipeline/utils/load_metadata.py b/aeon/dj_pipeline/utils/load_metadata.py index 9d055fb7..1d392c80 100644 --- a/aeon/dj_pipeline/utils/load_metadata.py +++ b/aeon/dj_pipeline/utils/load_metadata.py @@ -1,5 +1,4 @@ -"""Load metadata from the experiment and insert into streams schema. -""" +"""Load metadata from the experiment and insert into streams schema.""" import datetime import inspect @@ -265,9 +264,11 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath } ) - """Check if this device is currently installed. If the same - device serial number is currently installed check for any - changes in configuration. If not, skip this""" + """ + Check if this device is currently installed. + If the same device serial number is currently installed check for changes in configuration. + If not, skip this. + """ current_device_query = ( table - table.RemovalTime & experiment_key & device_key ) diff --git a/aeon/dj_pipeline/utils/paths.py b/aeon/dj_pipeline/utils/paths.py index d22c00cb..1677d5f3 100644 --- a/aeon/dj_pipeline/utils/paths.py +++ b/aeon/dj_pipeline/utils/paths.py @@ -1,5 +1,4 @@ -"""Utility functions for working with paths in the context of the DJ pipeline. -""" +"""Utility functions for working with paths in the context of the DJ pipeline.""" from __future__ import annotations @@ -35,9 +34,7 @@ def get_repository_path(repository_name: str) -> pathlib.Path: def find_root_directory( root_directories: str | pathlib.Path, full_path: str | pathlib.Path ) -> pathlib.Path: - """Given multiple potential root directories and a full-path, - - search and return one directory that is the parent of the given path. + """Given multiple potential root directories and a full-path, search and return one directory that is the parent of the given path. Args: root_directories (str | pathlib.Path): A list of potential root directories. @@ -49,7 +46,7 @@ def find_root_directory( Returns: pathlib.Path: The full path to the discovered root directory. - """ + """ # noqa E501 full_path = pathlib.Path(full_path) if not full_path.exists(): diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index af1c475a..c7273135 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -25,9 +25,7 @@ def plot_reward_rate_differences(subject_keys): - """Plotting the reward rate differences between food patches - - (Patch 2 - Patch 1) for all sessions from all subjects specified in "subject_keys". + """Plotting the reward rate differences between food patches (Patch 2 - Patch 1) for all sessions from all subjects specified in "subject_keys". Examples: ``` @@ -36,7 +34,7 @@ def plot_reward_rate_differences(subject_keys): fig = plot_reward_rate_differences(subject_keys) ``` - """ + """ # noqa E501 subj_names, sess_starts, rate_timestamps, rate_diffs = ( analysis.InArenaRewardRate & subject_keys ).fetch( @@ -83,9 +81,7 @@ def plot_reward_rate_differences(subject_keys): def plot_wheel_travelled_distance(session_keys): - """Plotting the wheel travelled distance for different patches - - for all sessions specified in "session_keys". + """Plot wheel-travelled-distance for different patches for all sessions specified in session_keys. Examples: ``` diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index 792f22ca..c8b57ce4 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -23,25 +23,23 @@ class StreamType(dj.Lookup): - """Catalog of all steam types for the different device types used across + """Catalog of all stream types used across Project Aeon. - Project Aeon. One StreamType corresponds to one reader class in `aeon.io.reader`. - - The combination of `stream_reader` and `stream_reader_kwargs` should fully - - specify the data loading routine for a particular device, using the `aeon.io.utils`. + Catalog of all steam types for the different device types used across Project Aeon. + One StreamType corresponds to one reader class in `aeon.io.reader`.The + combination of `stream_reader` and `stream_reader_kwargs` should fully specify the data + loading routine for a particular device, using the `aeon.io.utils`. """ - definition = """ # Catalog of all stream types used across Project Aeon + definition = """ # Catalog of all stream types used across Project Aeon stream_type : varchar(20) --- - stream_reader : varchar(256) # name of the reader class found in `aeon_mecha` - # package (e.g. aeon.io.reader.Video) + stream_reader : varchar(256) # name of the reader class found in `aeon_mecha` package (e.g. aeon.io.reader.Video) stream_reader_kwargs : longblob # keyword arguments to instantiate the reader class stream_description='': varchar(256) stream_hash : uuid # hash of dict(stream_reader_kwargs, stream_reader=stream_reader) unique index (stream_hash) - """ + """ # noqa: E501 class DeviceType(dj.Lookup): @@ -77,26 +75,22 @@ def get_device_template(device_type: str): device_type = dj.utils.from_camel_case(device_type) class ExperimentDevice(dj.Manual): - definition = f""" - # {device_title} placement and operation for a particular time period, - # at a certain location, for a given experiment (auto-generated with - # aeon_mecha-{aeon.__version__}) + definition = f""" # {device_title} placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-{aeon.__version__}) -> acquisition.Experiment -> Device {device_type}_install_time : datetime(6) # time of the {device_type} placed # and started operation at this position --- {device_type}_name : varchar(36) - """ + """ # noqa: E501 class Attribute(dj.Part): - definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) - # associated with this experimental device + definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device -> master attribute_name : varchar(32) --- attribute_value=null : longblob - """ + """ # noqa: E501 class RemovalTime(dj.Part): definition = f""" @@ -124,7 +118,7 @@ def get_device_stream_template(device_type: str, stream_type: str, streams_modul ).fetch1() for i, n in enumerate(stream_detail["stream_reader"].split(".")): - reader = aeon if i == 0 else getattr(reader, n) + reader = aeon if i == 0 else getattr(reader, n) # noqa: F821 if reader is aeon.io.reader.Pose: logger.warning( @@ -134,15 +128,13 @@ def get_device_stream_template(device_type: str, stream_type: str, streams_modul stream = reader(**stream_detail["stream_reader_kwargs"]) - table_definition = f""" - # Raw per-chunk {stream_type} data stream from {device_type} - # (auto-generated with aeon_mecha-{aeon.__version__}) + table_definition = f""" # Raw per-chunk {stream_type} data stream from {device_type} (auto-generated with aeon_mecha-{aeon.__version__}) -> {device_type} -> acquisition.Chunk --- sample_count: int # number of data points acquired from this stream for a given chunk timestamps: longblob # (datetime) timestamps of {stream_type} data - """ + """ # noqa: E501 for col in stream.columns: if col.startswith("_"): @@ -155,10 +147,9 @@ class DeviceDataStream(dj.Imported): @property def key_source(self): - """Only the combination of Chunk and device_type with overlapping time + """Only the combination of Chunk and device_type with overlapping time. + Chunk(s) that started after device_type install time and ended before device_type remove time - + Chunk(s) that started after device_type install time for device_type that are not yet removed """ key_source_query = ( @@ -304,12 +295,8 @@ def main(create_tables=True): 'f"chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time"': ( f"'chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time'" ), - """f'chunk_start < - IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, - "2200-01-01")'""": ( - f"""'chunk_start < - IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, - "2200-01-01")'""" + """f'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, "2200-01-01")'""": ( # noqa: E501 + f"""'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time,"2200-01-01")'""" # noqa: W291, E501 ), 'f"{dj.utils.from_camel_case(device_type)}_name"': ( f"'{dj.utils.from_camel_case(device_type)}_name'" From ae9bc98ef33b572d921d3e8f39c6b570772d04aa Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 01:59:35 +0000 Subject: [PATCH 044/143] fix: another round: all checks passed for `aeon/io` - code refactored --- aeon/io/reader.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/aeon/io/reader.py b/aeon/io/reader.py index 999f2acb..9cb08c01 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -159,10 +159,7 @@ def read(self, file): class JsonList(Reader): - """Extracts data from json list (.jsonl) files, - - where the key "seconds" stores the Aeon timestamp, in seconds. - """ + """Extracts data from .jsonl files, where the key "seconds" stores the Aeon timestamp (s).""" def __init__(self, pattern, columns=(), root_key="value", extension="jsonl"): """Initialize the object with the specified pattern, columns, and root key.""" @@ -176,7 +173,7 @@ def read(self, file): df = pd.read_json(f, lines=True) df.set_index("seconds", inplace=True) for column in self.columns: - df[column] = df[self.root_key].apply(lambda x: x[column]) + df[column] = df[self.root_key].apply(lambda x, col=column: x[col]) return df @@ -347,9 +344,7 @@ def read(self, file): class Pose(Harp): - """Reader for Harp-binarized tracking data given a model - - that outputs id, parts, and likelihoods. + """Reader for Harp-binarized tracking data given a model that outputs id, parts, and likelihoods. Columns: class (int): Int ID of a subject in the environment. From d0cdee994975a4a5564ba01ead0ce00186f0caa5 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 02:08:02 +0000 Subject: [PATCH 045/143] fix: another round: all checks passed for `aeon/tests` --- tests/dj_pipeline/conftest.py | 11 +++--- tests/dj_pipeline/test_acquisition.py | 31 ++++++++++++----- .../test_pipeline_instantiation.py | 34 ++++++++++++------- tests/dj_pipeline/test_qc.py | 2 +- tests/dj_pipeline/test_tracking.py | 14 ++++++-- tests/io/test_api.py | 22 ++++++++---- 6 files changed, 79 insertions(+), 35 deletions(-) diff --git a/tests/dj_pipeline/conftest.py b/tests/dj_pipeline/conftest.py index 1cff0de1..da43d891 100644 --- a/tests/dj_pipeline/conftest.py +++ b/tests/dj_pipeline/conftest.py @@ -48,8 +48,7 @@ def test_params(): @pytest.fixture(autouse=True, scope="session") def dj_config(): - """ - Configures DataJoint connection and loads custom settings. + """Configures DataJoint connection and loads custom settings. This fixture sets up the DataJoint configuration using the 'dj_local_conf.json' file. It raises FileNotFoundError if the file @@ -58,12 +57,16 @@ def dj_config(): """ dj_config_fp = pathlib.Path("dj_local_conf.json") if not dj_config_fp.exists(): - raise FileNotFoundError(f"DataJoint configuration file not found: {dj_config_fp}") + raise FileNotFoundError( + f"DataJoint configuration file not found: {dj_config_fp}" + ) dj.config.load(dj_config_fp) dj.config["safemode"] = False if "custom" not in dj.config: raise KeyError("'custom' not found in DataJoint configuration.") - dj.config["custom"]["database.prefix"] = f"u_{dj.config['database.user']}_testsuite_" + dj.config["custom"][ + "database.prefix" + ] = f"u_{dj.config['database.user']}_testsuite_" def load_pipeline(): diff --git a/tests/dj_pipeline/test_acquisition.py b/tests/dj_pipeline/test_acquisition.py index ce956b3e..fce754ad 100644 --- a/tests/dj_pipeline/test_acquisition.py +++ b/tests/dj_pipeline/test_acquisition.py @@ -1,4 +1,4 @@ -""" Tests for the acquisition pipeline. """ +"""Tests for the acquisition pipeline.""" import datajoint as dj import pytest @@ -9,21 +9,32 @@ @pytest.mark.ingestion def test_epoch_chunk_ingestion(test_params, pipeline, epoch_chunk_ingestion): acquisition = pipeline["acquisition"] - epoch_count = len(acquisition.Epoch & {"experiment_name": test_params["experiment_name"]}) - chunk_count = len(acquisition.Chunk & {"experiment_name": test_params["experiment_name"]}) + epoch_count = len( + acquisition.Epoch & {"experiment_name": test_params["experiment_name"]} + ) + chunk_count = len( + acquisition.Chunk & {"experiment_name": test_params["experiment_name"]} + ) if epoch_count != test_params["epoch_count"]: - raise AssertionError(f"Expected {test_params['epoch_count']} epochs, but got {epoch_count}.") + raise AssertionError( + f"Expected {test_params['epoch_count']} epochs, but got {epoch_count}." + ) if chunk_count != test_params["chunk_count"]: - raise AssertionError(f"Expected {test_params['chunk_count']} chunks, but got {chunk_count}.") + raise AssertionError( + f"Expected {test_params['chunk_count']} chunks, but got {chunk_count}." + ) @pytest.mark.ingestion -def test_experimentlog_ingestion(test_params, pipeline, epoch_chunk_ingestion, experimentlog_ingestion): +def test_experimentlog_ingestion( + test_params, pipeline, epoch_chunk_ingestion, experimentlog_ingestion +): acquisition = pipeline["acquisition"] exp_log_message_count = len( - acquisition.ExperimentLog.Message & {"experiment_name": test_params["experiment_name"]} + acquisition.ExperimentLog.Message + & {"experiment_name": test_params["experiment_name"]} ) if exp_log_message_count != test_params["experiment_log_message_count"]: raise AssertionError( @@ -32,7 +43,8 @@ def test_experimentlog_ingestion(test_params, pipeline, epoch_chunk_ingestion, e ) subject_enter_exit_count = len( - acquisition.SubjectEnterExit.Time & {"experiment_name": test_params["experiment_name"]} + acquisition.SubjectEnterExit.Time + & {"experiment_name": test_params["experiment_name"]} ) if subject_enter_exit_count != test_params["subject_enter_exit_count"]: raise AssertionError( @@ -41,7 +53,8 @@ def test_experimentlog_ingestion(test_params, pipeline, epoch_chunk_ingestion, e ) subject_weight_time_count = len( - acquisition.SubjectWeight.WeightTime & {"experiment_name": test_params["experiment_name"]} + acquisition.SubjectWeight.WeightTime + & {"experiment_name": test_params["experiment_name"]} ) if subject_weight_time_count != test_params["subject_weight_time_count"]: raise AssertionError( diff --git a/tests/dj_pipeline/test_pipeline_instantiation.py b/tests/dj_pipeline/test_pipeline_instantiation.py index 087d0c05..fdd9313e 100644 --- a/tests/dj_pipeline/test_pipeline_instantiation.py +++ b/tests/dj_pipeline/test_pipeline_instantiation.py @@ -1,16 +1,17 @@ -""" Tests for pipeline instantiation and experiment creation """ +"""Tests for pipeline instantiation and experiment creation.""" import datajoint as dj +import pytest logger = dj.logger -import pytest - @pytest.mark.instantiation def test_pipeline_instantiation(pipeline): if not hasattr(pipeline["acquisition"], "FoodPatchEvent"): - raise AssertionError("Pipeline acquisition does not have 'FoodPatchEvent' attribute.") + raise AssertionError( + "Pipeline acquisition does not have 'FoodPatchEvent' attribute." + ) if not hasattr(pipeline["lab"], "Arena"): raise AssertionError("Pipeline lab does not have 'Arena' attribute.") @@ -19,13 +20,17 @@ def test_pipeline_instantiation(pipeline): raise AssertionError("Pipeline qc does not have 'CameraQC' attribute.") if not hasattr(pipeline["report"], "InArenaSummaryPlot"): - raise AssertionError("Pipeline report does not have 'InArenaSummaryPlot' attribute.") + raise AssertionError( + "Pipeline report does not have 'InArenaSummaryPlot' attribute." + ) if not hasattr(pipeline["subject"], "Subject"): raise AssertionError("Pipeline subject does not have 'Subject' attribute.") if not hasattr(pipeline["tracking"], "CameraTracking"): - raise AssertionError("Pipeline tracking does not have 'CameraTracking' attribute.") + raise AssertionError( + "Pipeline tracking does not have 'CameraTracking' attribute." + ) @pytest.mark.instantiation @@ -40,18 +45,23 @@ def test_experiment_creation(test_params, pipeline, experiment_creation): ) raw_dir = ( - acquisition.Experiment.Directory & {"experiment_name": experiment_name, "directory_type": "raw"} + acquisition.Experiment.Directory + & {"experiment_name": experiment_name, "directory_type": "raw"} ).fetch1("directory_path") if raw_dir != test_params["raw_dir"]: - raise AssertionError(f"Expected raw directory '{test_params['raw_dir']}', but got '{raw_dir}'.") + raise AssertionError( + f"Expected raw directory '{test_params['raw_dir']}', but got '{raw_dir}'." + ) - exp_subjects = (acquisition.Experiment.Subject & {"experiment_name": experiment_name}).fetch( - "subject" - ) + exp_subjects = ( + acquisition.Experiment.Subject & {"experiment_name": experiment_name} + ).fetch("subject") if len(exp_subjects) != test_params["subject_count"]: raise AssertionError( f"Expected subject count {test_params['subject_count']}, but got {len(exp_subjects)}." ) if "BAA-1100701" not in exp_subjects: - raise AssertionError("Expected subject 'BAA-1100701' not found in experiment subjects.") + raise AssertionError( + "Expected subject 'BAA-1100701' not found in experiment subjects." + ) diff --git a/tests/dj_pipeline/test_qc.py b/tests/dj_pipeline/test_qc.py index c2750c99..64008eb1 100644 --- a/tests/dj_pipeline/test_qc.py +++ b/tests/dj_pipeline/test_qc.py @@ -1,4 +1,4 @@ -""" Tests for the QC pipeline. """ +"""Tests for the QC pipeline.""" import datajoint as dj import pytest diff --git a/tests/dj_pipeline/test_tracking.py b/tests/dj_pipeline/test_tracking.py index 733fe5f6..860d2392 100644 --- a/tests/dj_pipeline/test_tracking.py +++ b/tests/dj_pipeline/test_tracking.py @@ -1,4 +1,4 @@ -""" Test tracking pipeline. """ +"""Test tracking pipeline.""" import datetime import pathlib @@ -24,7 +24,11 @@ def save_test_data(pipeline, test_params): file_name = ( "-".join( [ - (v.strftime("%Y%m%d%H%M%S") if isinstance(v, datetime.datetime) else str(v)) + ( + v.strftime("%Y%m%d%H%M%S") + if isinstance(v, datetime.datetime) + else str(v) + ) for v in key.values() ] ) @@ -54,7 +58,11 @@ def test_camera_tracking_ingestion(test_params, pipeline, camera_tracking_ingest file_name = ( "-".join( [ - (v.strftime("%Y%m%d%H%M%S") if isinstance(v, datetime.datetime) else str(v)) + ( + v.strftime("%Y%m%d%H%M%S") + if isinstance(v, datetime.datetime) + else str(v) + ) for v in key.values() ] ) diff --git a/tests/io/test_api.py b/tests/io/test_api.py index 015b9a58..320f9476 100644 --- a/tests/io/test_api.py +++ b/tests/io/test_api.py @@ -1,4 +1,4 @@ -""" Tests for the aeon API """ +"""Tests for the aeon API.""" from pathlib import Path @@ -38,7 +38,9 @@ def test_load_end_only(): @pytest.mark.api def test_load_filter_nonchunked(): - data = aeon.load(nonmonotonic_path, exp02.Metadata, start=pd.Timestamp("2022-06-06T09:00:00")) + data = aeon.load( + nonmonotonic_path, exp02.Metadata, start=pd.Timestamp("2022-06-06T09:00:00") + ) if len(data) <= 0: raise AssertionError("Loaded data is empty. Expected non-empty data.") @@ -57,7 +59,9 @@ def test_load_monotonic(): def test_load_nonmonotonic(): data = aeon.load(nonmonotonic_path, exp02.Patch2.Encoder, downsample=None) if data.index.is_monotonic_increasing: - raise AssertionError("Data index is monotonic increasing, but it should not be.") + raise AssertionError( + "Data index is monotonic increasing, but it should not be." + ) @pytest.mark.api @@ -68,11 +72,15 @@ def test_load_encoder_with_downsampling(): # Check that the length of the downsampled data is less than the raw data if len(data) >= len(raw_data): - raise AssertionError("Downsampled data length should be less than raw data length.") + raise AssertionError( + "Downsampled data length should be less than raw data length." + ) # Check that the first timestamp of the downsampled data is within 20ms of the raw data if abs(data.index[0] - raw_data.index[0]).total_seconds() > DOWNSAMPLE_PERIOD: - raise AssertionError("The first timestamp of downsampled data is not within 20ms of raw data.") + raise AssertionError( + "The first timestamp of downsampled data is not within 20ms of raw data." + ) # Check that the last timestamp of the downsampled data is within 20ms of the raw data if abs(data.index[-1] - raw_data.index[-1]).total_seconds() > DOWNSAMPLE_PERIOD: @@ -90,7 +98,9 @@ def test_load_encoder_with_downsampling(): # Check that the timestamps in the downsampled data are strictly increasing if not data.index.is_monotonic_increasing: - raise AssertionError("Timestamps in downsampled data are not strictly increasing.") + raise AssertionError( + "Timestamps in downsampled data are not strictly increasing." + ) if __name__ == "__main__": From 31b7fb35b02b93b6fd5a6a76b73b7da63908370b Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 02:08:26 +0000 Subject: [PATCH 046/143] fix: another round: all checks passed for `.` --- aeon/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/README.md b/aeon/README.md index 792d6005..5a608ffb 100644 --- a/aeon/README.md +++ b/aeon/README.md @@ -1 +1 @@ -# +# README # noqa D100 From cc7e759625e0b1851032d4f686f6ace397ea66b2 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 02:23:22 +0000 Subject: [PATCH 047/143] fix: fix datetime deprecation with timezone --- aeon/dj_pipeline/analysis/block_analysis.py | 4 ++-- aeon/dj_pipeline/analysis/visit.py | 4 ++-- aeon/dj_pipeline/subject.py | 18 +++++++++--------- 3 files changed, 13 insertions(+), 13 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index bcb89e4d..dd60ad7f 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -3,7 +3,7 @@ import itertools import json from collections import defaultdict -from datetime import UTC, datetime +from datetime import datetime, timezone import datajoint as dj import numpy as np @@ -268,7 +268,7 @@ def make(self, key): # log a note and pick the first rate to move forward AnalysisNote.insert1( { - "note_timestamp": datetime.now(UTC), + "note_timestamp": datetime.now(timezone.utc), "note_type": "Multiple patch rates", "note": ( f"Found multiple patch rates for block {key} " diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index 5266c5a9..b46a9c76 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -1,6 +1,6 @@ """Module for visit-related tables in the analysis schema.""" -import datetime +from datetime import datetime, timezone from collections import deque import datajoint as dj @@ -143,7 +143,7 @@ def ingest_environment_visits(experiment_names: list | None = None): .fetch("last_visit") ) start = min(subjects_last_visits) if len(subjects_last_visits) else "1900-01-01" - end = datetime.datetime.now() if start else "2200-01-01" + end = datetime.now(timezone.utc) if start else "2200-01-01" enter_exit_query = ( acquisition.SubjectEnterExit.Time * acquisition.EventType diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index 78d4bb21..6bb69214 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -3,7 +3,7 @@ import json import os import time -from datetime import UTC, datetime, timedelta +from datetime import datetime, timedelta, timezone import datajoint as dj import requests @@ -195,7 +195,7 @@ def get_reference_weight(cls, subject_name): "procedure_date", order_by="procedure_date DESC", limit=1 )[0] else: - ref_date = datetime.now(UTC).date() + ref_date = datetime.now(timezone.utc).date() weight_query = SubjectWeight & subj_key & f"weight_time < '{ref_date}'" ref_weight = ( @@ -207,7 +207,7 @@ def get_reference_weight(cls, subject_name): entry = { "subject": subject_name, "reference_weight": ref_weight, - "last_updated_time": datetime.now(UTC), + "last_updated_time": datetime.now(timezone.utc), } cls.update1(entry) if cls & {"subject": subject_name} else cls.insert1(entry) @@ -250,7 +250,7 @@ class PyratIngestion(dj.Imported): def _auto_schedule(self): """Automatically schedule the next task.""" - utc_now = datetime.now(UTC) + utc_now = datetime.now(timezone.utc) next_task_schedule_time = utc_now + timedelta(hours=self.schedule_interval) if ( @@ -265,7 +265,7 @@ def _auto_schedule(self): def make(self, key): """Automatically import or update entries in the Subject table.""" - execution_time = datetime.now(UTC) + execution_time = datetime.now(timezone.utc) new_eartags = [] for responsible_id in lab.User.fetch("responsible_id"): # 1 - retrieve all animals from this user @@ -301,7 +301,7 @@ def make(self, key): new_entry_count += 1 logger.info(f"Inserting {new_entry_count} new subject(s) from Pyrat") - completion_time = datetime.now(UTC) + completion_time = datetime.now(timezone.utc) self.insert1( { **key, @@ -334,7 +334,7 @@ class PyratCommentWeightProcedure(dj.Imported): def make(self, key): """Automatically import or update entries in the PyratCommentWeightProcedure table.""" - execution_time = datetime.now(UTC) + execution_time = datetime.now(timezone.utc) logger.info("Extracting weights/comments/procedures") eartag_or_id = key["subject"] @@ -377,7 +377,7 @@ def make(self, key): # compute/update reference weight SubjectReferenceWeight.get_reference_weight(eartag_or_id) finally: - completion_time = datetime.now(UTC) + completion_time = datetime.now(timezone.utc) self.insert1( { **key, @@ -398,7 +398,7 @@ class CreatePyratIngestionTask(dj.Computed): def make(self, key): """Create one new PyratIngestionTask for every newly added users.""" PyratIngestionTask.insert1( - {"pyrat_task_scheduled_time": datetime.now(UTC)} + {"pyrat_task_scheduled_time": datetime.now(timezone.utc)} ) time.sleep(1) self.insert1(key) From 269cf16aeef5f289c09f22a469a0434d83c6fdc0 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 11:30:02 +0000 Subject: [PATCH 048/143] fix: add missing comment in table definition --- aeon/dj_pipeline/analysis/visit_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index b03ac593..025167db 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -565,7 +565,7 @@ def make(self, key): class VisitForagingBout(dj.Computed): """Time period from when the animal enters to when it leaves a food patch while moving the wheel.""" - definition = """ + definition = """ # Time from animal's entry to exit of a food patch while moving the wheel. -> Visit -> acquisition.ExperimentFoodPatch bout_start: datetime(6) # start time of bout From 5c7637ef2769ea201dbf8b309f79f0ba9bf3ba23 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 12:04:39 +0000 Subject: [PATCH 049/143] fix: fix one I001 error --- aeon/dj_pipeline/analysis/visit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index b46a9c76..697a7a22 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -1,7 +1,7 @@ """Module for visit-related tables in the analysis schema.""" -from datetime import datetime, timezone from collections import deque +from datetime import datetime, timezone import datajoint as dj import numpy as np From cb7226a830eb75e59586500925a85b8e6fbd058e Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 12:05:24 +0000 Subject: [PATCH 050/143] fix: add `UP017:Use `datetime.UTC` alias` as ignored --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fd85d178..cd1dfa3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,6 +102,7 @@ lint.ignore = [ "PLR0912", "PLR0913", "PLR0915", + "UP017" ] extend-exclude = [ ".git", @@ -113,7 +114,7 @@ extend-exclude = [ ] [tool.ruff.lint.per-file-ignores] "tests/*" = [ - "D103", # skip adding docstrings for public functions + "D103", # skip adding docstrings for public functions ] "aeon/schema/*" = [ "D101", # skip adding docstrings for schema classes From 2cc6f29bb4aacfc47424a76e476ab83fe3d870cb Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 17:57:13 +0000 Subject: [PATCH 051/143] fix: E741 substitute with meaniniful variables --- aeon/analysis/block_plotting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index 33c52663..dfe14728 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -29,16 +29,16 @@ def gen_hex_grad(hex_col, vals, min_l=0.3): """Generates an array of hex color values based on a gradient defined by unit-normalized values.""" # Convert hex to rgb to hls - h, ll, s = rgb_to_hls( + hue, lightness, saturation = rgb_to_hls( *[int(hex_col.lstrip("#")[i : i + 2], 16) / 255 for i in (0, 2, 4)] ) grad = np.empty(shape=(len(vals),), dtype=" Date: Tue, 29 Oct 2024 17:58:45 +0000 Subject: [PATCH 052/143] fix: update variable names --- aeon/analysis/block_plotting.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index dfe14728..d7bbc213 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -26,7 +26,7 @@ patch_markers_linestyles = ["solid", "dash", "dot", "dashdot", "longdashdot"] -def gen_hex_grad(hex_col, vals, min_l=0.3): +def gen_hex_grad(hex_col, vals, min_lightness=0.3): """Generates an array of hex color values based on a gradient defined by unit-normalized values.""" # Convert hex to rgb to hls hue, lightness, saturation = rgb_to_hls( @@ -34,11 +34,13 @@ def gen_hex_grad(hex_col, vals, min_l=0.3): ) grad = np.empty(shape=(len(vals),), dtype=" Date: Tue, 29 Oct 2024 18:09:10 +0000 Subject: [PATCH 053/143] fix: review --- aeon/dj_pipeline/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aeon/dj_pipeline/__init__.py b/aeon/dj_pipeline/__init__.py index c8b0b1c5..df34769b 100644 --- a/aeon/dj_pipeline/__init__.py +++ b/aeon/dj_pipeline/__init__.py @@ -27,11 +27,11 @@ def get_schema_name(name) -> str: def dict_to_uuid(key) -> uuid.UUID: """Given a dictionary `key`, returns a hash string as UUID.""" - hashed = hashlib.sha256() + hashed = hashlib.md5() for k, v in sorted(key.items()): hashed.update(str(k).encode()) hashed.update(str(v).encode()) - return uuid.UUID(hex=hashed.hexdigest()[:32]) + return uuid.UUID(hex=hashed.hexdigest()) def fetch_stream(query, drop_pk=True): @@ -68,5 +68,5 @@ def fetch_stream(query, drop_pk=True): from .utils import streams_maker streams = dj.VirtualModule("streams", streams_maker.schema_name) - except ImportError: + except Exception: pass From 6ad3b0af76fd070cc228701eb5ab1a26ae222c27 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 18:11:14 +0000 Subject: [PATCH 054/143] fix: review --- aeon/dj_pipeline/__init__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/__init__.py b/aeon/dj_pipeline/__init__.py index df34769b..13f45d7e 100644 --- a/aeon/dj_pipeline/__init__.py +++ b/aeon/dj_pipeline/__init__.py @@ -6,6 +6,8 @@ import datajoint as dj +logger = dj.logger + _default_database_prefix = os.getenv("DJ_DB_PREFIX") or "aeon_" _default_repository_config = {"ceph_aeon": "/ceph/aeon"} @@ -68,5 +70,5 @@ def fetch_stream(query, drop_pk=True): from .utils import streams_maker streams = dj.VirtualModule("streams", streams_maker.schema_name) - except Exception: - pass + except Exception as e: + logger.debug(f"Could not import streams module: {e}") From 52625703ccbf20c4ac9be13a23a04b6a6e81acef Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 18:12:27 +0000 Subject: [PATCH 055/143] fix: review - exception for md5 --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index cd1dfa3d..78623e3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,6 +123,7 @@ extend-exclude = [ "aeon/dj_pipeline/*" = [ "D101", # skip adding docstrings for schema classes "D106", # skip adding docstrings for nested streams + "S324", ] [tool.ruff.lint.pydocstyle] From 18b3aa59925a8a47d1c877f31debfbf29731960b Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 18:16:50 +0000 Subject: [PATCH 056/143] fix: review --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 78623e3d..fa9692c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -124,6 +124,7 @@ extend-exclude = [ "D101", # skip adding docstrings for schema classes "D106", # skip adding docstrings for nested streams "S324", + "F401", # skip incorrectly detecting `aeon.dj_pipeline` dependencies as unused ] [tool.ruff.lint.pydocstyle] From 115ea45603600be2a8abb7e8cb4d3a2f6c5b9ab6 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 18:17:02 +0000 Subject: [PATCH 057/143] fix: review --- aeon/dj_pipeline/acquisition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index 3d5bf577..cd00de76 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -8,11 +8,11 @@ import datajoint as dj import pandas as pd -from aeon.dj_pipeline import get_schema_name from aeon.dj_pipeline.utils import paths from aeon.io import api as io_api from aeon.io import reader as io_reader from aeon.schema import schemas as aeon_schemas +from aeon.dj_pipeline import get_schema_name, lab, subject logger = dj.logger schema = dj.schema(get_schema_name("acquisition")) From ec008b9a6141b4337c244bb6d2cacb00bd2a3696 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 18:21:18 +0000 Subject: [PATCH 058/143] fix: review --- aeon/dj_pipeline/acquisition.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index cd00de76..1b371409 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -8,15 +8,17 @@ import datajoint as dj import pandas as pd +from aeon.dj_pipeline import get_schema_name, lab, subject from aeon.dj_pipeline.utils import paths from aeon.io import api as io_api from aeon.io import reader as io_reader from aeon.schema import schemas as aeon_schemas -from aeon.dj_pipeline import get_schema_name, lab, subject -logger = dj.logger schema = dj.schema(get_schema_name("acquisition")) +logger = dj.logger + + # ------------------- Some Constants -------------------------- _ref_device_mapping = { @@ -147,7 +149,7 @@ def get_data_directory(cls, experiment_key, directory_type="raw", as_posix=False dir_path = pathlib.Path(dir_path) if dir_path.exists(): if not dir_path.is_relative_to(paths.get_repository_path(repo_name)): - raise ValueError(f"f{dir_path} is not relative to the repository path.") + raise ValueError(f"{dir_path} is not relative to the repository path.") data_directory = dir_path else: data_directory = paths.get_repository_path(repo_name) / dir_path From 1c91dadb0baecf6d230951e3309a928e8d9dd788 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 18:25:13 +0000 Subject: [PATCH 059/143] fix: review --- aeon/dj_pipeline/analysis/visit.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index 697a7a22..a92182d5 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -8,7 +8,14 @@ import pandas as pd from aeon.analysis import utils as analysis_utils -from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name +from aeon.dj_pipeline import ( + acquisition, + fetch_stream, + get_schema_name, + lab, + qc, + tracking, +) schema = dj.schema(get_schema_name("analysis")) From c19d92f8a942a3b8be2632cc33694e75b71bdb33 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 18:35:26 +0000 Subject: [PATCH 060/143] fix: review --- aeon/dj_pipeline/tracking.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index f6d298eb..1bfd2781 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -256,11 +256,10 @@ def make(self, key): # ---------- HELPER ------------------ -TARGET_LENGTH = 2 - def compute_distance(position_df, target, xcol="x", ycol="y"): """Compute the distance of the position data from a target point.""" + TARGET_LENGTH = 2 if len(target) != TARGET_LENGTH: raise ValueError(f"Target must be a list of tuple of length {TARGET_LENGTH}.") return np.sqrt(np.square(position_df[[xcol, ycol]] - target).sum(axis=1)) From afa6ff9394776d4319768c869190b2a03e20ef3e Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 18:39:08 +0000 Subject: [PATCH 061/143] fix: review --- aeon/dj_pipeline/tracking.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 1bfd2781..9fca07e1 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -258,10 +258,9 @@ def make(self, key): def compute_distance(position_df, target, xcol="x", ycol="y"): - """Compute the distance of the position data from a target point.""" - TARGET_LENGTH = 2 - if len(target) != TARGET_LENGTH: - raise ValueError(f"Target must be a list of tuple of length {TARGET_LENGTH}.") + """Compute the distance of the position data from a target coordinate (X,Y).""" + if len(target) != 2: # noqa PLR2004 + raise ValueError("Target must be a list of tuple of length 2.") return np.sqrt(np.square(position_df[[xcol, ycol]] - target).sum(axis=1)) From 842f9337a48e4f9aeaed408060a4958e5c698968 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 18:44:46 +0000 Subject: [PATCH 062/143] fix: review --- aeon/dj_pipeline/utils/streams_maker.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index c8b57ce4..79c62ffd 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -78,8 +78,7 @@ class ExperimentDevice(dj.Manual): definition = f""" # {device_title} placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-{aeon.__version__}) -> acquisition.Experiment -> Device - {device_type}_install_time : datetime(6) # time of the {device_type} placed - # and started operation at this position + {device_type}_install_time : datetime(6) # time of the {device_type} placed and started operation at this position --- {device_type}_name : varchar(36) """ # noqa: E501 From 8aa55a90128350aa5ffda83f589fd956691e384e Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 18:49:13 +0000 Subject: [PATCH 063/143] fix: review --- aeon/dj_pipeline/utils/streams_maker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index 79c62ffd..26f1ded7 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -116,8 +116,9 @@ def get_device_stream_template(device_type: str, stream_type: str, streams_modul ) ).fetch1() - for i, n in enumerate(stream_detail["stream_reader"].split(".")): - reader = aeon if i == 0 else getattr(reader, n) # noqa: F821 + reader = aeon + for n in stream_detail["stream_reader"].split(".")[1:]: + reader = getattr(reader, n) if reader is aeon.io.reader.Pose: logger.warning( From 88ae62e2e1a12d6397aed2f8f08678208085ac29 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 18:53:43 +0000 Subject: [PATCH 064/143] fix: review --- aeon/dj_pipeline/utils/streams_maker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index 26f1ded7..115ce73f 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -147,11 +147,11 @@ class DeviceDataStream(dj.Imported): @property def key_source(self): - """Only the combination of Chunk and device_type with overlapping time. + f"""Only the combination of Chunk and {device_type} with overlapping time. - + Chunk(s) that started after device_type install time and ended before device_type remove time - + Chunk(s) that started after device_type install time for device_type that are not yet removed - """ + + Chunk(s) that started after {device_type} install time and ended before {device_type} remove time + + Chunk(s) that started after {device_type} install time for {device_type} that are not yet removed + """ # noqa B021 key_source_query = ( acquisition.Chunk * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) From 4c69b8204f29fdd11fecfafe23b2358baeb0ce6a Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 18:55:44 +0000 Subject: [PATCH 065/143] fix: review --- aeon/dj_pipeline/utils/streams_maker.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index 115ce73f..83fee7cc 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -236,9 +236,7 @@ def main(create_tables=True): full_def = "@schema \n" + device_table_def + "\n\n" f.write(full_def) else: - raise FileExistsError( - f"File {_STREAMS_MODULE_FILE} already exists. Please remove it and try again." - ) + pass streams = importlib.import_module("aeon.dj_pipeline.streams") From 2dda4f0e13511b7e235b29a760362d8dacb57c82 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 18:56:48 +0000 Subject: [PATCH 066/143] fix: review --- aeon/dj_pipeline/utils/streams_maker.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index 83fee7cc..cfca110c 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -235,8 +235,6 @@ def main(create_tables=True): device_table_def = inspect.getsource(table_class).lstrip() full_def = "@schema \n" + device_table_def + "\n\n" f.write(full_def) - else: - pass streams = importlib.import_module("aeon.dj_pipeline.streams") From abab11225b1f0c21ace6d0396c02b58a9fdcdbba Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 19:01:07 +0000 Subject: [PATCH 067/143] fix: review --- aeon/dj_pipeline/utils/streams_maker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index cfca110c..dc5e1688 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -291,8 +291,8 @@ def main(create_tables=True): 'f"chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time"': ( f"'chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time'" ), - """f'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, "2200-01-01")'""": ( # noqa: E501 - f"""'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time,"2200-01-01")'""" # noqa: W291, E501 + """f'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, "2200-01-01")'""": ( # noqa E501 + f"""'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time,"2200-01-01")'""" # noqa E501 ), 'f"{dj.utils.from_camel_case(device_type)}_name"': ( f"'{dj.utils.from_camel_case(device_type)}_name'" From 2b96423a80e813e5da4cd3d5fb42b2abfcf2ebce Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 19:07:31 +0000 Subject: [PATCH 068/143] fix: update pyproject.toml --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fa9692c8..6df94f0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,11 +98,11 @@ lint.select = [ line-length = 108 lint.ignore = [ "E731", - "PT004", # Rule `PT004` is deprecated and will be removed in a future release. + "PT004", # Deprecated and will be removed in a future release. "PLR0912", "PLR0913", "PLR0915", - "UP017" + "UP017" # skip `datetime.UTC` alias ] extend-exclude = [ ".git", @@ -123,7 +123,7 @@ extend-exclude = [ "aeon/dj_pipeline/*" = [ "D101", # skip adding docstrings for schema classes "D106", # skip adding docstrings for nested streams - "S324", + "S324", # skip hashlib insecure hash function (md5) warning "F401", # skip incorrectly detecting `aeon.dj_pipeline` dependencies as unused ] From a438f836ed78f36433ccabef2ca8e18ae985f46a Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 22:12:39 +0000 Subject: [PATCH 069/143] fix(internal team review): --- aeon/io/reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/io/reader.py b/aeon/io/reader.py index 9cb08c01..cda78869 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -173,7 +173,7 @@ def read(self, file): df = pd.read_json(f, lines=True) df.set_index("seconds", inplace=True) for column in self.columns: - df[column] = df[self.root_key].apply(lambda x, col=column: x[col]) + df[column] = df[self.root_key].apply(lambda x: x[column]) # noqa B023 return df From 8c8317a39bd70ca0e87f65dbecb6d8e736cd41ec Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 22:16:16 +0000 Subject: [PATCH 070/143] fix(internal team review): --- aeon/dj_pipeline/utils/streams_maker.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index dc5e1688..a0fe9aab 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -152,16 +152,14 @@ def key_source(self): + Chunk(s) that started after {device_type} install time and ended before {device_type} remove time + Chunk(s) that started after {device_type} install time for {device_type} that are not yet removed """ # noqa B021 - key_source_query = ( + device_type_name = dj.utils.from_camel_case(device_type) + return ( acquisition.Chunk * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) - & f"chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time" - & f'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time,\ - "2200-01-01")' + & f"chunk_start >= {device_type_name}_install_time" + & f'chunk_start < IFNULL({device_type_name}_removal_time,"2200-01-01")' ) - return key_source_query - def make(self, key): """Load and insert the data for the DeviceDataStream table.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( From bbec40af2e207ab8dbb9ba7c5a6bd6c77c00f500 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 22:24:17 +0000 Subject: [PATCH 071/143] fix(internal team review): --- tests/dj_pipeline/test_acquisition.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/dj_pipeline/test_acquisition.py b/tests/dj_pipeline/test_acquisition.py index fce754ad..748b03b7 100644 --- a/tests/dj_pipeline/test_acquisition.py +++ b/tests/dj_pipeline/test_acquisition.py @@ -1,10 +1,7 @@ """Tests for the acquisition pipeline.""" -import datajoint as dj import pytest -logger = dj.logger - @pytest.mark.ingestion def test_epoch_chunk_ingestion(test_params, pipeline, epoch_chunk_ingestion): From 0796793868996211b25664c347faf9ae46dc5e85 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 22:29:43 +0000 Subject: [PATCH 072/143] fix(internal team review): --- aeon/dj_pipeline/analysis/block_analysis.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index dd60ad7f..cc862722 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -137,7 +137,7 @@ class BlockAnalysis(dj.Computed): @property def key_source(self): - """Ensure that the chunk ingestion has caught up with this block before processing (there exists a chunk that ends after the block end time).""" # noqa 501 + """Ensure that the chunk ingestion has caught up with this block before processing (there exists a chunk that ends after the block end time).""" # noqa 501 ks = Block.aggr(acquisition.Chunk, latest_chunk_end="MAX(chunk_end)") ks = ks * Block & "latest_chunk_end >= block_end" & "block_end IS NOT NULL" return ks @@ -624,12 +624,14 @@ def make(self, key): for p in patch_names ] ) + + CUM_PREF_DIST_MIN = 1e-3 + for patch_name in patch_names: cum_pref_dist = ( all_subj_patch_pref_dict[patch_name][subject_name]["cum_dist"] / all_cum_dist ) - CUM_PREF_DIST_MIN = 1e-3 cum_pref_dist = np.where( cum_pref_dist < CUM_PREF_DIST_MIN, 0, cum_pref_dist ) From d3c388179f46df985fbaf5908529b69317f13a45 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 22:32:41 +0000 Subject: [PATCH 073/143] fix(internal team review): --- aeon/dj_pipeline/analysis/block_analysis.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index cc862722..5cb82c0c 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -1843,11 +1843,15 @@ def get_threshold_associated_pellets(patch_key, start, end): ) # Step 2 - Remove invalid rows (back-to-back events) - BTB_TIME_DIFF = ( - 1.2 # pellet delivery trigger - time difference is less than 1.2 seconds + BTB_MIN_TIME_DIFF = ( + 1.2 # pellet delivery trigger - time diff is less than 1.2 seconds ) + BB_MIN_TIME_DIFF = 1.0 # beambreak - time difference is less than 1 seconds + PT_MIN_TIME_DIFF = 1.0 # patch threshold - time difference is less than 1 seconds + invalid_rows = ( - delivered_pellet_df.index.to_series().diff().dt.total_seconds() < BTB_TIME_DIFF + delivered_pellet_df.index.to_series().diff().dt.total_seconds() + < BTB_MIN_TIME_DIFF ) delivered_pellet_df = delivered_pellet_df[~invalid_rows] # exclude manual deliveries @@ -1855,16 +1859,15 @@ def get_threshold_associated_pellets(patch_key, start, end): delivered_pellet_df.index.difference(manual_delivery_df.index) ] - BB_TIME_DIFF = 1.0 # beambreak - time difference is less than 1 seconds invalid_rows = ( - beambreak_df.index.to_series().diff().dt.total_seconds() < BB_TIME_DIFF + beambreak_df.index.to_series().diff().dt.total_seconds() < BB_MIN_TIME_DIFF ) beambreak_df = beambreak_df[~invalid_rows] - PT_TIME_DIFF = 1.0 # patch threshold - time difference is less than 1 seconds depletion_state_df = depletion_state_df.dropna(subset=["threshold"]) invalid_rows = ( - depletion_state_df.index.to_series().diff().dt.total_seconds() < PT_TIME_DIFF + depletion_state_df.index.to_series().diff().dt.total_seconds() + < PT_MIN_TIME_DIFF ) depletion_state_df = depletion_state_df[~invalid_rows] @@ -1882,7 +1885,7 @@ def get_threshold_associated_pellets(patch_key, start, end): beambreak_df.reset_index().rename(columns={"time": "beam_break_timestamp"}), left_on="time", right_on="beam_break_timestamp", - tolerance=pd.Timedelta("{BTB_TIME_DIFF}s"), + tolerance=pd.Timedelta("{BTB_MIN_TIME_DIFF}s"), direction="forward", ) .set_index("time") From 4600e99f3b6d7c66a06ed75a3251fed4b7f2c70f Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 29 Oct 2024 22:34:46 +0000 Subject: [PATCH 074/143] fix(internal team review): Revert to previous version since it is autogenerated --- aeon/dj_pipeline/streams.py | 145 +++++++++++++++++++++++------------- 1 file changed, 93 insertions(+), 52 deletions(-) diff --git a/aeon/dj_pipeline/streams.py b/aeon/dj_pipeline/streams.py index 97cbc10d..225e7198 100644 --- a/aeon/dj_pipeline/streams.py +++ b/aeon/dj_pipeline/streams.py @@ -190,7 +190,9 @@ def key_source(self): def make(self, key): """Load and insert RfidEvents data stream for a given chunk and RfidReader.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) @@ -198,9 +200,10 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(getattr(devices_schema, device_name), "RfidEvents") @@ -246,14 +249,17 @@ def key_source(self): + Chunk(s) that started after SpinnakerVideoSource install time for SpinnakerVideoSource that are not yet removed """ return ( - acquisition.Chunk * SpinnakerVideoSource.join(SpinnakerVideoSource.RemovalTime, left=True) + acquisition.Chunk + * SpinnakerVideoSource.join(SpinnakerVideoSource.RemovalTime, left=True) & "chunk_start >= spinnaker_video_source_install_time" & 'chunk_start < IFNULL(spinnaker_video_source_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert Video data stream for a given chunk and SpinnakerVideoSource.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) @@ -261,9 +267,10 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(getattr(devices_schema, device_name), "Video") @@ -308,14 +315,17 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert BeamBreak data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) @@ -323,9 +333,10 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(getattr(devices_schema, device_name), "BeamBreak") @@ -370,14 +381,17 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert DeliverPellet data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) @@ -385,9 +399,10 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(getattr(devices_schema, device_name), "DeliverPellet") @@ -434,14 +449,17 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert DepletionState data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) @@ -449,9 +467,10 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(getattr(devices_schema, device_name), "DepletionState") @@ -497,14 +516,17 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert Encoder data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) @@ -512,9 +534,10 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(getattr(devices_schema, device_name), "Encoder") @@ -559,14 +582,17 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert ManualDelivery data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) @@ -574,9 +600,10 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(getattr(devices_schema, device_name), "ManualDelivery") @@ -621,14 +648,17 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert MissedPellet data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) @@ -636,9 +666,10 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(getattr(devices_schema, device_name), "MissedPellet") @@ -683,14 +714,17 @@ def key_source(self): + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ return ( - acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + acquisition.Chunk + * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) & "chunk_start >= underground_feeder_install_time" & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' ) def make(self, key): """Load and insert RetriedDelivery data stream for a given chunk and UndergroundFeeder.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) @@ -698,9 +732,10 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(getattr(devices_schema, device_name), "RetriedDelivery") @@ -753,7 +788,9 @@ def key_source(self): def make(self, key): """Load and insert WeightFiltered data stream for a given chunk and WeightScale.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) @@ -761,9 +798,10 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(getattr(devices_schema, device_name), "WeightFiltered") @@ -816,7 +854,9 @@ def key_source(self): def make(self, key): """Load and insert WeightRaw data stream for a given chunk and WeightScale.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( + "chunk_start", "chunk_end" + ) data_dirs = acquisition.Experiment.get_data_directories(key) @@ -824,9 +864,10 @@ def make(self, key): devices_schema = getattr( aeon_schemas, - (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( - "devices_schema_name" - ), + ( + acquisition.Experiment.DevicesSchema + & {"experiment_name": key["experiment_name"]} + ).fetch1("devices_schema_name"), ) stream_reader = getattr(getattr(devices_schema, device_name), "WeightRaw") From e3dfdf54a484323f81842dee7ec9736b5b85ed39 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 30 Oct 2024 15:19:34 +0000 Subject: [PATCH 075/143] fix(internal team review): revert replacement of assert in `tests` and add rule in `pyproject` --- pyproject.toml | 1 + tests/dj_pipeline/conftest.py | 9 +-- tests/dj_pipeline/test_acquisition.py | 60 +++++++------------ .../test_pipeline_instantiation.py | 56 ++++------------- tests/dj_pipeline/test_qc.py | 11 +--- tests/dj_pipeline/test_tracking.py | 24 ++++---- tests/io/test_api.py | 47 ++++----------- 7 files changed, 58 insertions(+), 150 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6df94f0a..96c4a98c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,6 +115,7 @@ extend-exclude = [ [tool.ruff.lint.per-file-ignores] "tests/*" = [ "D103", # skip adding docstrings for public functions + "S101", # skip using assert ] "aeon/schema/*" = [ "D101", # skip adding docstrings for schema classes diff --git a/tests/dj_pipeline/conftest.py b/tests/dj_pipeline/conftest.py index da43d891..d1383fe7 100644 --- a/tests/dj_pipeline/conftest.py +++ b/tests/dj_pipeline/conftest.py @@ -18,7 +18,6 @@ _tear_down = True # always set to True since most fixtures are session-scoped _populate_settings = {"suppress_errors": True} -logger = dj.logger def data_dir(): @@ -56,14 +55,10 @@ def dj_config(): DataJoint configuration. """ dj_config_fp = pathlib.Path("dj_local_conf.json") - if not dj_config_fp.exists(): - raise FileNotFoundError( - f"DataJoint configuration file not found: {dj_config_fp}" - ) + assert dj_config_fp.exists() dj.config.load(dj_config_fp) dj.config["safemode"] = False - if "custom" not in dj.config: - raise KeyError("'custom' not found in DataJoint configuration.") + assert "custom" in dj.config dj.config["custom"][ "database.prefix" ] = f"u_{dj.config['database.user']}_testsuite_" diff --git a/tests/dj_pipeline/test_acquisition.py b/tests/dj_pipeline/test_acquisition.py index 748b03b7..51cd1e77 100644 --- a/tests/dj_pipeline/test_acquisition.py +++ b/tests/dj_pipeline/test_acquisition.py @@ -6,21 +6,15 @@ @pytest.mark.ingestion def test_epoch_chunk_ingestion(test_params, pipeline, epoch_chunk_ingestion): acquisition = pipeline["acquisition"] - epoch_count = len( - acquisition.Epoch & {"experiment_name": test_params["experiment_name"]} + + assert ( + len(acquisition.Epoch & {"experiment_name": test_params["experiment_name"]}) + == test_params["epoch_count"] ) - chunk_count = len( - acquisition.Chunk & {"experiment_name": test_params["experiment_name"]} + assert ( + len(acquisition.Chunk & {"experiment_name": test_params["experiment_name"]}) + == test_params["chunk_count"] ) - if epoch_count != test_params["epoch_count"]: - raise AssertionError( - f"Expected {test_params['epoch_count']} epochs, but got {epoch_count}." - ) - - if chunk_count != test_params["chunk_count"]: - raise AssertionError( - f"Expected {test_params['chunk_count']} chunks, but got {chunk_count}." - ) @pytest.mark.ingestion @@ -29,32 +23,24 @@ def test_experimentlog_ingestion( ): acquisition = pipeline["acquisition"] - exp_log_message_count = len( - acquisition.ExperimentLog.Message - & {"experiment_name": test_params["experiment_name"]} - ) - if exp_log_message_count != test_params["experiment_log_message_count"]: - raise AssertionError( - f"Expected {test_params['experiment_log_message_count']} log messages," - f"but got {exp_log_message_count}." + assert ( + len( + acquisition.ExperimentLog.Message + & {"experiment_name": test_params["experiment_name"]} ) - - subject_enter_exit_count = len( - acquisition.SubjectEnterExit.Time - & {"experiment_name": test_params["experiment_name"]} + == test_params["experiment_log_message_count"] ) - if subject_enter_exit_count != test_params["subject_enter_exit_count"]: - raise AssertionError( - f"Expected {test_params['subject_enter_exit_count']} subject enter/exit events," - f"but got {subject_enter_exit_count}." + assert ( + len( + acquisition.SubjectEnterExit.Time + & {"experiment_name": test_params["experiment_name"]} ) - - subject_weight_time_count = len( - acquisition.SubjectWeight.WeightTime - & {"experiment_name": test_params["experiment_name"]} + == test_params["subject_enter_exit_count"] ) - if subject_weight_time_count != test_params["subject_weight_time_count"]: - raise AssertionError( - f"Expected {test_params['subject_weight_time_count']} subject weight events," - f"but got {subject_weight_time_count}." + assert ( + len( + acquisition.SubjectWeight.WeightTime + & {"experiment_name": test_params["experiment_name"]} ) + == test_params["subject_weight_time_count"] + ) diff --git a/tests/dj_pipeline/test_pipeline_instantiation.py b/tests/dj_pipeline/test_pipeline_instantiation.py index fdd9313e..52da5625 100644 --- a/tests/dj_pipeline/test_pipeline_instantiation.py +++ b/tests/dj_pipeline/test_pipeline_instantiation.py @@ -1,36 +1,16 @@ """Tests for pipeline instantiation and experiment creation.""" -import datajoint as dj import pytest -logger = dj.logger - @pytest.mark.instantiation def test_pipeline_instantiation(pipeline): - if not hasattr(pipeline["acquisition"], "FoodPatchEvent"): - raise AssertionError( - "Pipeline acquisition does not have 'FoodPatchEvent' attribute." - ) - - if not hasattr(pipeline["lab"], "Arena"): - raise AssertionError("Pipeline lab does not have 'Arena' attribute.") - - if not hasattr(pipeline["qc"], "CameraQC"): - raise AssertionError("Pipeline qc does not have 'CameraQC' attribute.") - - if not hasattr(pipeline["report"], "InArenaSummaryPlot"): - raise AssertionError( - "Pipeline report does not have 'InArenaSummaryPlot' attribute." - ) - - if not hasattr(pipeline["subject"], "Subject"): - raise AssertionError("Pipeline subject does not have 'Subject' attribute.") - - if not hasattr(pipeline["tracking"], "CameraTracking"): - raise AssertionError( - "Pipeline tracking does not have 'CameraTracking' attribute." - ) + assert hasattr(pipeline["acquisition"], "FoodPatchEvent") + assert hasattr(pipeline["lab"], "Arena") + assert hasattr(pipeline["qc"], "CameraQC") + assert hasattr(pipeline["report"], "InArenaSummaryPlot") + assert hasattr(pipeline["subject"], "Subject") + assert hasattr(pipeline["tracking"], "CameraTracking") @pytest.mark.instantiation @@ -38,30 +18,14 @@ def test_experiment_creation(test_params, pipeline, experiment_creation): acquisition = pipeline["acquisition"] experiment_name = test_params["experiment_name"] - fetched_experiment_name = acquisition.Experiment.fetch1("experiment_name") - if fetched_experiment_name != experiment_name: - raise AssertionError( - f"Expected experiment name '{experiment_name}', but got '{fetched_experiment_name}'." - ) - + assert acquisition.Experiment.fetch1("experiment_name") == experiment_name raw_dir = ( acquisition.Experiment.Directory & {"experiment_name": experiment_name, "directory_type": "raw"} ).fetch1("directory_path") - if raw_dir != test_params["raw_dir"]: - raise AssertionError( - f"Expected raw directory '{test_params['raw_dir']}', but got '{raw_dir}'." - ) - + assert raw_dir == test_params["raw_dir"] exp_subjects = ( acquisition.Experiment.Subject & {"experiment_name": experiment_name} ).fetch("subject") - if len(exp_subjects) != test_params["subject_count"]: - raise AssertionError( - f"Expected subject count {test_params['subject_count']}, but got {len(exp_subjects)}." - ) - - if "BAA-1100701" not in exp_subjects: - raise AssertionError( - "Expected subject 'BAA-1100701' not found in experiment subjects." - ) + assert len(exp_subjects) == test_params["subject_count"] + assert "BAA-1100701" in exp_subjects diff --git a/tests/dj_pipeline/test_qc.py b/tests/dj_pipeline/test_qc.py index 64008eb1..31e6baf9 100644 --- a/tests/dj_pipeline/test_qc.py +++ b/tests/dj_pipeline/test_qc.py @@ -1,19 +1,10 @@ """Tests for the QC pipeline.""" -import datajoint as dj import pytest -logger = dj.logger - @pytest.mark.qc def test_camera_qc_ingestion(test_params, pipeline, camera_qc_ingestion): qc = pipeline["qc"] - camera_qc_count = len(qc.CameraQC()) - expected_camera_qc_count = test_params["camera_qc_count"] - - if camera_qc_count != expected_camera_qc_count: - raise AssertionError( - f"Expected camera QC count {expected_camera_qc_count}, but got {camera_qc_count}." - ) + assert len(qc.CameraQC()) == test_params["camera_qc_count"] diff --git a/tests/dj_pipeline/test_tracking.py b/tests/dj_pipeline/test_tracking.py index 860d2392..1227adb2 100644 --- a/tests/dj_pipeline/test_tracking.py +++ b/tests/dj_pipeline/test_tracking.py @@ -47,12 +47,10 @@ def save_test_data(pipeline, test_params): def test_camera_tracking_ingestion(test_params, pipeline, camera_tracking_ingestion): tracking = pipeline["tracking"] - camera_tracking_object_count = len(tracking.CameraTracking.Object()) - if camera_tracking_object_count != test_params["camera_tracking_object_count"]: - raise AssertionError( - f"Expected camera tracking object count {test_params['camera_tracking_object_count']}," - f"but got {camera_tracking_object_count}." - ) + assert ( + len(tracking.CameraTracking.Object()) + == test_params["camera_tracking_object_count"] + ) key = tracking.CameraTracking.Object().fetch("KEY")[index] file_name = ( @@ -70,15 +68,13 @@ def test_camera_tracking_ingestion(test_params, pipeline, camera_tracking_ingest ) test_file = pathlib.Path(test_params["test_dir"] + "/" + file_name) - if not test_file.exists(): - raise AssertionError(f"Test file '{test_file}' does not exist.") + assert test_file.exists() print(f"\nTesting {file_name}") data = np.load(test_file) - expected_data = (tracking.CameraTracking.Object() & key).fetch(column_name)[0] - - if not np.allclose(data, expected_data, equal_nan=True): - raise AssertionError( - f"Loaded data does not match the expected data.nExpected: {expected_data}, but got: {data}." - ) + assert np.allclose( + data, + (tracking.CameraTracking.Object() & key).fetch(column_name)[0], + equal_nan=True, + ) diff --git a/tests/io/test_api.py b/tests/io/test_api.py index 320f9476..cf8cbfda 100644 --- a/tests/io/test_api.py +++ b/tests/io/test_api.py @@ -20,8 +20,7 @@ def test_load_start_only(): start=pd.Timestamp("2022-06-06T13:00:49"), downsample=None, ) - if len(data) <= 0: - raise AssertionError("Loaded data is empty. Expected non-empty data.") + assert len(data) > 0 @pytest.mark.api @@ -32,8 +31,7 @@ def test_load_end_only(): end=pd.Timestamp("2022-06-06T13:00:49"), downsample=None, ) - if len(data) <= 0: - raise AssertionError("Loaded data is empty. Expected non-empty data.") + assert len(data) > 0 @pytest.mark.api @@ -41,27 +39,20 @@ def test_load_filter_nonchunked(): data = aeon.load( nonmonotonic_path, exp02.Metadata, start=pd.Timestamp("2022-06-06T09:00:00") ) - if len(data) <= 0: - raise AssertionError("Loaded data is empty. Expected non-empty data.") + assert len(data) > 0 @pytest.mark.api def test_load_monotonic(): data = aeon.load(monotonic_path, exp02.Patch2.Encoder, downsample=None) - if len(data) <= 0: - raise AssertionError("Loaded data is empty. Expected non-empty data.") - - if not data.index.is_monotonic_increasing: - raise AssertionError("Data index is not monotonic increasing.") + assert len(data) > 0 + assert data.index.is_monotonic_increasing @pytest.mark.api def test_load_nonmonotonic(): data = aeon.load(nonmonotonic_path, exp02.Patch2.Encoder, downsample=None) - if data.index.is_monotonic_increasing: - raise AssertionError( - "Data index is monotonic increasing, but it should not be." - ) + assert not data.index.is_monotonic_increasing @pytest.mark.api @@ -71,36 +62,20 @@ def test_load_encoder_with_downsampling(): raw_data = aeon.load(monotonic_path, exp02.Patch2.Encoder, downsample=None) # Check that the length of the downsampled data is less than the raw data - if len(data) >= len(raw_data): - raise AssertionError( - "Downsampled data length should be less than raw data length." - ) + assert len(data) < len(raw_data) # Check that the first timestamp of the downsampled data is within 20ms of the raw data - if abs(data.index[0] - raw_data.index[0]).total_seconds() > DOWNSAMPLE_PERIOD: - raise AssertionError( - "The first timestamp of downsampled data is not within 20ms of raw data." - ) + assert abs(data.index[0] - raw_data.index[0]).total_seconds() <= DOWNSAMPLE_PERIOD # Check that the last timestamp of the downsampled data is within 20ms of the raw data - if abs(data.index[-1] - raw_data.index[-1]).total_seconds() > DOWNSAMPLE_PERIOD: - raise AssertionError( - f"The last timestamp of downsampled data is not within {DOWNSAMPLE_PERIOD*1000} ms of raw data." - ) + assert abs(data.index[-1] - raw_data.index[-1]).total_seconds() <= DOWNSAMPLE_PERIOD # Check that the minimum difference between consecutive timestamps in the downsampled data # is at least 20ms (50Hz) - min_diff = data.index.to_series().diff().dt.total_seconds().min() - if min_diff < DOWNSAMPLE_PERIOD: - raise AssertionError( - f"Minimum difference between consecutive timestamps is less than {DOWNSAMPLE_PERIOD} seconds." - ) + assert data.index.to_series().diff().dt.total_seconds().min() >= DOWNSAMPLE_PERIOD # Check that the timestamps in the downsampled data are strictly increasing - if not data.index.is_monotonic_increasing: - raise AssertionError( - "Timestamps in downsampled data are not strictly increasing." - ) + assert data.index.is_monotonic_increasing if __name__ == "__main__": From 24582dec07cdf354c78b9576cbb65fbec63efb1d Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 31 Oct 2024 15:06:36 +0000 Subject: [PATCH 076/143] fix: revert changes in `streams.py` --- aeon/dj_pipeline/streams.py | 112 ++++++++++++++++-------------------- 1 file changed, 50 insertions(+), 62 deletions(-) diff --git a/aeon/dj_pipeline/streams.py b/aeon/dj_pipeline/streams.py index 225e7198..8a639f5e 100644 --- a/aeon/dj_pipeline/streams.py +++ b/aeon/dj_pipeline/streams.py @@ -12,7 +12,6 @@ from aeon.schema import schemas as aeon_schemas schema = dj.Schema(get_schema_name("streams")) -logger = dj.logger @schema @@ -189,7 +188,6 @@ def key_source(self): ) def make(self, key): - """Load and insert RfidEvents data stream for a given chunk and RfidReader.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -243,11 +241,11 @@ class SpinnakerVideoSourceVideo(dj.Imported): @property def key_source(self): - """ - Only the combination of Chunk and SpinnakerVideoSource with overlapping time - + Chunk(s) that started after SpinnakerVideoSource install time and ended before SpinnakerVideoSource remove time - + Chunk(s) that started after SpinnakerVideoSource install time for SpinnakerVideoSource that are not yet removed - """ + f""" + Only the combination of Chunk and SpinnakerVideoSource with overlapping time + + Chunk(s) that started after SpinnakerVideoSource install time and ended before SpinnakerVideoSource remove time + + Chunk(s) that started after SpinnakerVideoSource install time for SpinnakerVideoSource that are not yet removed + """ return ( acquisition.Chunk * SpinnakerVideoSource.join(SpinnakerVideoSource.RemovalTime, left=True) @@ -256,7 +254,6 @@ def key_source(self): ) def make(self, key): - """Load and insert Video data stream for a given chunk and SpinnakerVideoSource.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -309,11 +306,11 @@ class UndergroundFeederBeamBreak(dj.Imported): @property def key_source(self): - """ - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + f""" + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -322,7 +319,6 @@ def key_source(self): ) def make(self, key): - """Load and insert BeamBreak data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -375,11 +371,11 @@ class UndergroundFeederDeliverPellet(dj.Imported): @property def key_source(self): - """ - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + f""" + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -388,7 +384,6 @@ def key_source(self): ) def make(self, key): - """Load and insert DeliverPellet data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -443,11 +438,11 @@ class UndergroundFeederDepletionState(dj.Imported): @property def key_source(self): - """ - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + f""" + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -456,7 +451,6 @@ def key_source(self): ) def make(self, key): - """Load and insert DepletionState data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -510,11 +504,11 @@ class UndergroundFeederEncoder(dj.Imported): @property def key_source(self): - """ - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + f""" + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -523,7 +517,6 @@ def key_source(self): ) def make(self, key): - """Load and insert Encoder data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -576,11 +569,11 @@ class UndergroundFeederManualDelivery(dj.Imported): @property def key_source(self): - """ - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + f""" + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -589,7 +582,6 @@ def key_source(self): ) def make(self, key): - """Load and insert ManualDelivery data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -642,11 +634,11 @@ class UndergroundFeederMissedPellet(dj.Imported): @property def key_source(self): - """ - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + f""" + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -655,7 +647,6 @@ def key_source(self): ) def make(self, key): - """Load and insert MissedPellet data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -708,11 +699,11 @@ class UndergroundFeederRetriedDelivery(dj.Imported): @property def key_source(self): - """ - Only the combination of Chunk and UndergroundFeeder with overlapping time - + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time - + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed - """ + f""" + Only the combination of Chunk and UndergroundFeeder with overlapping time + + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed + """ return ( acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) @@ -721,7 +712,6 @@ def key_source(self): ) def make(self, key): - """Load and insert RetriedDelivery data stream for a given chunk and UndergroundFeeder.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -775,11 +765,11 @@ class WeightScaleWeightFiltered(dj.Imported): @property def key_source(self): - """ - Only the combination of Chunk and WeightScale with overlapping time - + Chunk(s) that started after WeightScale install time and ended before WeightScale remove time - + Chunk(s) that started after WeightScale install time for WeightScale that are not yet removed - """ + f""" + Only the combination of Chunk and WeightScale with overlapping time + + Chunk(s) that started after WeightScale install time and ended before WeightScale remove time + + Chunk(s) that started after WeightScale install time for WeightScale that are not yet removed + """ return ( acquisition.Chunk * WeightScale.join(WeightScale.RemovalTime, left=True) & "chunk_start >= weight_scale_install_time" @@ -787,7 +777,6 @@ def key_source(self): ) def make(self, key): - """Load and insert WeightFiltered data stream for a given chunk and WeightScale.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) @@ -841,11 +830,11 @@ class WeightScaleWeightRaw(dj.Imported): @property def key_source(self): - """ - Only the combination of Chunk and WeightScale with overlapping time - + Chunk(s) that started after WeightScale install time and ended before WeightScale remove time - + Chunk(s) that started after WeightScale install time for WeightScale that are not yet removed - """ + f""" + Only the combination of Chunk and WeightScale with overlapping time + + Chunk(s) that started after WeightScale install time and ended before WeightScale remove time + + Chunk(s) that started after WeightScale install time for WeightScale that are not yet removed + """ return ( acquisition.Chunk * WeightScale.join(WeightScale.RemovalTime, left=True) & "chunk_start >= weight_scale_install_time" @@ -853,7 +842,6 @@ def key_source(self): ) def make(self, key): - """Load and insert WeightRaw data stream for a given chunk and WeightScale.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( "chunk_start", "chunk_end" ) From 22d64529a3f8f0843fdd631bb8bf5177281f81bb Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 31 Oct 2024 15:59:27 +0000 Subject: [PATCH 077/143] fix: revert black formatting -> max line length from 88 to 105 --- aeon/analysis/block_plotting.py | 4 +- aeon/dj_pipeline/__init__.py | 8 +- aeon/dj_pipeline/acquisition.py | 84 +-- aeon/dj_pipeline/analysis/block_analysis.py | 552 +++++------------- aeon/dj_pipeline/analysis/visit.py | 45 +- aeon/dj_pipeline/analysis/visit_analysis.py | 134 ++--- .../create_experiment_01.py | 4 +- .../create_experiment_02.py | 5 +- .../create_experiments/create_octagon_1.py | 5 +- .../create_experiments/create_presocial.py | 8 +- .../create_socialexperiment.py | 8 +- .../create_socialexperiment_0.py | 13 +- aeon/dj_pipeline/lab.py | 4 +- aeon/dj_pipeline/populate/process.py | 4 +- aeon/dj_pipeline/populate/worker.py | 4 +- aeon/dj_pipeline/qc.py | 25 +- aeon/dj_pipeline/report.py | 63 +- .../scripts/clone_and_freeze_exp02.py | 3 +- .../scripts/update_timestamps_longblob.py | 12 +- aeon/dj_pipeline/streams.py | 4 +- aeon/dj_pipeline/subject.py | 50 +- aeon/dj_pipeline/tracking.py | 47 +- aeon/dj_pipeline/utils/load_metadata.py | 114 +--- aeon/dj_pipeline/utils/paths.py | 3 +- aeon/dj_pipeline/utils/plotting.py | 133 ++--- aeon/dj_pipeline/utils/streams_maker.py | 31 +- .../dj_pipeline/webapps/sciviz/specsheet.yaml | 2 +- aeon/io/reader.py | 76 +-- aeon/schema/social_03.py | 1 - 29 files changed, 392 insertions(+), 1054 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index d7bbc213..c71cb629 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -37,9 +37,7 @@ def gen_hex_grad(hex_col, vals, min_lightness=0.3): curl_lightness = (lightness * val) + ( min_lightness * (1 - val) ) # get cur lightness relative to `hex_col` - curl_lightness = max( - min(curl_lightness, lightness), min_lightness - ) # set min, max bounds + curl_lightness = max(min(curl_lightness, lightness), min_lightness) # set min, max bounds cur_rgb_col = hls_to_rgb(hue, curl_lightness, saturation) # convert to rgb cur_hex_col = "#{:02x}{:02x}{:02x}".format( *tuple(int(c * 255) for c in cur_rgb_col) diff --git a/aeon/dj_pipeline/__init__.py b/aeon/dj_pipeline/__init__.py index 13f45d7e..f225ab2b 100644 --- a/aeon/dj_pipeline/__init__.py +++ b/aeon/dj_pipeline/__init__.py @@ -17,9 +17,7 @@ db_prefix = dj.config["custom"].get("database.prefix", _default_database_prefix) -repository_config = dj.config["custom"].get( - "repository_config", _default_repository_config -) +repository_config = dj.config["custom"].get("repository_config", _default_repository_config) def get_schema_name(name) -> str: @@ -44,9 +42,7 @@ def fetch_stream(query, drop_pk=True): """ df = (query & "sample_count > 0").fetch(format="frame").reset_index() cols2explode = [ - c - for c in query.heading.secondary_attributes - if query.heading.attributes[c].type == "longblob" + c for c in query.heading.secondary_attributes if query.heading.attributes[c].type == "longblob" ] df = df.explode(column=cols2explode) cols2drop = ["sample_count"] + (query.primary_key if drop_pk else []) diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index 1b371409..8c5056bc 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -167,10 +167,7 @@ def get_data_directories(cls, experiment_key, directory_types=None, as_posix=Fal return [ d for dir_type in directory_types - if ( - d := cls.get_data_directory(experiment_key, dir_type, as_posix=as_posix) - ) - is not None + if (d := cls.get_data_directory(experiment_key, dir_type, as_posix=as_posix)) is not None ] @@ -198,9 +195,7 @@ def ingest_epochs(cls, experiment_name): for i, (_, chunk) in enumerate(all_chunks.iterrows()): chunk_rep_file = pathlib.Path(chunk.path) epoch_dir = pathlib.Path(chunk_rep_file.as_posix().split(device_name)[0]) - epoch_start = datetime.datetime.strptime( - epoch_dir.name, "%Y-%m-%dT%H-%M-%S" - ) + epoch_start = datetime.datetime.strptime(epoch_dir.name, "%Y-%m-%dT%H-%M-%S") # --- insert to Epoch --- epoch_key = {"experiment_name": experiment_name, "epoch_start": epoch_start} @@ -219,15 +214,11 @@ def ingest_epochs(cls, experiment_name): if i > 0: previous_chunk = all_chunks.iloc[i - 1] previous_chunk_path = pathlib.Path(previous_chunk.path) - previous_epoch_dir = pathlib.Path( - previous_chunk_path.as_posix().split(device_name)[0] - ) + previous_epoch_dir = pathlib.Path(previous_chunk_path.as_posix().split(device_name)[0]) previous_epoch_start = datetime.datetime.strptime( previous_epoch_dir.name, "%Y-%m-%dT%H-%M-%S" ) - previous_chunk_end = previous_chunk.name + datetime.timedelta( - hours=io_api.CHUNK_DURATION - ) + previous_chunk_end = previous_chunk.name + datetime.timedelta(hours=io_api.CHUNK_DURATION) previous_epoch_end = min(previous_chunk_end, epoch_start) previous_epoch_key = { "experiment_name": experiment_name, @@ -256,9 +247,7 @@ def ingest_epochs(cls, experiment_name): { **previous_epoch_key, "epoch_end": previous_epoch_end, - "epoch_duration": ( - previous_epoch_end - previous_epoch_start - ).total_seconds() + "epoch_duration": (previous_epoch_end - previous_epoch_start).total_seconds() / 3600, } ) @@ -331,23 +320,17 @@ def make(self, key): experiment_name = key["experiment_name"] devices_schema = getattr( aeon_schemas, - (Experiment.DevicesSchema & {"experiment_name": experiment_name}).fetch1( - "devices_schema_name" - ), + (Experiment.DevicesSchema & {"experiment_name": experiment_name}).fetch1("devices_schema_name"), ) dir_type, epoch_dir = (Epoch & key).fetch1("directory_type", "epoch_dir") data_dir = Experiment.get_data_directory(key, dir_type) metadata_yml_filepath = data_dir / epoch_dir / "Metadata.yml" - epoch_config = extract_epoch_config( - experiment_name, devices_schema, metadata_yml_filepath - ) + epoch_config = extract_epoch_config(experiment_name, devices_schema, metadata_yml_filepath) epoch_config = { **epoch_config, - "metadata_file_path": metadata_yml_filepath.relative_to( - data_dir - ).as_posix(), + "metadata_file_path": metadata_yml_filepath.relative_to(data_dir).as_posix(), } # Insert new entries for streams.DeviceType, streams.Device. @@ -358,20 +341,15 @@ def make(self, key): # Define and instantiate new devices/stream tables under `streams` schema streams_maker.main() # Insert devices' installation/removal/settings - epoch_device_types = ingest_epoch_metadata( - experiment_name, devices_schema, metadata_yml_filepath - ) + epoch_device_types = ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath) self.insert1(key) self.Meta.insert1(epoch_config) - self.DeviceType.insert( - key | {"device_type": n} for n in epoch_device_types or {} - ) + self.DeviceType.insert(key | {"device_type": n} for n in epoch_device_types or {}) with metadata_yml_filepath.open("r") as f: metadata = json.load(f) self.ActiveRegion.insert( - {**key, "region_name": k, "region_data": v} - for k, v in metadata["ActiveRegion"].items() + {**key, "region_name": k, "region_data": v} for k, v in metadata["ActiveRegion"].items() ) @@ -410,9 +388,7 @@ def ingest_chunks(cls, experiment_name): for _, chunk in all_chunks.iterrows(): chunk_rep_file = pathlib.Path(chunk.path) epoch_dir = pathlib.Path(chunk_rep_file.as_posix().split(device_name)[0]) - epoch_start = datetime.datetime.strptime( - epoch_dir.name, "%Y-%m-%dT%H-%M-%S" - ) + epoch_start = datetime.datetime.strptime(epoch_dir.name, "%Y-%m-%dT%H-%M-%S") epoch_key = {"experiment_name": experiment_name, "epoch_start": epoch_start} if not (Epoch & epoch_key): @@ -420,9 +396,7 @@ def ingest_chunks(cls, experiment_name): continue chunk_start = chunk.name - chunk_start = max( - chunk_start, epoch_start - ) # first chunk of the epoch starts at epoch_start + chunk_start = max(chunk_start, epoch_start) # first chunk of the epoch starts at epoch_start chunk_end = chunk_start + datetime.timedelta(hours=io_api.CHUNK_DURATION) if EpochEnd & epoch_key: @@ -442,12 +416,8 @@ def ingest_chunks(cls, experiment_name): ) chunk_starts.append(chunk_key["chunk_start"]) - chunk_list.append( - {**chunk_key, **directory, "chunk_end": chunk_end, **epoch_key} - ) - file_name_list.append( - chunk_rep_file.name - ) # handle duplicated files in different folders + chunk_list.append({**chunk_key, **directory, "chunk_end": chunk_end, **epoch_key}) + file_name_list.append(chunk_rep_file.name) # handle duplicated files in different folders # -- files -- file_datetime_str = chunk_rep_file.stem.replace(f"{device_name}_", "") @@ -564,9 +534,9 @@ def make(self, key): data_dirs = Experiment.get_data_directories(key) devices_schema = getattr( aeon_schemas, - ( - Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) device = devices_schema.Environment @@ -626,14 +596,12 @@ def make(self, key): data_dirs = Experiment.get_data_directories(key) devices_schema = getattr( aeon_schemas, - ( - Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) device = devices_schema.Environment - stream_reader = ( - device.EnvironmentActiveConfiguration - ) # expecting columns: time, name, value + stream_reader = device.EnvironmentActiveConfiguration # expecting columns: time, name, value stream_data = io_api.load( root=data_dirs, reader=stream_reader, @@ -666,9 +634,7 @@ def _get_all_chunks(experiment_name, device_name): raw_data_dirs = {k: v for k, v in raw_data_dirs.items() if v} if not raw_data_dirs: - raise ValueError( - f"No raw data directory found for experiment: {experiment_name}" - ) + raise ValueError(f"No raw data directory found for experiment: {experiment_name}") chunkdata = io_api.load( root=list(raw_data_dirs.values()), @@ -690,9 +656,7 @@ def _match_experiment_directory(experiment_name, path, directories): repo_path = paths.get_repository_path(directory.pop("repository_name")) break else: - raise FileNotFoundError( - f"Unable to identify the directory" f" where this chunk is from: {path}" - ) + raise FileNotFoundError(f"Unable to identify the directory" f" where this chunk is from: {path}") return raw_data_dir, directory, repo_path diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 5cb82c0c..e491bc9e 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -67,18 +67,14 @@ def make(self, key): # find the 0s in `pellet_ct` (these are times when the pellet count reset - i.e. new block) # that would mark the start of a new block - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") exp_key = {"experiment_name": key["experiment_name"]} chunk_restriction = acquisition.create_chunk_restriction( key["experiment_name"], chunk_start, chunk_end ) - block_state_query = ( - acquisition.Environment.BlockState & exp_key & chunk_restriction - ) + block_state_query = acquisition.Environment.BlockState & exp_key & chunk_restriction block_state_df = fetch_stream(block_state_query) if block_state_df.empty: self.insert1(key) @@ -101,12 +97,8 @@ def make(self, key): block_entries = [] if not blocks_df.empty: # calculate block end_times (use due_time) and durations - blocks_df["end_time"] = blocks_df["due_time"].apply( - lambda x: io_api.aeon(x) - ) - blocks_df["duration"] = ( - blocks_df["end_time"] - blocks_df.index - ).dt.total_seconds() / 3600 + blocks_df["end_time"] = blocks_df["due_time"].apply(lambda x: io_api.aeon(x)) + blocks_df["duration"] = (blocks_df["end_time"] - blocks_df.index).dt.total_seconds() / 3600 for _, row in blocks_df.iterrows(): block_entries.append( @@ -195,9 +187,7 @@ def make(self, key): tracking.SLEAPTracking, ) for streams_table in streams_tables: - if len(streams_table & chunk_keys) < len( - streams_table.key_source & chunk_keys - ): + if len(streams_table & chunk_keys) < len(streams_table.key_source & chunk_keys): raise ValueError( f"BlockAnalysis Not Ready - {streams_table.__name__}" f"not yet fully ingested for block: {key}." @@ -208,14 +198,10 @@ def make(self, key): # For wheel data, downsample to 10Hz final_encoder_fs = 10 - maintenance_period = get_maintenance_periods( - key["experiment_name"], block_start, block_end - ) + maintenance_period = get_maintenance_periods(key["experiment_name"], block_start, block_end) patch_query = ( - streams.UndergroundFeeder.join( - streams.UndergroundFeeder.RemovalTime, left=True - ) + streams.UndergroundFeeder.join(streams.UndergroundFeeder.RemovalTime, left=True) & key & f'"{block_start}" >= underground_feeder_install_time' & f'"{block_end}" < IFNULL(underground_feeder_removal_time, "2200-01-01")' @@ -229,14 +215,12 @@ def make(self, key): streams.UndergroundFeederDepletionState & patch_key & chunk_restriction )[block_start:block_end] - pellet_ts_threshold_df = get_threshold_associated_pellets( - patch_key, block_start, block_end - ) + pellet_ts_threshold_df = get_threshold_associated_pellets(patch_key, block_start, block_end) # wheel encoder data - encoder_df = fetch_stream( - streams.UndergroundFeederEncoder & patch_key & chunk_restriction - )[block_start:block_end] + encoder_df = fetch_stream(streams.UndergroundFeederEncoder & patch_key & chunk_restriction)[ + block_start:block_end + ] # filter out maintenance period based on logs pellet_ts_threshold_df = filter_out_maintenance_periods( pellet_ts_threshold_df, @@ -255,13 +239,9 @@ def make(self, key): ) if depletion_state_df.empty: - raise ValueError( - f"No depletion state data found for block {key} - patch: {patch_name}" - ) + raise ValueError(f"No depletion state data found for block {key} - patch: {patch_name}") - encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled( - encoder_df.angle - ) + encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled(encoder_df.angle) if len(depletion_state_df.rate.unique()) > 1: # multiple patch rates per block is unexpected @@ -297,9 +277,7 @@ def make(self, key): "wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values[ ::wheel_downsampling_factor ], - "wheel_timestamps": encoder_df.index.values[ - ::wheel_downsampling_factor - ], + "wheel_timestamps": encoder_df.index.values[::wheel_downsampling_factor], "patch_threshold": pellet_ts_threshold_df.threshold.values, "patch_threshold_timestamps": pellet_ts_threshold_df.index.values, "patch_rate": patch_rate, @@ -331,9 +309,7 @@ def make(self, key): # positions - query for CameraTop, identity_name matches subject_name, pos_query = ( streams.SpinnakerVideoSource - * tracking.SLEAPTracking.PoseIdentity.proj( - "identity_name", part_name="anchor_part" - ) + * tracking.SLEAPTracking.PoseIdentity.proj("identity_name", part_name="anchor_part") * tracking.SLEAPTracking.Part & key & { @@ -343,23 +319,18 @@ def make(self, key): & chunk_restriction ) pos_df = fetch_stream(pos_query)[block_start:block_end] - pos_df = filter_out_maintenance_periods( - pos_df, maintenance_period, block_end - ) + pos_df = filter_out_maintenance_periods(pos_df, maintenance_period, block_end) if pos_df.empty: continue position_diff = np.sqrt( - np.square(np.diff(pos_df.x.astype(float))) - + np.square(np.diff(pos_df.y.astype(float))) + np.square(np.diff(pos_df.x.astype(float))) + np.square(np.diff(pos_df.y.astype(float))) ) cumsum_distance_travelled = np.concatenate([[0], np.cumsum(position_diff)]) # weights - weight_query = ( - acquisition.Environment.SubjectWeight & key & chunk_restriction - ) + weight_query = acquisition.Environment.SubjectWeight & key & chunk_restriction weight_df = fetch_stream(weight_query)[block_start:block_end] weight_df.query(f"subject_id == '{subject_name}'", inplace=True) @@ -447,10 +418,7 @@ def make(self, key): subjects_positions_df = pd.concat( [ pd.DataFrame( - { - "subject_name": [s["subject_name"]] - * len(s["position_timestamps"]) - } + {"subject_name": [s["subject_name"]] * len(s["position_timestamps"])} | { k: s[k] for k in ( @@ -478,8 +446,7 @@ def make(self, key): "cum_pref_time", ] all_subj_patch_pref_dict = { - p: {s: {a: pd.Series() for a in pref_attrs} for s in subject_names} - for p in patch_names + p: {s: {a: pd.Series() for a in pref_attrs} for s in subject_names} for p in patch_names } for patch in block_patches: @@ -502,15 +469,11 @@ def make(self, key): ).fetch1("attribute_value") patch_center = (int(patch_center["X"]), int(patch_center["Y"])) subjects_xy = subjects_positions_df[["position_x", "position_y"]].values - dist_to_patch = np.sqrt( - np.sum((subjects_xy - patch_center) ** 2, axis=1).astype(float) - ) + dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_center) ** 2, axis=1).astype(float)) dist_to_patch_df = subjects_positions_df[["subject_name"]].copy() dist_to_patch_df["dist_to_patch"] = dist_to_patch - dist_to_patch_wheel_ts_id_df = pd.DataFrame( - index=cum_wheel_dist.index, columns=subject_names - ) + dist_to_patch_wheel_ts_id_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subject_names) dist_to_patch_pel_ts_id_df = pd.DataFrame( index=patch["pellet_timestamps"], columns=subject_names ) @@ -518,12 +481,10 @@ def make(self, key): # Find closest match between pose_df indices and wheel indices if not dist_to_patch_wheel_ts_id_df.empty: dist_to_patch_wheel_ts_subj = pd.merge_asof( - left=pd.DataFrame( - dist_to_patch_wheel_ts_id_df[subject_name].copy() - ).reset_index(names="time"), - right=dist_to_patch_df[ - dist_to_patch_df["subject_name"] == subject_name - ] + left=pd.DataFrame(dist_to_patch_wheel_ts_id_df[subject_name].copy()).reset_index( + names="time" + ), + right=dist_to_patch_df[dist_to_patch_df["subject_name"] == subject_name] .copy() .reset_index(names="time"), on="time", @@ -532,18 +493,16 @@ def make(self, key): direction="nearest", tolerance=pd.Timedelta("100ms"), ) - dist_to_patch_wheel_ts_id_df[subject_name] = ( - dist_to_patch_wheel_ts_subj["dist_to_patch"].values - ) + dist_to_patch_wheel_ts_id_df[subject_name] = dist_to_patch_wheel_ts_subj[ + "dist_to_patch" + ].values # Find closest match between pose_df indices and pel indices if not dist_to_patch_pel_ts_id_df.empty: dist_to_patch_pel_ts_subj = pd.merge_asof( - left=pd.DataFrame( - dist_to_patch_pel_ts_id_df[subject_name].copy() - ).reset_index(names="time"), - right=dist_to_patch_df[ - dist_to_patch_df["subject_name"] == subject_name - ] + left=pd.DataFrame(dist_to_patch_pel_ts_id_df[subject_name].copy()).reset_index( + names="time" + ), + right=dist_to_patch_df[dist_to_patch_df["subject_name"] == subject_name] .copy() .reset_index(names="time"), on="time", @@ -552,9 +511,9 @@ def make(self, key): direction="nearest", tolerance=pd.Timedelta("200ms"), ) - dist_to_patch_pel_ts_id_df[subject_name] = ( - dist_to_patch_pel_ts_subj["dist_to_patch"].values - ) + dist_to_patch_pel_ts_id_df[subject_name] = dist_to_patch_pel_ts_subj[ + "dist_to_patch" + ].values # Get closest subject to patch at each pellet timestep closest_subjects_pellet_ts = dist_to_patch_pel_ts_id_df.idxmin(axis=1) @@ -566,12 +525,8 @@ def make(self, key): wheel_dist = cum_wheel_dist.diff().fillna(cum_wheel_dist.iloc[0]) # Assign wheel dist to closest subject for each wheel timestep for subject_name in subject_names: - subj_idxs = cum_wheel_dist_subj_df[ - closest_subjects_wheel_ts == subject_name - ].index - cum_wheel_dist_subj_df.loc[subj_idxs, subject_name] = wheel_dist[ - subj_idxs - ] + subj_idxs = cum_wheel_dist_subj_df[closest_subjects_wheel_ts == subject_name].index + cum_wheel_dist_subj_df.loc[subj_idxs, subject_name] = wheel_dist[subj_idxs] cum_wheel_dist_subj_df = cum_wheel_dist_subj_df.cumsum(axis=0) # In patch time @@ -579,14 +534,14 @@ def make(self, key): dt = np.median(np.diff(cum_wheel_dist.index)).astype(int) / 1e9 # s # Fill in `all_subj_patch_pref` for subject_name in subject_names: - all_subj_patch_pref_dict[patch["patch_name"]][subject_name][ - "cum_dist" - ] = cum_wheel_dist_subj_df[subject_name].values + all_subj_patch_pref_dict[patch["patch_name"]][subject_name]["cum_dist"] = ( + cum_wheel_dist_subj_df[subject_name].values + ) subject_in_patch = in_patch[subject_name] subject_in_patch_cum_time = subject_in_patch.cumsum().values * dt - all_subj_patch_pref_dict[patch["patch_name"]][subject_name][ - "cum_time" - ] = subject_in_patch_cum_time + all_subj_patch_pref_dict[patch["patch_name"]][subject_name]["cum_time"] = ( + subject_in_patch_cum_time + ) closest_subj_mask = closest_subjects_pellet_ts == subject_name subj_pellets = closest_subjects_pellet_ts[closest_subj_mask] @@ -602,9 +557,7 @@ def make(self, key): "pellet_count": len(subj_pellets), "pellet_timestamps": subj_pellets.index.values, "patch_threshold": subj_patch_thresh, - "wheel_cumsum_distance_travelled": cum_wheel_dist_subj_df[ - subject_name - ].values, + "wheel_cumsum_distance_travelled": cum_wheel_dist_subj_df[subject_name].values, } ) @@ -613,77 +566,49 @@ def make(self, key): for subject_name in subject_names: # Get sum of subj cum wheel dists and cum in patch time all_cum_dist = np.sum( - [ - all_subj_patch_pref_dict[p][subject_name]["cum_dist"][-1] - for p in patch_names - ] + [all_subj_patch_pref_dict[p][subject_name]["cum_dist"][-1] for p in patch_names] ) all_cum_time = np.sum( - [ - all_subj_patch_pref_dict[p][subject_name]["cum_time"][-1] - for p in patch_names - ] + [all_subj_patch_pref_dict[p][subject_name]["cum_time"][-1] for p in patch_names] ) CUM_PREF_DIST_MIN = 1e-3 for patch_name in patch_names: cum_pref_dist = ( - all_subj_patch_pref_dict[patch_name][subject_name]["cum_dist"] - / all_cum_dist - ) - cum_pref_dist = np.where( - cum_pref_dist < CUM_PREF_DIST_MIN, 0, cum_pref_dist + all_subj_patch_pref_dict[patch_name][subject_name]["cum_dist"] / all_cum_dist ) - all_subj_patch_pref_dict[patch_name][subject_name][ - "cum_pref_dist" - ] = cum_pref_dist + cum_pref_dist = np.where(cum_pref_dist < CUM_PREF_DIST_MIN, 0, cum_pref_dist) + all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_dist"] = cum_pref_dist cum_pref_time = ( - all_subj_patch_pref_dict[patch_name][subject_name]["cum_time"] - / all_cum_time + all_subj_patch_pref_dict[patch_name][subject_name]["cum_time"] / all_cum_time ) - all_subj_patch_pref_dict[patch_name][subject_name][ - "cum_pref_time" - ] = cum_pref_time + all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_time"] = cum_pref_time # sum pref at each ts across patches for each subject total_dist_pref = np.sum( np.vstack( - [ - all_subj_patch_pref_dict[p][subject_name]["cum_pref_dist"] - for p in patch_names - ] + [all_subj_patch_pref_dict[p][subject_name]["cum_pref_dist"] for p in patch_names] ), axis=0, ) total_time_pref = np.sum( np.vstack( - [ - all_subj_patch_pref_dict[p][subject_name]["cum_pref_time"] - for p in patch_names - ] + [all_subj_patch_pref_dict[p][subject_name]["cum_pref_time"] for p in patch_names] ), axis=0, ) for patch_name in patch_names: - cum_pref_dist = all_subj_patch_pref_dict[patch_name][subject_name][ - "cum_pref_dist" - ] - all_subj_patch_pref_dict[patch_name][subject_name][ - "running_dist_pref" - ] = np.divide( + cum_pref_dist = all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_dist"] + all_subj_patch_pref_dict[patch_name][subject_name]["running_dist_pref"] = np.divide( cum_pref_dist, total_dist_pref, out=np.zeros_like(cum_pref_dist), where=total_dist_pref != 0, ) - cum_pref_time = all_subj_patch_pref_dict[patch_name][subject_name][ - "cum_pref_time" - ] - all_subj_patch_pref_dict[patch_name][subject_name][ - "running_time_pref" - ] = np.divide( + cum_pref_time = all_subj_patch_pref_dict[patch_name][subject_name]["cum_pref_time"] + all_subj_patch_pref_dict[patch_name][subject_name]["running_time_pref"] = np.divide( cum_pref_time, total_time_pref, out=np.zeros_like(cum_pref_time), @@ -695,24 +620,12 @@ def make(self, key): | { "patch_name": p, "subject_name": s, - "cumulative_preference_by_time": all_subj_patch_pref_dict[p][s][ - "cum_pref_time" - ], - "cumulative_preference_by_wheel": all_subj_patch_pref_dict[p][s][ - "cum_pref_dist" - ], - "running_preference_by_time": all_subj_patch_pref_dict[p][s][ - "running_time_pref" - ], - "running_preference_by_wheel": all_subj_patch_pref_dict[p][s][ - "running_dist_pref" - ], - "final_preference_by_time": all_subj_patch_pref_dict[p][s][ - "cum_pref_time" - ][-1], - "final_preference_by_wheel": all_subj_patch_pref_dict[p][s][ - "cum_pref_dist" - ][-1], + "cumulative_preference_by_time": all_subj_patch_pref_dict[p][s]["cum_pref_time"], + "cumulative_preference_by_wheel": all_subj_patch_pref_dict[p][s]["cum_pref_dist"], + "running_preference_by_time": all_subj_patch_pref_dict[p][s]["running_time_pref"], + "running_preference_by_wheel": all_subj_patch_pref_dict[p][s]["running_dist_pref"], + "final_preference_by_time": all_subj_patch_pref_dict[p][s]["cum_pref_time"][-1], + "final_preference_by_wheel": all_subj_patch_pref_dict[p][s]["cum_pref_dist"][-1], } for p, s in itertools.product(patch_names, subject_names) ) @@ -736,9 +649,7 @@ class BlockPatchPlots(dj.Computed): def make(self, key): """Compute and plot various block-level statistics and visualizations.""" # Define subject colors and patch styling for plotting - exp_subject_names = (acquisition.Experiment.Subject & key).fetch( - "subject", order_by="subject" - ) + exp_subject_names = (acquisition.Experiment.Subject & key).fetch("subject", order_by="subject") if not len(exp_subject_names): raise ValueError( "No subjects found in the `acquisition.Experiment.Subject`, missing a manual insert step?." @@ -757,10 +668,7 @@ def make(self, key): # Figure 1 - Patch stats: patch means and pellet threshold boxplots # --- subj_patch_info = ( - ( - BlockSubjectAnalysis.Patch.proj("pellet_timestamps", "patch_threshold") - & key - ) + (BlockSubjectAnalysis.Patch.proj("pellet_timestamps", "patch_threshold") & key) .fetch(format="frame") .reset_index() ) @@ -774,46 +682,28 @@ def make(self, key): ["patch_name", "subject_name", "pellet_timestamps", "patch_threshold"] ] min_subj_patch_info = ( - min_subj_patch_info.explode( - ["pellet_timestamps", "patch_threshold"], ignore_index=True - ) + min_subj_patch_info.explode(["pellet_timestamps", "patch_threshold"], ignore_index=True) .dropna() .reset_index(drop=True) ) # Rename and reindex columns min_subj_patch_info.columns = ["patch", "subject", "time", "threshold"] - min_subj_patch_info = min_subj_patch_info.reindex( - columns=["time", "patch", "threshold", "subject"] - ) + min_subj_patch_info = min_subj_patch_info.reindex(columns=["time", "patch", "threshold", "subject"]) # Add patch mean values and block-normalized delivery times to pellet info n_patches = len(patch_info) - patch_mean_info = pd.DataFrame( - index=np.arange(n_patches), columns=min_subj_patch_info.columns - ) + patch_mean_info = pd.DataFrame(index=np.arange(n_patches), columns=min_subj_patch_info.columns) patch_mean_info["subject"] = "mean" patch_mean_info["patch"] = [d["patch_name"] for d in patch_info] - patch_mean_info["threshold"] = [ - ((1 / d["patch_rate"]) + d["patch_offset"]) for d in patch_info - ] + patch_mean_info["threshold"] = [((1 / d["patch_rate"]) + d["patch_offset"]) for d in patch_info] patch_mean_info["time"] = subj_patch_info["block_start"][0] - min_subj_patch_info_plus = pd.concat( - (patch_mean_info, min_subj_patch_info) - ).reset_index(drop=True) + min_subj_patch_info_plus = pd.concat((patch_mean_info, min_subj_patch_info)).reset_index(drop=True) min_subj_patch_info_plus["norm_time"] = ( - ( - min_subj_patch_info_plus["time"] - - min_subj_patch_info_plus["time"].iloc[0] - ) - / ( - min_subj_patch_info_plus["time"].iloc[-1] - - min_subj_patch_info_plus["time"].iloc[0] - ) + (min_subj_patch_info_plus["time"] - min_subj_patch_info_plus["time"].iloc[0]) + / (min_subj_patch_info_plus["time"].iloc[-1] - min_subj_patch_info_plus["time"].iloc[0]) ).round(3) # Plot it - box_colors = ["#0A0A0A"] + list( - subject_colors_dict.values() - ) # subject colors + mean color + box_colors = ["#0A0A0A"] + list(subject_colors_dict.values()) # subject colors + mean color patch_stats_fig = px.box( min_subj_patch_info_plus.sort_values("patch"), x="patch", @@ -843,9 +733,7 @@ def make(self, key): .dropna() .reset_index(drop=True) ) - weights_block.drop( - columns=["experiment_name", "block_start"], inplace=True, errors="ignore" - ) + weights_block.drop(columns=["experiment_name", "block_start"], inplace=True, errors="ignore") weights_block.rename(columns={"weight_timestamps": "time"}, inplace=True) weights_block.set_index("time", inplace=True) weights_block.sort_index(inplace=True) @@ -869,17 +757,13 @@ def make(self, key): # Figure 3 - Cumulative pellet count: over time, per subject, markered by patch # --- # Create dataframe with cumulative pellet count per subject - cum_pel_ct = ( - min_subj_patch_info_plus.sort_values("time").copy().reset_index(drop=True) - ) + cum_pel_ct = min_subj_patch_info_plus.sort_values("time").copy().reset_index(drop=True) patch_means = cum_pel_ct.loc[0:3][["patch", "threshold"]].rename( columns={"threshold": "mean_thresh"} ) patch_means["mean_thresh"] = patch_means["mean_thresh"].astype(float).round(1) cum_pel_ct = cum_pel_ct.merge(patch_means, on="patch", how="left") - cum_pel_ct = cum_pel_ct[ - ~cum_pel_ct["subject"].str.contains("mean") - ].reset_index(drop=True) + cum_pel_ct = cum_pel_ct[~cum_pel_ct["subject"].str.contains("mean")].reset_index(drop=True) cum_pel_ct = ( cum_pel_ct.groupby("subject", group_keys=False) .apply(lambda group: group.assign(counter=np.arange(len(group)) + 1)) @@ -889,9 +773,7 @@ def make(self, key): make_float_cols = ["threshold", "mean_thresh", "norm_time"] cum_pel_ct[make_float_cols] = cum_pel_ct[make_float_cols].astype(float) cum_pel_ct["patch_label"] = ( - cum_pel_ct["patch"] - + " μ: " - + cum_pel_ct["mean_thresh"].astype(float).round(1).astype(str) + cum_pel_ct["patch"] + " μ: " + cum_pel_ct["mean_thresh"].astype(float).round(1).astype(str) ) cum_pel_ct["norm_thresh_val"] = ( (cum_pel_ct["threshold"] - cum_pel_ct["threshold"].min()) @@ -921,9 +803,7 @@ def make(self, key): mode="markers", marker={ "symbol": patch_markers_dict[patch_grp["patch"].iloc[0]], - "color": gen_hex_grad( - pel_mrkr_col, patch_grp["norm_thresh_val"] - ), + "color": gen_hex_grad(pel_mrkr_col, patch_grp["norm_thresh_val"]), "size": 8, }, name=patch_val, @@ -943,9 +823,7 @@ def make(self, key): cum_pel_per_subject_fig = go.Figure() for id_val, id_grp in cum_pel_ct.groupby("subject"): for patch_val, patch_grp in id_grp.groupby("patch"): - cur_p_mean = patch_means[patch_means["patch"] == patch_val][ - "mean_thresh" - ].values[0] + cur_p_mean = patch_means[patch_means["patch"] == patch_val]["mean_thresh"].values[0] cur_p = patch_val.replace("Patch", "P") cum_pel_per_subject_fig.add_trace( go.Scatter( @@ -960,9 +838,7 @@ def make(self, key): # line=dict(width=2, color=subject_colors_dict[id_val]), marker={ "symbol": patch_markers_dict[patch_val], - "color": gen_hex_grad( - pel_mrkr_col, patch_grp["norm_thresh_val"] - ), + "color": gen_hex_grad(pel_mrkr_col, patch_grp["norm_thresh_val"]), "size": 8, }, name=f"{id_val} - {cur_p} - μ: {cur_p_mean}", @@ -979,9 +855,7 @@ def make(self, key): # Figure 5 - Cumulative wheel distance: over time, per subject-patch # --- # Get wheel timestamps for each patch - wheel_ts = (BlockAnalysis.Patch & key).fetch( - "patch_name", "wheel_timestamps", as_dict=True - ) + wheel_ts = (BlockAnalysis.Patch & key).fetch("patch_name", "wheel_timestamps", as_dict=True) wheel_ts = {d["patch_name"]: d["wheel_timestamps"] for d in wheel_ts} # Get subject patch data subj_wheel_cumsum_dist = (BlockSubjectAnalysis.Patch & key).fetch( @@ -1001,9 +875,7 @@ def make(self, key): for subj in subject_names: for patch_name in patch_names: cur_cum_wheel_dist = subj_wheel_cumsum_dist[(subj, patch_name)] - cur_p_mean = patch_means[patch_means["patch"] == patch_name][ - "mean_thresh" - ].values[0] + cur_p_mean = patch_means[patch_means["patch"] == patch_name]["mean_thresh"].values[0] cur_p = patch_name.replace("Patch", "P") cum_wheel_dist_fig.add_trace( go.Scatter( @@ -1020,10 +892,7 @@ def make(self, key): ) # Add markers for each pellet cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[ - (cum_pel_ct["subject"] == subj) - & (cum_pel_ct["patch"] == patch_name) - ], + cum_pel_ct[(cum_pel_ct["subject"] == subj) & (cum_pel_ct["patch"] == patch_name)], pd.DataFrame( { "time": wheel_ts[patch_name], @@ -1042,15 +911,11 @@ def make(self, key): mode="markers", marker={ "symbol": patch_markers_dict[patch_name], - "color": gen_hex_grad( - pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"] - ), + "color": gen_hex_grad(pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"]), "size": 8, }, name=f"{subj} - {cur_p} pellets", - customdata=np.stack( - (cur_cum_pel_ct["threshold"],), axis=-1 - ), + customdata=np.stack((cur_cum_pel_ct["threshold"],), axis=-1), hovertemplate="Threshold: %{customdata[0]:.2f} cm", ) ) @@ -1064,14 +929,10 @@ def make(self, key): # --- # Get and format a dataframe with preference data patch_pref = (BlockSubjectAnalysis.Preference & key).fetch(format="frame") - patch_pref.reset_index( - level=["experiment_name", "block_start"], drop=True, inplace=True - ) + patch_pref.reset_index(level=["experiment_name", "block_start"], drop=True, inplace=True) # Replace small vals with 0 small_pref_thresh = 1e-3 - patch_pref["cumulative_preference_by_wheel"] = patch_pref[ - "cumulative_preference_by_wheel" - ].apply( + patch_pref["cumulative_preference_by_wheel"] = patch_pref["cumulative_preference_by_wheel"].apply( lambda arr: np.where(np.array(arr) < small_pref_thresh, 0, np.array(arr)) ) @@ -1079,9 +940,7 @@ def calculate_running_preference(group, pref_col, out_col): # Sum pref at each ts total_pref = np.sum(np.vstack(group[pref_col].values), axis=0) # Calculate running pref - group[out_col] = group[pref_col].apply( - lambda x: np.nan_to_num(x / total_pref, 0.0) - ) + group[out_col] = group[pref_col].apply(lambda x: np.nan_to_num(x / total_pref, 0.0)) return group patch_pref = ( @@ -1110,12 +969,8 @@ def calculate_running_preference(group, pref_col, out_col): # Add trace for each subject-patch combo for subj in subject_names: for patch_name in patch_names: - cur_run_wheel_pref = patch_pref.loc[patch_name].loc[subj][ - "running_preference_by_wheel" - ] - cur_p_mean = patch_means[patch_means["patch"] == patch_name][ - "mean_thresh" - ].values[0] + cur_run_wheel_pref = patch_pref.loc[patch_name].loc[subj]["running_preference_by_wheel"] + cur_p_mean = patch_means[patch_means["patch"] == patch_name]["mean_thresh"].values[0] cur_p = patch_name.replace("Patch", "P") running_pref_by_wheel_plot.add_trace( go.Scatter( @@ -1132,10 +987,7 @@ def calculate_running_preference(group, pref_col, out_col): ) # Add markers for each pellet cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[ - (cum_pel_ct["subject"] == subj) - & (cum_pel_ct["patch"] == patch_name) - ], + cum_pel_ct[(cum_pel_ct["subject"] == subj) & (cum_pel_ct["patch"] == patch_name)], pd.DataFrame( { "time": wheel_ts[patch_name], @@ -1154,15 +1006,11 @@ def calculate_running_preference(group, pref_col, out_col): mode="markers", marker={ "symbol": patch_markers_dict[patch_name], - "color": gen_hex_grad( - pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"] - ), + "color": gen_hex_grad(pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"]), "size": 8, }, name=f"{subj} - {cur_p} pellets", - customdata=np.stack( - (cur_cum_pel_ct["threshold"],), axis=-1 - ), + customdata=np.stack((cur_cum_pel_ct["threshold"],), axis=-1), hovertemplate="Threshold: %{customdata[0]:.2f} cm", ) ) @@ -1178,12 +1026,8 @@ def calculate_running_preference(group, pref_col, out_col): # Add trace for each subject-patch combo for subj in subject_names: for patch_name in patch_names: - cur_run_time_pref = patch_pref.loc[patch_name].loc[subj][ - "running_preference_by_time" - ] - cur_p_mean = patch_means[patch_means["patch"] == patch_name][ - "mean_thresh" - ].values[0] + cur_run_time_pref = patch_pref.loc[patch_name].loc[subj]["running_preference_by_time"] + cur_p_mean = patch_means[patch_means["patch"] == patch_name]["mean_thresh"].values[0] cur_p = patch_name.replace("Patch", "P") running_pref_by_patch_fig.add_trace( go.Scatter( @@ -1200,10 +1044,7 @@ def calculate_running_preference(group, pref_col, out_col): ) # Add markers for each pellet cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[ - (cum_pel_ct["subject"] == subj) - & (cum_pel_ct["patch"] == patch_name) - ], + cum_pel_ct[(cum_pel_ct["subject"] == subj) & (cum_pel_ct["patch"] == patch_name)], pd.DataFrame( { "time": wheel_ts[patch_name], @@ -1222,15 +1063,11 @@ def calculate_running_preference(group, pref_col, out_col): mode="markers", marker={ "symbol": patch_markers_dict[patch_name], - "color": gen_hex_grad( - pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"] - ), + "color": gen_hex_grad(pel_mrkr_col, cur_cum_pel_ct["norm_thresh_val"]), "size": 8, }, name=f"{subj} - {cur_p} pellets", - customdata=np.stack( - (cur_cum_pel_ct["threshold"],), axis=-1 - ), + customdata=np.stack((cur_cum_pel_ct["threshold"],), axis=-1), hovertemplate="Threshold: %{customdata[0]:.2f} cm", ) ) @@ -1244,9 +1081,7 @@ def calculate_running_preference(group, pref_col, out_col): # Figure 8 - Weighted patch preference: weighted by 'wheel_dist_spun : pel_ct' ratio # --- # Create multi-indexed dataframe with weighted distance for each subject-patch pair - pel_patches = [ - p for p in patch_names if "dummy" not in p.lower() - ] # exclude dummy patches + pel_patches = [p for p in patch_names if "dummy" not in p.lower()] # exclude dummy patches data = [] for patch in pel_patches: for subject in subject_names: @@ -1259,16 +1094,12 @@ def calculate_running_preference(group, pref_col, out_col): } ) subj_wheel_pel_weighted_dist = pd.DataFrame(data) - subj_wheel_pel_weighted_dist.set_index( - ["patch_name", "subject_name"], inplace=True - ) + subj_wheel_pel_weighted_dist.set_index(["patch_name", "subject_name"], inplace=True) subj_wheel_pel_weighted_dist["weighted_dist"] = np.nan # Calculate weighted distance subject_patch_data = (BlockSubjectAnalysis.Patch() & key).fetch(format="frame") - subject_patch_data.reset_index( - level=["experiment_name", "block_start"], drop=True, inplace=True - ) + subject_patch_data.reset_index(level=["experiment_name", "block_start"], drop=True, inplace=True) subj_wheel_pel_weighted_dist = defaultdict(lambda: defaultdict(dict)) for s in subject_names: for p in pel_patches: @@ -1276,14 +1107,11 @@ def calculate_running_preference(group, pref_col, out_col): cur_wheel_cum_dist_df = pd.DataFrame(columns=["time", "cum_wheel_dist"]) cur_wheel_cum_dist_df["time"] = wheel_ts[p] cur_wheel_cum_dist_df["cum_wheel_dist"] = ( - subject_patch_data.loc[p].loc[s]["wheel_cumsum_distance_travelled"] - + 1 + subject_patch_data.loc[p].loc[s]["wheel_cumsum_distance_travelled"] + 1 ) # Get cumulative pellet count cur_cum_pel_ct = pd.merge_asof( - cum_pel_ct[ - (cum_pel_ct["subject"] == s) & (cum_pel_ct["patch"] == p) - ], + cum_pel_ct[(cum_pel_ct["subject"] == s) & (cum_pel_ct["patch"] == p)], cur_wheel_cum_dist_df.sort_values("time"), on="time", direction="forward", @@ -1302,9 +1130,7 @@ def calculate_running_preference(group, pref_col, out_col): on="time", direction="forward", ) - max_weight = ( - cur_cum_pel_ct.iloc[-1]["counter"] + 1 - ) # for values after last pellet + max_weight = cur_cum_pel_ct.iloc[-1]["counter"] + 1 # for values after last pellet merged_df["counter"] = merged_df["counter"].fillna(max_weight) merged_df["weighted_cum_wheel_dist"] = ( merged_df.groupby("counter") @@ -1315,9 +1141,7 @@ def calculate_running_preference(group, pref_col, out_col): else: weighted_dist = cur_wheel_cum_dist_df["cum_wheel_dist"].values # Assign to dict - subj_wheel_pel_weighted_dist[p][s]["time"] = cur_wheel_cum_dist_df[ - "time" - ].values + subj_wheel_pel_weighted_dist[p][s]["time"] = cur_wheel_cum_dist_df["time"].values subj_wheel_pel_weighted_dist[p][s]["weighted_dist"] = weighted_dist # Convert back to dataframe data = [] @@ -1328,15 +1152,11 @@ def calculate_running_preference(group, pref_col, out_col): "patch_name": p, "subject_name": s, "time": subj_wheel_pel_weighted_dist[p][s]["time"], - "weighted_dist": subj_wheel_pel_weighted_dist[p][s][ - "weighted_dist" - ], + "weighted_dist": subj_wheel_pel_weighted_dist[p][s]["weighted_dist"], } ) subj_wheel_pel_weighted_dist = pd.DataFrame(data) - subj_wheel_pel_weighted_dist.set_index( - ["patch_name", "subject_name"], inplace=True - ) + subj_wheel_pel_weighted_dist.set_index(["patch_name", "subject_name"], inplace=True) # Calculate normalized weighted value def norm_inv_norm(group): @@ -1345,28 +1165,20 @@ def norm_inv_norm(group): inv_norm_dist = 1 / norm_dist inv_norm_dist = inv_norm_dist / (np.sum(inv_norm_dist, axis=0)) # Map each inv_norm_dist back to patch name. - return pd.Series( - inv_norm_dist.tolist(), index=group.index, name="norm_value" - ) + return pd.Series(inv_norm_dist.tolist(), index=group.index, name="norm_value") subj_wheel_pel_weighted_dist["norm_value"] = ( subj_wheel_pel_weighted_dist.groupby("subject_name") .apply(norm_inv_norm) .reset_index(level=0, drop=True) ) - subj_wheel_pel_weighted_dist["wheel_pref"] = patch_pref[ - "running_preference_by_wheel" - ] + subj_wheel_pel_weighted_dist["wheel_pref"] = patch_pref["running_preference_by_wheel"] # Plot it weighted_patch_pref_fig = make_subplots( rows=len(pel_patches), cols=len(subject_names), - subplot_titles=[ - f"{patch} - {subject}" - for patch in pel_patches - for subject in subject_names - ], + subplot_titles=[f"{patch} - {subject}" for patch in pel_patches for subject in subject_names], specs=[[{"secondary_y": True}] * len(subject_names)] * len(pel_patches), shared_xaxes=True, vertical_spacing=0.1, @@ -1548,9 +1360,7 @@ def make(self, key): for id_val, id_grp in centroid_df.groupby("identity_name"): # Add counts of x,y points to a grid that will be used for heatmap img_grid = np.zeros((max_x + 1, max_y + 1)) - points, counts = np.unique( - id_grp[["x", "y"]].values, return_counts=True, axis=0 - ) + points, counts = np.unique(id_grp[["x", "y"]].values, return_counts=True, axis=0) for point, count in zip(points, counts, strict=True): img_grid[point[0], point[1]] = count img_grid /= img_grid.max() # normalize @@ -1559,9 +1369,7 @@ def make(self, key): # so 45 cm/frame ~= 9 px/frame win_sz = 9 # in pixels (ensure odd for centering) kernel = np.ones((win_sz, win_sz)) / win_sz**2 # moving avg kernel - img_grid_p = np.pad( - img_grid, win_sz // 2, mode="edge" - ) # pad for full output from convolution + img_grid_p = np.pad(img_grid, win_sz // 2, mode="edge") # pad for full output from convolution img_grid_smooth = conv2d(img_grid_p, kernel) heatmaps.append((id_val, img_grid_smooth)) @@ -1590,17 +1398,11 @@ def make(self, key): # Figure 3 - Position ethogram # --- # Get Active Region (ROI) locations - epoch_query = acquisition.Epoch & ( - acquisition.Chunk & key & chunk_restriction - ).proj("epoch_start") + epoch_query = acquisition.Epoch & (acquisition.Chunk & key & chunk_restriction).proj("epoch_start") active_region_query = acquisition.EpochConfig.ActiveRegion & epoch_query - roi_locs = dict( - zip(*active_region_query.fetch("region_name", "region_data"), strict=True) - ) + roi_locs = dict(zip(*active_region_query.fetch("region_name", "region_data"), strict=True)) # get RFID reader locations - recent_rfid_query = ( - acquisition.Experiment.proj() * streams.Device.proj() & key - ).aggr( + recent_rfid_query = (acquisition.Experiment.proj() * streams.Device.proj() & key).aggr( streams.RfidReader & f"rfid_reader_install_time <= '{block_start}'", rfid_reader_install_time="max(rfid_reader_install_time)", ) @@ -1638,30 +1440,18 @@ def make(self, key): # For each ROI, compute if within ROI for roi in rois: - if ( - roi == "Corridor" - ): # special case for corridor, based on between inner and outer radius + if roi == "Corridor": # special case for corridor, based on between inner and outer radius dist = np.linalg.norm( (np.vstack((centroid_df["x"], centroid_df["y"])).T) - arena_center, axis=1, ) - pos_eth_df[roi] = (dist >= arena_inner_radius) & ( - dist <= arena_outer_radius - ) + pos_eth_df[roi] = (dist >= arena_inner_radius) & (dist <= arena_outer_radius) elif roi == "Nest": # special case for nest, based on 4 corners nest_corners = roi_locs["NestRegion"]["ArrayOfPoint"] - nest_br_x, nest_br_y = int(nest_corners[0]["X"]), int( - nest_corners[0]["Y"] - ) - nest_bl_x, nest_bl_y = int(nest_corners[1]["X"]), int( - nest_corners[1]["Y"] - ) - nest_tl_x, nest_tl_y = int(nest_corners[2]["X"]), int( - nest_corners[2]["Y"] - ) - nest_tr_x, nest_tr_y = int(nest_corners[3]["X"]), int( - nest_corners[3]["Y"] - ) + nest_br_x, nest_br_y = int(nest_corners[0]["X"]), int(nest_corners[0]["Y"]) + nest_bl_x, nest_bl_y = int(nest_corners[1]["X"]), int(nest_corners[1]["Y"]) + nest_tl_x, nest_tl_y = int(nest_corners[2]["X"]), int(nest_corners[2]["Y"]) + nest_tr_x, nest_tr_y = int(nest_corners[3]["X"]), int(nest_corners[3]["Y"]) pos_eth_df[roi] = ( (centroid_df["x"] <= nest_br_x) & (centroid_df["y"] >= nest_br_y) @@ -1675,13 +1465,10 @@ def make(self, key): else: roi_radius = gate_radius if roi == "Gate" else patch_radius # Get ROI coords - roi_x, roi_y = int(rfid_locs[roi + "Rfid"]["X"]), int( - rfid_locs[roi + "Rfid"]["Y"] - ) + roi_x, roi_y = int(rfid_locs[roi + "Rfid"]["X"]), int(rfid_locs[roi + "Rfid"]["Y"]) # Check if in ROI dist = np.linalg.norm( - (np.vstack((centroid_df["x"], centroid_df["y"])).T) - - (roi_x, roi_y), + (np.vstack((centroid_df["x"], centroid_df["y"])).T) - (roi_x, roi_y), axis=1, ) pos_eth_df[roi] = dist < roi_radius @@ -1814,9 +1601,7 @@ def get_threshold_associated_pellets(patch_key, start, end): - offset - rate """ # noqa 501 - chunk_restriction = acquisition.create_chunk_restriction( - patch_key["experiment_name"], start, end - ) + chunk_restriction = acquisition.create_chunk_restriction(patch_key["experiment_name"], start, end) # Step 1 - fetch data # pellet delivery trigger @@ -1824,9 +1609,9 @@ def get_threshold_associated_pellets(patch_key, start, end): streams.UndergroundFeederDeliverPellet & patch_key & chunk_restriction )[start:end] # beambreak - beambreak_df = fetch_stream( - streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction - )[start:end] + beambreak_df = fetch_stream(streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction)[ + start:end + ] # patch threshold depletion_state_df = fetch_stream( streams.UndergroundFeederDepletionState & patch_key & chunk_restriction @@ -1843,32 +1628,22 @@ def get_threshold_associated_pellets(patch_key, start, end): ) # Step 2 - Remove invalid rows (back-to-back events) - BTB_MIN_TIME_DIFF = ( - 1.2 # pellet delivery trigger - time diff is less than 1.2 seconds - ) + BTB_MIN_TIME_DIFF = 1.2 # pellet delivery trigger - time diff is less than 1.2 seconds BB_MIN_TIME_DIFF = 1.0 # beambreak - time difference is less than 1 seconds PT_MIN_TIME_DIFF = 1.0 # patch threshold - time difference is less than 1 seconds - invalid_rows = ( - delivered_pellet_df.index.to_series().diff().dt.total_seconds() - < BTB_MIN_TIME_DIFF - ) + invalid_rows = delivered_pellet_df.index.to_series().diff().dt.total_seconds() < BTB_MIN_TIME_DIFF delivered_pellet_df = delivered_pellet_df[~invalid_rows] # exclude manual deliveries delivered_pellet_df = delivered_pellet_df.loc[ delivered_pellet_df.index.difference(manual_delivery_df.index) ] - invalid_rows = ( - beambreak_df.index.to_series().diff().dt.total_seconds() < BB_MIN_TIME_DIFF - ) + invalid_rows = beambreak_df.index.to_series().diff().dt.total_seconds() < BB_MIN_TIME_DIFF beambreak_df = beambreak_df[~invalid_rows] depletion_state_df = depletion_state_df.dropna(subset=["threshold"]) - invalid_rows = ( - depletion_state_df.index.to_series().diff().dt.total_seconds() - < PT_MIN_TIME_DIFF - ) + invalid_rows = depletion_state_df.index.to_series().diff().dt.total_seconds() < PT_MIN_TIME_DIFF depletion_state_df = depletion_state_df[~invalid_rows] # Return empty if no data @@ -1891,18 +1666,14 @@ def get_threshold_associated_pellets(patch_key, start, end): .set_index("time") .dropna(subset=["beam_break_timestamp"]) ) - pellet_beam_break_df.drop_duplicates( - subset="beam_break_timestamp", keep="last", inplace=True - ) + pellet_beam_break_df.drop_duplicates(subset="beam_break_timestamp", keep="last", inplace=True) # Find pellet delivery triggers that approximately coincide with each threshold update # i.e. nearest pellet delivery within 100ms before or after threshold update pellet_ts_threshold_df = ( pd.merge_asof( depletion_state_df.reset_index(), - pellet_beam_break_df.reset_index().rename( - columns={"time": "pellet_timestamp"} - ), + pellet_beam_break_df.reset_index().rename(columns={"time": "pellet_timestamp"}), left_on="time", right_on="pellet_timestamp", tolerance=pd.Timedelta("100ms"), @@ -1915,12 +1686,8 @@ def get_threshold_associated_pellets(patch_key, start, end): # Clean up the df pellet_ts_threshold_df = pellet_ts_threshold_df.drop(columns=["event_x", "event_y"]) # Shift back the pellet_timestamp values by 1 to match with the previous threshold update - pellet_ts_threshold_df.pellet_timestamp = ( - pellet_ts_threshold_df.pellet_timestamp.shift(-1) - ) - pellet_ts_threshold_df.beam_break_timestamp = ( - pellet_ts_threshold_df.beam_break_timestamp.shift(-1) - ) + pellet_ts_threshold_df.pellet_timestamp = pellet_ts_threshold_df.pellet_timestamp.shift(-1) + pellet_ts_threshold_df.beam_break_timestamp = pellet_ts_threshold_df.beam_break_timestamp.shift(-1) pellet_ts_threshold_df = pellet_ts_threshold_df.dropna( subset=["pellet_timestamp", "beam_break_timestamp"] ) @@ -1947,12 +1714,8 @@ def get_foraging_bouts( Returns: DataFrame containing foraging bouts. Columns: duration, n_pellets, cum_wheel_dist, subject. """ - max_inactive_time = ( - pd.Timedelta(seconds=60) if max_inactive_time is None else max_inactive_time - ) - bout_data = pd.DataFrame( - columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"] - ) + max_inactive_time = pd.Timedelta(seconds=60) if max_inactive_time is None else max_inactive_time + bout_data = pd.DataFrame(columns=["start", "end", "n_pellets", "cum_wheel_dist", "subject"]) subject_patch_data = (BlockSubjectAnalysis.Patch() & key).fetch(format="frame") if subject_patch_data.empty: return bout_data @@ -1996,52 +1759,34 @@ def get_foraging_bouts( wheel_s_r = pd.Timedelta(wheel_ts[1] - wheel_ts[0], unit="ns") max_inactive_win_len = int(max_inactive_time / wheel_s_r) # Find times when foraging - max_windowed_wheel_vals = ( - patch_spun_df["cum_wheel_dist"].shift(-(max_inactive_win_len - 1)).ffill() - ) - foraging_mask = max_windowed_wheel_vals > ( - patch_spun_df["cum_wheel_dist"] + min_wheel_movement - ) + max_windowed_wheel_vals = patch_spun_df["cum_wheel_dist"].shift(-(max_inactive_win_len - 1)).ffill() + foraging_mask = max_windowed_wheel_vals > (patch_spun_df["cum_wheel_dist"] + min_wheel_movement) # Discretize into foraging bouts - bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + ( - max_inactive_win_len - 1 - ) + bout_start_indxs = np.where(np.diff(foraging_mask, prepend=0) == 1)[0] + (max_inactive_win_len - 1) n_samples_in_1s = int(1 / wheel_s_r.total_seconds()) bout_end_indxs = ( np.where(np.diff(foraging_mask, prepend=0) == -1)[0] + (max_inactive_win_len - 1) + n_samples_in_1s ) - bout_end_indxs[-1] = min( - bout_end_indxs[-1], len(wheel_ts) - 1 - ) # ensure last bout ends in block + bout_end_indxs[-1] = min(bout_end_indxs[-1], len(wheel_ts) - 1) # ensure last bout ends in block # Remove bout that starts at block end if bout_start_indxs[-1] >= len(wheel_ts): bout_start_indxs = bout_start_indxs[:-1] bout_end_indxs = bout_end_indxs[:-1] if len(bout_start_indxs) != len(bout_end_indxs): - raise ValueError( - "Mismatch between the lengths of bout_start_indxs and bout_end_indxs." - ) - bout_durations = ( - wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs] - ).astype( # in seconds + raise ValueError("Mismatch between the lengths of bout_start_indxs and bout_end_indxs.") + bout_durations = (wheel_ts[bout_end_indxs] - wheel_ts[bout_start_indxs]).astype( # in seconds "timedelta64[ns]" - ).astype( - float - ) / 1e9 + ).astype(float) / 1e9 bout_starts_ends = np.array( [ (wheel_ts[start_idx], wheel_ts[end_idx]) - for start_idx, end_idx in zip( - bout_start_indxs, bout_end_indxs, strict=True - ) + for start_idx, end_idx in zip(bout_start_indxs, bout_end_indxs, strict=True) ] ) all_pel_ts = np.sort( - np.concatenate( - [arr for arr in cur_subject_data["pellet_timestamps"] if len(arr) > 0] - ) + np.concatenate([arr for arr in cur_subject_data["pellet_timestamps"] if len(arr) > 0]) ) bout_pellets = np.array( [ @@ -2055,8 +1800,7 @@ def get_foraging_bouts( bout_pellets = bout_pellets[bout_pellets >= min_pellets] bout_cum_wheel_dist = np.array( [ - patch_spun_df.loc[end, "cum_wheel_dist"] - - patch_spun_df.loc[start, "cum_wheel_dist"] + patch_spun_df.loc[end, "cum_wheel_dist"] - patch_spun_df.loc[start, "cum_wheel_dist"] for start, end in bout_starts_ends ] ) diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index a92182d5..6942c5f4 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -76,15 +76,15 @@ class Visit(dj.Part): @property def key_source(self): """Key source for OverlapVisit.""" - return dj.U("experiment_name", "place", "overlap_start") & ( - Visit & VisitEnd - ).proj(overlap_start="visit_start") + return dj.U("experiment_name", "place", "overlap_start") & (Visit & VisitEnd).proj( + overlap_start="visit_start" + ) def make(self, key): """Populate OverlapVisit table with overlapping visits.""" - visit_starts, visit_ends = ( - Visit * VisitEnd & key & {"visit_start": key["overlap_start"]} - ).fetch("visit_start", "visit_end") + visit_starts, visit_ends = (Visit * VisitEnd & key & {"visit_start": key["overlap_start"]}).fetch( + "visit_start", "visit_end" + ) visit_start = min(visit_starts) visit_end = max(visit_ends) @@ -98,9 +98,7 @@ def make(self, key): if len(overlap_query) <= 1: break overlap_visits.extend( - overlap_query.proj(overlap_start=f'"{key["overlap_start"]}"').fetch( - as_dict=True - ) + overlap_query.proj(overlap_start=f'"{key["overlap_start"]}"').fetch(as_dict=True) ) visit_starts, visit_ends = overlap_query.fetch("visit_start", "visit_end") if visit_start == max(visit_starts) and visit_end == max(visit_ends): @@ -114,10 +112,7 @@ def make(self, key): { **key, "overlap_end": visit_end, - "overlap_duration": ( - visit_end - key["overlap_start"] - ).total_seconds() - / 3600, + "overlap_duration": (visit_end - key["overlap_start"]).total_seconds() / 3600, "subject_count": len({v["subject"] for v in overlap_visits}), } ) @@ -209,22 +204,16 @@ def ingest_environment_visits(experiment_names: list | None = None): def get_maintenance_periods(experiment_name, start, end): """Get maintenance periods for the specified experiment and time range.""" # get states from acquisition.Environment.EnvironmentState - chunk_restriction = acquisition.create_chunk_restriction( - experiment_name, start, end - ) + chunk_restriction = acquisition.create_chunk_restriction(experiment_name, start, end) state_query = ( - acquisition.Environment.EnvironmentState - & {"experiment_name": experiment_name} - & chunk_restriction + acquisition.Environment.EnvironmentState & {"experiment_name": experiment_name} & chunk_restriction ) env_state_df = fetch_stream(state_query)[start:end] if env_state_df.empty: return deque([]) env_state_df.reset_index(inplace=True) - env_state_df = env_state_df[ - env_state_df["state"].shift() != env_state_df["state"] - ].reset_index( + env_state_df = env_state_df[env_state_df["state"].shift() != env_state_df["state"]].reset_index( drop=True ) # remove duplicates and keep the first one # An experiment starts with visit start (anything before the first maintenance is experiment) @@ -240,12 +229,8 @@ def get_maintenance_periods(experiment_name, start, end): env_state_df = pd.concat([env_state_df, log_df_end]) env_state_df.reset_index(drop=True, inplace=True) - maintenance_starts = env_state_df.loc[ - env_state_df["state"] == "Maintenance", "time" - ].values - maintenance_ends = env_state_df.loc[ - env_state_df["state"] != "Maintenance", "time" - ].values + maintenance_starts = env_state_df.loc[env_state_df["state"] == "Maintenance", "time"].values + maintenance_ends = env_state_df.loc[env_state_df["state"] != "Maintenance", "time"].values return deque( [ @@ -262,9 +247,7 @@ def filter_out_maintenance_periods(data_df, maintenance_period, end_time, dropna (maintenance_start, maintenance_end) = maint_period[0] if end_time < maintenance_start: # no more maintenance for this date break - maintenance_filter = (data_df.index >= maintenance_start) & ( - data_df.index <= maintenance_end - ) + maintenance_filter = (data_df.index >= maintenance_start) & (data_df.index <= maintenance_end) data_df[maintenance_filter] = np.nan if end_time >= maintenance_end: # remove this range maint_period.popleft() diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 025167db..9149af01 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -93,8 +93,7 @@ def key_source(self): + chunk starts after visit_start and ends before visit_end (or NOW() - i.e. ongoing visits). """ return ( - Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") - * acquisition.Chunk + Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") * acquisition.Chunk & acquisition.SubjectEnterExit & [ "visit_start BETWEEN chunk_start AND chunk_end", @@ -107,9 +106,7 @@ def key_source(self): def make(self, key): """Populate VisitSubjectPosition for each visit.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") # -- Determine the time to start time_slicing in this chunk start_time = ( @@ -177,12 +174,8 @@ def make(self, key): end_time = np.array(end_time, dtype="datetime64[ns]") while time_slice_start < end_time: - time_slice_end = time_slice_start + min( - self._time_slice_duration, end_time - time_slice_start - ) - in_time_slice = np.logical_and( - timestamps >= time_slice_start, timestamps < time_slice_end - ) + time_slice_end = time_slice_start + min(self._time_slice_duration, end_time - time_slice_start) + in_time_slice = np.logical_and(timestamps >= time_slice_start, timestamps < time_slice_end) chunk_time_slices.append( { **key, @@ -209,14 +202,9 @@ def get_position(cls, visit_key=None, subject=None, start=None, end=None): """ if visit_key is not None: if len(Visit & visit_key) != 1: - raise ValueError( - "The `visit_key` must correspond to exactly one Visit." - ) + raise ValueError("The `visit_key` must correspond to exactly one Visit.") start, end = ( - Visit.join(VisitEnd, left=True).proj( - visit_end="IFNULL(visit_end, NOW())" - ) - & visit_key + Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") & visit_key ).fetch1("visit_start", "visit_end") subject = visit_key["subject"] elif all((subject, start, end)): @@ -277,9 +265,7 @@ class FoodPatch(dj.Part): """ # Work on finished visits - key_source = Visit & ( - VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end" - ) + key_source = Visit & (VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end") def make(self, key): """Populate VisitTimeDistribution for each visit.""" @@ -287,9 +273,7 @@ def make(self, key): visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) ) - maintenance_period = get_maintenance_periods( - key["experiment_name"], visit_start, visit_end - ) + maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) for visit_date in visit_dates: day_start = datetime.datetime.combine(visit_date.date(), time.min) @@ -309,16 +293,12 @@ def make(self, key): subject=key["subject"], start=day_start, end=day_end ) # filter out maintenance period based on logs - position = filter_out_maintenance_periods( - position, maintenance_period, day_end - ) + position = filter_out_maintenance_periods(position, maintenance_period, day_end) # filter for objects of the correct size valid_position = (position.area > MIN_AREA) & (position.area < MAX_AREA) position[~valid_position] = np.nan - position.rename( - columns={"position_x": "x", "position_y": "y"}, inplace=True - ) + position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) # in corridor distance_from_center = tracking.compute_distance( position[["x", "y"]], @@ -362,9 +342,9 @@ def make(self, key): in_food_patch_times = [] for food_patch_key in food_patch_keys: # wheel data - food_patch_description = ( - acquisition.ExperimentFoodPatch & food_patch_key - ).fetch1("food_patch_description") + food_patch_description = (acquisition.ExperimentFoodPatch & food_patch_key).fetch1( + "food_patch_description" + ) wheel_data = acquisition.FoodPatchWheel.get_wheel_data( experiment_name=key["experiment_name"], start=pd.Timestamp(day_start), @@ -373,12 +353,10 @@ def make(self, key): using_aeon_io=True, ) # filter out maintenance period based on logs - wheel_data = filter_out_maintenance_periods( - wheel_data, maintenance_period, day_end + wheel_data = filter_out_maintenance_periods(wheel_data, maintenance_period, day_end) + patch_position = (acquisition.ExperimentFoodPatch.Position & food_patch_key).fetch1( + "food_patch_position_x", "food_patch_position_y" ) - patch_position = ( - acquisition.ExperimentFoodPatch.Position & food_patch_key - ).fetch1("food_patch_position_x", "food_patch_position_y") in_patch = tracking.is_position_in_patch( position, patch_position, @@ -433,9 +411,7 @@ class FoodPatch(dj.Part): """ # Work on finished visits - key_source = Visit & ( - VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end" - ) + key_source = Visit & (VisitEnd * VisitSubjectPosition.TimeSlice & "time_slice_end = visit_end") def make(self, key): """Populate VisitSummary for each visit.""" @@ -443,9 +419,7 @@ def make(self, key): visit_dates = pd.date_range( start=pd.Timestamp(visit_start.date()), end=pd.Timestamp(visit_end.date()) ) - maintenance_period = get_maintenance_periods( - key["experiment_name"], visit_start, visit_end - ) + maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) for visit_date in visit_dates: day_start = datetime.datetime.combine(visit_date.date(), time.min) @@ -466,18 +440,12 @@ def make(self, key): subject=key["subject"], start=day_start, end=day_end ) # filter out maintenance period based on logs - position = filter_out_maintenance_periods( - position, maintenance_period, day_end - ) + position = filter_out_maintenance_periods(position, maintenance_period, day_end) # filter for objects of the correct size valid_position = (position.area > MIN_AREA) & (position.area < MAX_AREA) position[~valid_position] = np.nan - position.rename( - columns={"position_x": "x", "position_y": "y"}, inplace=True - ) - position_diff = np.sqrt( - np.square(np.diff(position.x)) + np.square(np.diff(position.y)) - ) + position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) + position_diff = np.sqrt(np.square(np.diff(position.x)) + np.square(np.diff(position.y))) total_distance_travelled = np.nansum(position_diff) # in food patches - loop through all in-use patches during this visit @@ -513,9 +481,9 @@ def make(self, key): dropna=True, ).index.values # wheel data - food_patch_description = ( - acquisition.ExperimentFoodPatch & food_patch_key - ).fetch1("food_patch_description") + food_patch_description = (acquisition.ExperimentFoodPatch & food_patch_key).fetch1( + "food_patch_description" + ) wheel_data = acquisition.FoodPatchWheel.get_wheel_data( experiment_name=key["experiment_name"], start=pd.Timestamp(day_start), @@ -524,9 +492,7 @@ def make(self, key): using_aeon_io=True, ) # filter out maintenance period based on logs - wheel_data = filter_out_maintenance_periods( - wheel_data, maintenance_period, day_end - ) + wheel_data = filter_out_maintenance_periods(wheel_data, maintenance_period, day_end) food_patch_statistics.append( { @@ -534,15 +500,11 @@ def make(self, key): **food_patch_key, "visit_date": visit_date.date(), "pellet_count": len(pellet_events), - "wheel_distance_travelled": wheel_data.distance_travelled.values[ - -1 - ], + "wheel_distance_travelled": wheel_data.distance_travelled.values[-1], } ) - total_pellet_count = np.sum( - [p["pellet_count"] for p in food_patch_statistics] - ) + total_pellet_count = np.sum([p["pellet_count"] for p in food_patch_statistics]) total_wheel_distance_travelled = np.sum( [p["wheel_distance_travelled"] for p in food_patch_statistics] ) @@ -578,10 +540,7 @@ class VisitForagingBout(dj.Computed): # Work on 24/7 experiments key_source = ( - Visit - & VisitSummary - & (VisitEnd & "visit_duration > 24") - & "experiment_name= 'exp0.2-r0'" + Visit & VisitSummary & (VisitEnd & "visit_duration > 24") & "experiment_name= 'exp0.2-r0'" ) * acquisition.ExperimentFoodPatch def make(self, key): @@ -589,17 +548,13 @@ def make(self, key): visit_start, visit_end = (VisitEnd & key).fetch1("visit_start", "visit_end") # get in_patch timestamps - food_patch_description = (acquisition.ExperimentFoodPatch & key).fetch1( - "food_patch_description" - ) + food_patch_description = (acquisition.ExperimentFoodPatch & key).fetch1("food_patch_description") in_patch_times = np.concatenate( - ( - VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & key - ).fetch("in_patch", order_by="visit_date") - ) - maintenance_period = get_maintenance_periods( - key["experiment_name"], visit_start, visit_end + (VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & key).fetch( + "in_patch", order_by="visit_date" + ) ) + maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) in_patch_times = filter_out_maintenance_periods( pd.DataFrame( [[food_patch_description]] * len(in_patch_times), @@ -627,12 +582,8 @@ def make(self, key): .set_index("event_time") ) # TODO: handle multiple retries of pellet delivery - maintenance_period = get_maintenance_periods( - key["experiment_name"], visit_start, visit_end - ) - patch = filter_out_maintenance_periods( - patch, maintenance_period, visit_end, True - ) + maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) + patch = filter_out_maintenance_periods(patch, maintenance_period, visit_end, True) if len(in_patch_times): change_ind = ( @@ -648,9 +599,7 @@ def make(self, key): ts_array = in_patch_times[change_ind[i - 1] : change_ind[i]] wheel_start, wheel_end = ts_array[0], ts_array[-1] - if ( - wheel_start >= wheel_end - ): # skip if timestamps were misaligned or a single timestamp + if wheel_start >= wheel_end: # skip if timestamps were misaligned or a single timestamp continue wheel_data = acquisition.FoodPatchWheel.get_wheel_data( @@ -660,19 +609,14 @@ def make(self, key): patch_name=food_patch_description, using_aeon_io=True, ) - maintenance_period = get_maintenance_periods( - key["experiment_name"], visit_start, visit_end - ) - wheel_data = filter_out_maintenance_periods( - wheel_data, maintenance_period, visit_end, True - ) + maintenance_period = get_maintenance_periods(key["experiment_name"], visit_start, visit_end) + wheel_data = filter_out_maintenance_periods(wheel_data, maintenance_period, visit_end, True) self.insert1( { **key, "bout_start": ts_array[0], "bout_end": ts_array[-1], - "bout_duration": (ts_array[-1] - ts_array[0]) - / np.timedelta64(1, "s"), + "bout_duration": (ts_array[-1] - ts_array[0]) / np.timedelta64(1, "s"), "wheel_distance_travelled": wheel_data.distance_travelled[-1], "pellet_count": len(patch.loc[wheel_start:wheel_end]), } diff --git a/aeon/dj_pipeline/create_experiments/create_experiment_01.py b/aeon/dj_pipeline/create_experiments/create_experiment_01.py index cb66455d..18edb4c3 100644 --- a/aeon/dj_pipeline/create_experiments/create_experiment_01.py +++ b/aeon/dj_pipeline/create_experiments/create_experiment_01.py @@ -255,9 +255,7 @@ def add_arena_setup(): # manually update coordinates of foodpatch and nest patch_coordinates = {"Patch1": (1.13, 1.59, 0), "Patch2": (1.19, 0.50, 0)} - for patch_key in (acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name}).fetch( - "KEY" - ): + for patch_key in (acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name}).fetch("KEY"): patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1("food_patch_description") x, y, z = patch_coordinates[patch] acquisition.ExperimentFoodPatch.Position.update1( diff --git a/aeon/dj_pipeline/create_experiments/create_experiment_02.py b/aeon/dj_pipeline/create_experiments/create_experiment_02.py index f14b0342..10877546 100644 --- a/aeon/dj_pipeline/create_experiments/create_experiment_02.py +++ b/aeon/dj_pipeline/create_experiments/create_experiment_02.py @@ -33,10 +33,7 @@ def create_new_experiment(): skip_duplicates=True, ) acquisition.Experiment.Subject.insert( - [ - {"experiment_name": experiment_name, "subject": s["subject"]} - for s in subject_list - ], + [{"experiment_name": experiment_name, "subject": s["subject"]} for s in subject_list], skip_duplicates=True, ) diff --git a/aeon/dj_pipeline/create_experiments/create_octagon_1.py b/aeon/dj_pipeline/create_experiments/create_octagon_1.py index 98b13b41..edbfdd64 100644 --- a/aeon/dj_pipeline/create_experiments/create_octagon_1.py +++ b/aeon/dj_pipeline/create_experiments/create_octagon_1.py @@ -36,10 +36,7 @@ def create_new_experiment(): skip_duplicates=True, ) acquisition.Experiment.Subject.insert( - [ - {"experiment_name": experiment_name, "subject": s["subject"]} - for s in subject_list - ], + [{"experiment_name": experiment_name, "subject": s["subject"]} for s in subject_list], skip_duplicates=True, ) diff --git a/aeon/dj_pipeline/create_experiments/create_presocial.py b/aeon/dj_pipeline/create_experiments/create_presocial.py index 0a60e59d..7a1ce9a3 100644 --- a/aeon/dj_pipeline/create_experiments/create_presocial.py +++ b/aeon/dj_pipeline/create_experiments/create_presocial.py @@ -12,9 +12,7 @@ def create_new_experiment(): """Create new experiments for presocial0.1.""" lab.Location.insert1({"lab": "SWC", "location": location}, skip_duplicates=True) - acquisition.ExperimentType.insert1( - {"experiment_type": experiment_type}, skip_duplicates=True - ) + acquisition.ExperimentType.insert1({"experiment_type": experiment_type}, skip_duplicates=True) acquisition.Experiment.insert( [ @@ -49,9 +47,7 @@ def create_new_experiment(): "directory_type": "raw", "directory_path": f"aeon/data/raw/{computer}/{experiment_type}", } - for experiment_name, computer in zip( - experiment_names, computers, strict=False - ) + for experiment_name, computer in zip(experiment_names, computers, strict=False) ], skip_duplicates=True, ) diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py index 4b90b018..757166e2 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py @@ -39,9 +39,7 @@ def create_new_social_experiment(experiment_name): "experiment_name": experiment_name, "repository_name": "ceph_aeon", "directory_type": dir_type, - "directory_path": ( - ceph_data_dir / dir_type / machine_name.upper() / exp_name - ) + "directory_path": (ceph_data_dir / dir_type / machine_name.upper() / exp_name) .relative_to(ceph_dir) .as_posix(), "load_order": load_order, @@ -54,9 +52,7 @@ def create_new_social_experiment(experiment_name): new_experiment_entry, skip_duplicates=True, ) - acquisition.Experiment.Directory.insert( - experiment_directories, skip_duplicates=True - ) + acquisition.Experiment.Directory.insert(experiment_directories, skip_duplicates=True) acquisition.Experiment.DevicesSchema.insert1( { "experiment_name": experiment_name, diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py index cc69ced4..3b13a1f3 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py @@ -38,10 +38,7 @@ def create_new_experiment(): skip_duplicates=True, ) acquisition.Experiment.Subject.insert( - [ - {"experiment_name": experiment_name, "subject": s["subject"]} - for s in subject_list - ], + [{"experiment_name": experiment_name, "subject": s["subject"]} for s in subject_list], skip_duplicates=True, ) @@ -97,12 +94,8 @@ def add_arena_setup(): # manually update coordinates of foodpatch and nest patch_coordinates = {"Patch1": (1.13, 1.59, 0), "Patch2": (1.19, 0.50, 0)} - for patch_key in ( - acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name} - ).fetch("KEY"): - patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1( - "food_patch_description" - ) + for patch_key in (acquisition.ExperimentFoodPatch & {"experiment_name": experiment_name}).fetch("KEY"): + patch = (acquisition.ExperimentFoodPatch & patch_key).fetch1("food_patch_description") x, y, z = patch_coordinates[patch] acquisition.ExperimentFoodPatch.Position.update1( { diff --git a/aeon/dj_pipeline/lab.py b/aeon/dj_pipeline/lab.py index b5a4c3c5..141c40cd 100644 --- a/aeon/dj_pipeline/lab.py +++ b/aeon/dj_pipeline/lab.py @@ -85,9 +85,7 @@ class ArenaShape(dj.Lookup): definition = """ arena_shape: varchar(32) """ - contents = zip( - ["square", "circular", "rectangular", "linear", "octagon"], strict=False - ) + contents = zip(["square", "circular", "rectangular", "linear", "octagon"], strict=False) @schema diff --git a/aeon/dj_pipeline/populate/process.py b/aeon/dj_pipeline/populate/process.py index ae02eef2..d3699b25 100644 --- a/aeon/dj_pipeline/populate/process.py +++ b/aeon/dj_pipeline/populate/process.py @@ -81,9 +81,7 @@ def run(**kwargs): try: worker.run() except Exception: - logger.exception( - "action '{}' encountered an exception:".format(kwargs["worker_name"]) - ) + logger.exception("action '{}' encountered an exception:".format(kwargs["worker_name"])) logger.info("Ingestion process ended.") diff --git a/aeon/dj_pipeline/populate/worker.py b/aeon/dj_pipeline/populate/worker.py index 835b6610..92ba5f82 100644 --- a/aeon/dj_pipeline/populate/worker.py +++ b/aeon/dj_pipeline/populate/worker.py @@ -113,6 +113,4 @@ def get_workflow_operation_overview(): """Get the workflow operation overview for the worker schema.""" from datajoint_utilities.dj_worker.utils import get_workflow_operation_overview - return get_workflow_operation_overview( - worker_schema_name=worker_schema_name, db_prefixes=[db_prefix] - ) + return get_workflow_operation_overview(worker_schema_name=worker_schema_name, db_prefixes=[db_prefix]) diff --git a/aeon/dj_pipeline/qc.py b/aeon/dj_pipeline/qc.py index 50f493fe..cf357c0f 100644 --- a/aeon/dj_pipeline/qc.py +++ b/aeon/dj_pipeline/qc.py @@ -60,9 +60,7 @@ def key_source(self): return ( acquisition.Chunk * ( - streams.SpinnakerVideoSource.join( - streams.SpinnakerVideoSource.RemovalTime, left=True - ) + streams.SpinnakerVideoSource.join(streams.SpinnakerVideoSource.RemovalTime, left=True) & "spinnaker_video_source_name='CameraTop'" ) & "chunk_start >= spinnaker_video_source_install_time" @@ -71,21 +69,16 @@ def key_source(self): def make(self, key): """Perform quality control checks on the CameraTop stream.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - device_name = (streams.SpinnakerVideoSource & key).fetch1( - "spinnaker_video_source_name" - ) + device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name") data_dirs = acquisition.Experiment.get_data_directories(key) devices_schema = getattr( acquisition.aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(devices_schema, device_name).Video @@ -112,11 +105,9 @@ def make(self, key): **key, "drop_count": deltas.frame_offset.iloc[-1], "max_harp_delta": deltas.time_delta.max().total_seconds(), - "max_camera_delta": deltas.hw_timestamp_delta.max() - / 1e9, # convert to seconds + "max_camera_delta": deltas.hw_timestamp_delta.max() / 1e9, # convert to seconds "timestamps": videodata.index.values, - "time_delta": deltas.time_delta.values - / np.timedelta64(1, "s"), # convert to seconds + "time_delta": deltas.time_delta.values / np.timedelta64(1, "s"), # convert to seconds "frame_delta": deltas.frame_delta.values, "hw_counter_delta": deltas.hw_counter_delta.values, "hw_timestamp_delta": deltas.hw_timestamp_delta.values, diff --git a/aeon/dj_pipeline/report.py b/aeon/dj_pipeline/report.py index dc377176..6b1bce02 100644 --- a/aeon/dj_pipeline/report.py +++ b/aeon/dj_pipeline/report.py @@ -31,9 +31,7 @@ class InArenaSummaryPlot(dj.Computed): summary_plot_png: attach """ - key_source = ( - analysis.InArena & analysis.InArenaTimeDistribution & analysis.InArenaSummary - ) + key_source = analysis.InArena & analysis.InArenaTimeDistribution & analysis.InArenaSummary color_code = { "Patch1": "b", @@ -45,17 +43,15 @@ class InArenaSummaryPlot(dj.Computed): def make(self, key): """Make method for InArenaSummaryPlot table.""" - in_arena_start, in_arena_end = ( - analysis.InArena * analysis.InArenaEnd & key - ).fetch1("in_arena_start", "in_arena_end") + in_arena_start, in_arena_end = (analysis.InArena * analysis.InArenaEnd & key).fetch1( + "in_arena_start", "in_arena_end" + ) # subject's position data in the time_slices position = analysis.InArenaSubjectPosition.get_position(key) position.rename(columns={"position_x": "x", "position_y": "y"}, inplace=True) - position_minutes_elapsed = ( - position.index - in_arena_start - ).total_seconds() / 60 + position_minutes_elapsed = (position.index - in_arena_start).total_seconds() / 60 # figure fig = plt.figure(figsize=(20, 9)) @@ -70,16 +66,12 @@ def make(self, key): # position plot non_nan = np.logical_and(~np.isnan(position.x), ~np.isnan(position.y)) - analysis_plotting.heatmap( - position[non_nan], 50, ax=position_ax, bins=500, alpha=0.5 - ) + analysis_plotting.heatmap(position[non_nan], 50, ax=position_ax, bins=500, alpha=0.5) # event rate plots in_arena_food_patches = ( analysis.InArena - * acquisition.ExperimentFoodPatch.join( - acquisition.ExperimentFoodPatch.RemovalTime, left=True - ) + * acquisition.ExperimentFoodPatch.join(acquisition.ExperimentFoodPatch.RemovalTime, left=True) & key & "in_arena_start >= food_patch_install_time" & 'in_arena_start < IFNULL(food_patch_remove_time, "2200-01-01")' @@ -146,9 +138,7 @@ def make(self, key): color=self.color_code[food_patch_key["food_patch_description"]], alpha=0.3, ) - threshold_change_ind = np.where( - wheel_threshold[:-1] != wheel_threshold[1:] - )[0] + threshold_change_ind = np.where(wheel_threshold[:-1] != wheel_threshold[1:])[0] threshold_ax.vlines( wheel_time[threshold_change_ind + 1], ymin=wheel_threshold[threshold_change_ind], @@ -160,20 +150,17 @@ def make(self, key): ) # ethogram - in_arena, in_corridor, arena_time, corridor_time = ( - analysis.InArenaTimeDistribution & key - ).fetch1( + in_arena, in_corridor, arena_time, corridor_time = (analysis.InArenaTimeDistribution & key).fetch1( "in_arena", "in_corridor", "time_fraction_in_arena", "time_fraction_in_corridor", ) - nest_keys, in_nests, nests_times = ( - analysis.InArenaTimeDistribution.Nest & key - ).fetch("KEY", "in_nest", "time_fraction_in_nest") + nest_keys, in_nests, nests_times = (analysis.InArenaTimeDistribution.Nest & key).fetch( + "KEY", "in_nest", "time_fraction_in_nest" + ) patch_names, in_patches, patches_times = ( - analysis.InArenaTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch - & key + analysis.InArenaTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & key ).fetch("food_patch_description", "in_patch", "time_fraction_in_patch") ethogram_ax.plot( @@ -204,9 +191,7 @@ def make(self, key): alpha=0.6, label="nest", ) - for patch_idx, (patch_name, in_patch) in enumerate( - zip(patch_names, in_patches, strict=False) - ): + for patch_idx, (patch_name, in_patch) in enumerate(zip(patch_names, in_patches, strict=False)): ethogram_ax.plot( position_minutes_elapsed[in_patch], np.full_like(position_minutes_elapsed[in_patch], (patch_idx + 3)), @@ -247,9 +232,7 @@ def make(self, key): rate_ax.set_title("foraging rate (bin size = 10 min)") distance_ax.set_ylabel("distance travelled (m)") threshold_ax.set_ylabel("threshold") - threshold_ax.set_ylim( - [threshold_ax.get_ylim()[0] - 100, threshold_ax.get_ylim()[1] + 100] - ) + threshold_ax.set_ylim([threshold_ax.get_ylim()[0] - 100, threshold_ax.get_ylim()[1] + 100]) ethogram_ax.set_xlabel("time (min)") analysis_plotting.set_ymargin(distance_ax, 0.2, 0.1) for ax in (rate_ax, distance_ax, pellet_ax, time_dist_ax, threshold_ax): @@ -278,9 +261,7 @@ def make(self, key): # ---- Save fig and insert ---- save_dir = _make_path(key) - fig_dict = _save_figs( - (fig,), ("summary_plot_png",), save_dir=save_dir, prefix=save_dir.name - ) + fig_dict = _save_figs((fig,), ("summary_plot_png",), save_dir=save_dir, prefix=save_dir.name) self.insert1({**key, **fig_dict}) @@ -468,10 +449,7 @@ class VisitDailySummaryPlot(dj.Computed): """ key_source = ( - Visit - & analysis.VisitSummary - & (VisitEnd & "visit_duration > 24") - & "experiment_name= 'exp0.2-r0'" + Visit & analysis.VisitSummary & (VisitEnd & "visit_duration > 24") & "experiment_name= 'exp0.2-r0'" ) def make(self, key): @@ -580,12 +558,7 @@ def _make_path(in_arena_key): experiment_name, subject, in_arena_start = (analysis.InArena & in_arena_key).fetch1( "experiment_name", "subject", "in_arena_start" ) - output_dir = ( - store_stage - / experiment_name - / subject - / in_arena_start.strftime("%y%m%d_%H%M%S_%f") - ) + output_dir = store_stage / experiment_name / subject / in_arena_start.strftime("%y%m%d_%H%M%S_%f") output_dir.mkdir(parents=True, exist_ok=True) return output_dir diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py index 5ba5e5ce..66050cd3 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py @@ -108,8 +108,7 @@ def validate(): target_entry_count = len(target_tbl()) missing_entries[orig_schema_name][source_tbl.__name__] = { "entry_count_diff": source_entry_count - target_entry_count, - "db_size_diff": source_tbl().size_on_disk - - target_tbl().size_on_disk, + "db_size_diff": source_tbl().size_on_disk - target_tbl().size_on_disk, } return { diff --git a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py index 47aa6134..31ee8109 100644 --- a/aeon/dj_pipeline/scripts/update_timestamps_longblob.py +++ b/aeon/dj_pipeline/scripts/update_timestamps_longblob.py @@ -10,9 +10,7 @@ if dj.__version__ < "0.13.7": - raise ImportError( - f"DataJoint version must be at least 0.13.7, but found {dj.__version__}." - ) + raise ImportError(f"DataJoint version must be at least 0.13.7, but found {dj.__version__}.") schema = dj.schema("u_thinh_aeonfix") @@ -40,13 +38,7 @@ def main(): for schema_name in schema_names: vm = dj.create_virtual_module(schema_name, schema_name) table_names = [ - ".".join( - [ - dj.utils.to_camel_case(s) - for s in tbl_name.strip("`").split("__") - if s - ] - ) + ".".join([dj.utils.to_camel_case(s) for s in tbl_name.strip("`").split("__") if s]) for tbl_name in vm.schema.list_tables() ] for table_name in table_names: diff --git a/aeon/dj_pipeline/streams.py b/aeon/dj_pipeline/streams.py index 8a639f5e..8fdd1f1e 100644 --- a/aeon/dj_pipeline/streams.py +++ b/aeon/dj_pipeline/streams.py @@ -2,11 +2,11 @@ # ---- THIS FILE IS AUTO-GENERATED BY `streams_maker.py` ---- import re -import datajoint as dj -import pandas as pd from uuid import UUID import aeon +import datajoint as dj +import pandas as pd from aeon.dj_pipeline import acquisition, get_schema_name from aeon.io import api as io_api from aeon.schema import schemas as aeon_schemas diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index 6bb69214..135c7077 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -86,9 +86,7 @@ def make(self, key): ) return elif len(animal_resp) > 1: - raise ValueError( - f"Found {len(animal_resp)} with eartag {eartag_or_id}, expect one" - ) + raise ValueError(f"Found {len(animal_resp)} with eartag {eartag_or_id}, expect one") else: animal_resp = animal_resp[0] @@ -187,21 +185,17 @@ def get_reference_weight(cls, subject_name): """Get the reference weight for the subject.""" subj_key = {"subject": subject_name} - food_restrict_query = ( - SubjectProcedure & subj_key & "procedure_name = 'R02 - food restriction'" - ) + food_restrict_query = SubjectProcedure & subj_key & "procedure_name = 'R02 - food restriction'" if food_restrict_query: - ref_date = food_restrict_query.fetch( - "procedure_date", order_by="procedure_date DESC", limit=1 - )[0] + ref_date = food_restrict_query.fetch("procedure_date", order_by="procedure_date DESC", limit=1)[ + 0 + ] else: ref_date = datetime.now(timezone.utc).date() weight_query = SubjectWeight & subj_key & f"weight_time < '{ref_date}'" ref_weight = ( - weight_query.fetch("weight", order_by="weight_time DESC", limit=1)[0] - if weight_query - else -1 + weight_query.fetch("weight", order_by="weight_time DESC", limit=1)[0] if weight_query else -1 ) entry = { @@ -259,9 +253,7 @@ def _auto_schedule(self): ): return - PyratIngestionTask.insert1( - {"pyrat_task_scheduled_time": next_task_schedule_time} - ) + PyratIngestionTask.insert1({"pyrat_task_scheduled_time": next_task_schedule_time}) def make(self, key): """Automatically import or update entries in the Subject table.""" @@ -269,15 +261,11 @@ def make(self, key): new_eartags = [] for responsible_id in lab.User.fetch("responsible_id"): # 1 - retrieve all animals from this user - animal_resp = get_pyrat_data( - endpoint="animals", params={"responsible_id": responsible_id} - ) + animal_resp = get_pyrat_data(endpoint="animals", params={"responsible_id": responsible_id}) for animal_entry in animal_resp: # 2 - find animal with comment - Project Aeon eartag_or_id = animal_entry["eartag_or_id"] - comment_resp = get_pyrat_data( - endpoint=f"animals/{eartag_or_id}/comments" - ) + comment_resp = get_pyrat_data(endpoint=f"animals/{eartag_or_id}/comments") for comment in comment_resp: if comment["attributes"]: first_attr = comment["attributes"][0] @@ -306,9 +294,7 @@ def make(self, key): { **key, "execution_time": execution_time, - "execution_duration": ( - completion_time - execution_time - ).total_seconds(), + "execution_duration": (completion_time - execution_time).total_seconds(), "new_pyrat_entry_count": new_entry_count, } ) @@ -354,9 +340,7 @@ def make(self, key): for cmt in comment_resp: cmt["subject"] = eartag_or_id cmt["attributes"] = json.dumps(cmt["attributes"], default=str) - SubjectComment.insert( - comment_resp, skip_duplicates=True, allow_direct_insert=True - ) + SubjectComment.insert(comment_resp, skip_duplicates=True, allow_direct_insert=True) weight_resp = get_pyrat_data(endpoint=f"animals/{eartag_or_id}/weights") SubjectWeight.insert( @@ -365,9 +349,7 @@ def make(self, key): allow_direct_insert=True, ) - procedure_resp = get_pyrat_data( - endpoint=f"animals/{eartag_or_id}/procedures" - ) + procedure_resp = get_pyrat_data(endpoint=f"animals/{eartag_or_id}/procedures") SubjectProcedure.insert( [{**v, "subject": eartag_or_id} for v in procedure_resp], skip_duplicates=True, @@ -382,9 +364,7 @@ def make(self, key): { **key, "execution_time": execution_time, - "execution_duration": ( - completion_time - execution_time - ).total_seconds(), + "execution_duration": (completion_time - execution_time).total_seconds(), } ) @@ -397,9 +377,7 @@ class CreatePyratIngestionTask(dj.Computed): def make(self, key): """Create one new PyratIngestionTask for every newly added users.""" - PyratIngestionTask.insert1( - {"pyrat_task_scheduled_time": datetime.now(timezone.utc)} - ) + PyratIngestionTask.insert1({"pyrat_task_scheduled_time": datetime.now(timezone.utc)}) time.sleep(1) self.insert1(key) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 9fca07e1..56424a7d 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -80,18 +80,14 @@ def insert_new_params( ): """Insert a new set of parameters for a given tracking method.""" if tracking_paramset_id is None: - tracking_paramset_id = ( - dj.U().aggr(cls, n="max(tracking_paramset_id)").fetch1("n") or 0 - ) + 1 + tracking_paramset_id = (dj.U().aggr(cls, n="max(tracking_paramset_id)").fetch1("n") or 0) + 1 param_dict = { "tracking_method": tracking_method, "tracking_paramset_id": tracking_paramset_id, "paramset_description": paramset_description, "params": params, - "param_set_hash": dict_to_uuid( - {**params, "tracking_method": tracking_method} - ), + "param_set_hash": dict_to_uuid({**params, "tracking_method": tracking_method}), } param_query = cls & {"param_set_hash": param_dict["param_set_hash"]} @@ -161,9 +157,7 @@ def key_source(self): return ( acquisition.Chunk * ( - streams.SpinnakerVideoSource.join( - streams.SpinnakerVideoSource.RemovalTime, left=True - ) + streams.SpinnakerVideoSource.join(streams.SpinnakerVideoSource.RemovalTime, left=True) & "spinnaker_video_source_name='CameraTop'" ) * (TrackingParamSet & "tracking_paramset_id = 1") @@ -173,22 +167,17 @@ def key_source(self): def make(self, key): """Ingest SLEAP tracking data for a given chunk.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) - device_name = (streams.SpinnakerVideoSource & key).fetch1( - "spinnaker_video_source_name" - ) + device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name") devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) stream_reader = getattr(devices_schema, device_name).Pose @@ -200,9 +189,7 @@ def make(self, key): ) if not len(pose_data): - raise ValueError( - f"No SLEAP data found for {key['experiment_name']} - {device_name}" - ) + raise ValueError(f"No SLEAP data found for {key['experiment_name']} - {device_name}") # get identity names class_names = np.unique(pose_data.identity) @@ -235,9 +222,7 @@ def make(self, key): if part == anchor_part: identity_likelihood = part_position.identity_likelihood.values if isinstance(identity_likelihood[0], dict): - identity_likelihood = np.array( - [v[identity] for v in identity_likelihood] - ) + identity_likelihood = np.array([v[identity] for v in identity_likelihood]) pose_identity_entries.append( { @@ -316,9 +301,7 @@ def _get_position( start_query = table & obj_restriction & start_restriction end_query = table & obj_restriction & end_restriction if not (start_query and end_query): - raise ValueError( - f"No position data found for {object_name} between {start} and {end}" - ) + raise ValueError(f"No position data found for {object_name} between {start} and {end}") time_restriction = ( f'{start_attr} >= "{min(start_query.fetch(start_attr))}"' @@ -326,14 +309,10 @@ def _get_position( ) # subject's position data in the time slice - fetched_data = (table & obj_restriction & time_restriction).fetch( - *fetch_attrs, order_by=start_attr - ) + fetched_data = (table & obj_restriction & time_restriction).fetch(*fetch_attrs, order_by=start_attr) if not len(fetched_data[0]): - raise ValueError( - f"No position data found for {object_name} between {start} and {end}" - ) + raise ValueError(f"No position data found for {object_name} between {start} and {end}") timestamp_attr = next(attr for attr in fetch_attrs if "timestamps" in attr) diff --git a/aeon/dj_pipeline/utils/load_metadata.py b/aeon/dj_pipeline/utils/load_metadata.py index 1d392c80..a4792f6b 100644 --- a/aeon/dj_pipeline/utils/load_metadata.py +++ b/aeon/dj_pipeline/utils/load_metadata.py @@ -40,9 +40,7 @@ def insert_stream_types(): else: # If the existed stream type does not have the same name: # human error, trying to add the same content with different name - raise dj.DataJointError( - f"The specified stream type already exists - name: {pname}" - ) + raise dj.DataJointError(f"The specified stream type already exists - name: {pname}") else: streams.StreamType.insert1(entry) @@ -57,9 +55,7 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): streams = dj.VirtualModule("streams", streams_maker.schema_name) device_info: dict[dict] = get_device_info(devices_schema) - device_type_mapper, device_sn = get_device_mapper( - devices_schema, metadata_yml_filepath - ) + device_type_mapper, device_sn = get_device_mapper(devices_schema, metadata_yml_filepath) # Add device type to device_info. Only add if device types that are defined in Metadata.yml device_info = { @@ -96,8 +92,7 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): {"device_type": device_type, "stream_type": stream_type} for device_type, stream_list in device_stream_map.items() for stream_type in stream_list - if not streams.DeviceType.Stream - & {"device_type": device_type, "stream_type": stream_type} + if not streams.DeviceType.Stream & {"device_type": device_type, "stream_type": stream_type} ] new_devices = [ @@ -106,8 +101,7 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): "device_type": device_config["device_type"], } for device_name, device_config in device_info.items() - if device_sn[device_name] - and not streams.Device & {"device_serial_number": device_sn[device_name]} + if device_sn[device_name] and not streams.Device & {"device_serial_number": device_sn[device_name]} ] # Insert new entries. @@ -125,9 +119,7 @@ def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): streams.Device.insert(new_devices) -def extract_epoch_config( - experiment_name: str, devices_schema: DotMap, metadata_yml_filepath: str -) -> dict: +def extract_epoch_config(experiment_name: str, devices_schema: DotMap, metadata_yml_filepath: str) -> dict: """Parse experiment metadata YAML file and extract epoch configuration. Args: @@ -139,9 +131,7 @@ def extract_epoch_config( dict: epoch_config [dict] """ metadata_yml_filepath = pathlib.Path(metadata_yml_filepath) - epoch_start = datetime.datetime.strptime( - metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S" - ) + epoch_start = datetime.datetime.strptime(metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S") epoch_config: dict = ( io_api.load( metadata_yml_filepath.parent.as_posix(), @@ -156,22 +146,16 @@ def extract_epoch_config( commit = epoch_config["metadata"]["Revision"] if not commit: - raise ValueError( - f'Neither "Commit" nor "Revision" found in {metadata_yml_filepath}' - ) + raise ValueError(f'Neither "Commit" nor "Revision" found in {metadata_yml_filepath}') devices: list[dict] = json.loads( - json.dumps( - epoch_config["metadata"]["Devices"], default=lambda x: x.__dict__, indent=4 - ) + json.dumps(epoch_config["metadata"]["Devices"], default=lambda x: x.__dict__, indent=4) ) # Maintain backward compatibility - In exp02, it is a list of dict. # From presocial onward, it's a dict of dict. if isinstance(devices, list): - devices: dict = { - d.pop("Name"): d for d in devices - } # {deivce_name: device_config} + devices: dict = {d.pop("Name"): d for d in devices} # {deivce_name: device_config} return { "experiment_name": experiment_name, @@ -195,17 +179,15 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath experiment_key = {"experiment_name": experiment_name} metadata_yml_filepath = pathlib.Path(metadata_yml_filepath) - epoch_config = extract_epoch_config( - experiment_name, devices_schema, metadata_yml_filepath - ) + epoch_config = extract_epoch_config(experiment_name, devices_schema, metadata_yml_filepath) previous_epoch = (acquisition.Experiment & experiment_key).aggr( acquisition.Epoch & f'epoch_start < "{epoch_config["epoch_start"]}"', epoch_start="MAX(epoch_start)", ) - if len(acquisition.EpochConfig.Meta & previous_epoch) and epoch_config[ - "commit" - ] == (acquisition.EpochConfig.Meta & previous_epoch).fetch1("commit"): + if len(acquisition.EpochConfig.Meta & previous_epoch) and epoch_config["commit"] == ( + acquisition.EpochConfig.Meta & previous_epoch + ).fetch1("commit"): # if identical commit -> no changes return @@ -239,9 +221,7 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath table_entry = { "experiment_name": experiment_name, **device_key, - f"{dj.utils.from_camel_case(table.__name__)}_install_time": epoch_config[ - "epoch_start" - ], + f"{dj.utils.from_camel_case(table.__name__)}_install_time": epoch_config["epoch_start"], f"{dj.utils.from_camel_case(table.__name__)}_name": device_name, } @@ -258,9 +238,7 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath { **table_entry, "attribute_name": "SamplingFrequency", - "attribute_value": video_controller[ - device_config["TriggerFrequency"] - ], + "attribute_value": video_controller[device_config["TriggerFrequency"]], } ) @@ -269,14 +247,10 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath If the same device serial number is currently installed check for changes in configuration. If not, skip this. """ - current_device_query = ( - table - table.RemovalTime & experiment_key & device_key - ) + current_device_query = table - table.RemovalTime & experiment_key & device_key if current_device_query: - current_device_config: list[dict] = ( - table.Attribute & current_device_query - ).fetch( + current_device_config: list[dict] = (table.Attribute & current_device_query).fetch( "experiment_name", "device_serial_number", "attribute_name", @@ -284,11 +258,7 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath as_dict=True, ) new_device_config: list[dict] = [ - { - k: v - for k, v in entry.items() - if dj.utils.from_camel_case(table.__name__) not in k - } + {k: v for k, v in entry.items() if dj.utils.from_camel_case(table.__name__) not in k} for entry in table_attribute_entry ] @@ -298,10 +268,7 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath for config in current_device_config } ) == dict_to_uuid( - { - config["attribute_name"]: config["attribute_value"] - for config in new_device_config - } + {config["attribute_name"]: config["attribute_value"] for config in new_device_config} ): # Skip if none of the configuration has changed. continue @@ -419,14 +386,10 @@ def _get_class_path(obj): "aeon.schema.social", ]: device_info[device_name]["stream_type"].append(stream_type) - device_info[device_name]["stream_reader"].append( - _get_class_path(stream_obj) - ) + device_info[device_name]["stream_reader"].append(_get_class_path(stream_obj)) required_args = [ - k - for k in inspect.signature(stream_obj.__init__).parameters - if k != "self" + k for k in inspect.signature(stream_obj.__init__).parameters if k != "self" ] pattern = schema_dict[device_name][stream_type].get("pattern") schema_dict[device_name][stream_type]["pattern"] = pattern.replace( @@ -434,35 +397,23 @@ def _get_class_path(obj): ) kwargs = { - k: v - for k, v in schema_dict[device_name][stream_type].items() - if k in required_args + k: v for k, v in schema_dict[device_name][stream_type].items() if k in required_args } device_info[device_name]["stream_reader_kwargs"].append(kwargs) # Add hash device_info[device_name]["stream_hash"].append( - dict_to_uuid( - {**kwargs, "stream_reader": _get_class_path(stream_obj)} - ) + dict_to_uuid({**kwargs, "stream_reader": _get_class_path(stream_obj)}) ) else: stream_type = device.__class__.__name__ device_info[device_name]["stream_type"].append(stream_type) device_info[device_name]["stream_reader"].append(_get_class_path(device)) - required_args = { - k: None - for k in inspect.signature(device.__init__).parameters - if k != "self" - } + required_args = {k: None for k in inspect.signature(device.__init__).parameters if k != "self"} pattern = schema_dict[device_name].get("pattern") - schema_dict[device_name]["pattern"] = pattern.replace( - device_name, "{pattern}" - ) + schema_dict[device_name]["pattern"] = pattern.replace(device_name, "{pattern}") - kwargs = { - k: v for k, v in schema_dict[device_name].items() if k in required_args - } + kwargs = {k: v for k, v in schema_dict[device_name].items() if k in required_args} device_info[device_name]["stream_reader_kwargs"].append(kwargs) # Add hash device_info[device_name]["stream_hash"].append( @@ -558,9 +509,7 @@ def ingest_epoch_metadata_octagon(experiment_name, metadata_yml_filepath): ("Wall8", "Wall"), ] - epoch_start = datetime.datetime.strptime( - metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S" - ) + epoch_start = datetime.datetime.strptime(metadata_yml_filepath.parent.name, "%Y-%m-%dT%H-%M-%S") for device_idx, (device_name, device_type) in enumerate(oct01_devices): device_sn = f"oct01_{device_idx}" @@ -569,13 +518,8 @@ def ingest_epoch_metadata_octagon(experiment_name, metadata_yml_filepath): skip_duplicates=True, ) experiment_table = getattr(streams, f"Experiment{device_type}") - if not ( - experiment_table - & {"experiment_name": experiment_name, "device_serial_number": device_sn} - ): - experiment_table.insert1( - (experiment_name, device_sn, epoch_start, device_name) - ) + if not (experiment_table & {"experiment_name": experiment_name, "device_serial_number": device_sn}): + experiment_table.insert1((experiment_name, device_sn, epoch_start, device_name)) # endregion diff --git a/aeon/dj_pipeline/utils/paths.py b/aeon/dj_pipeline/utils/paths.py index 1677d5f3..ebba44b5 100644 --- a/aeon/dj_pipeline/utils/paths.py +++ b/aeon/dj_pipeline/utils/paths.py @@ -65,6 +65,5 @@ def find_root_directory( except StopIteration as err: raise FileNotFoundError( - f"No valid root directory found (from {root_directories})" - f" for {full_path}" + f"No valid root directory found (from {root_directories})" f" for {full_path}" ) from err diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index c7273135..d3641826 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -37,17 +37,13 @@ def plot_reward_rate_differences(subject_keys): """ # noqa E501 subj_names, sess_starts, rate_timestamps, rate_diffs = ( analysis.InArenaRewardRate & subject_keys - ).fetch( - "subject", "in_arena_start", "pellet_rate_timestamps", "patch2_patch1_rate_diff" - ) + ).fetch("subject", "in_arena_start", "pellet_rate_timestamps", "patch2_patch1_rate_diff") nSessions = len(sess_starts) longest_rateDiff = np.max([len(t) for t in rate_timestamps]) max_session_idx = np.argmax([len(t) for t in rate_timestamps]) - max_session_elapsed_times = ( - rate_timestamps[max_session_idx] - rate_timestamps[max_session_idx][0] - ) + max_session_elapsed_times = rate_timestamps[max_session_idx] - rate_timestamps[max_session_idx][0] x_labels = [t.total_seconds() / 60 for t in max_session_elapsed_times] y_labels = [ @@ -92,15 +88,12 @@ def plot_wheel_travelled_distance(session_keys): ``` """ distance_travelled_query = ( - analysis.InArenaSummary.FoodPatch - * acquisition.ExperimentFoodPatch.proj("food_patch_description") + analysis.InArenaSummary.FoodPatch * acquisition.ExperimentFoodPatch.proj("food_patch_description") & session_keys ) distance_travelled_df = ( - distance_travelled_query.proj( - "food_patch_description", "wheel_distance_travelled" - ) + distance_travelled_query.proj("food_patch_description", "wheel_distance_travelled") .fetch(format="frame") .reset_index() ) @@ -164,8 +157,7 @@ def plot_average_time_distribution(session_keys): & session_keys ) .aggr( - analysis.InArenaTimeDistribution.FoodPatch - * acquisition.ExperimentFoodPatch, + analysis.InArenaTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch, avg_in_patch="AVG(time_fraction_in_patch)", ) .fetch("subject", "food_patch_description", "avg_in_patch") @@ -248,15 +240,11 @@ def plot_visit_daily_summary( .reset_index() ) else: - visit_per_day_df = ( - (VisitSummary & visit_key).fetch(format="frame").reset_index() - ) + visit_per_day_df = (VisitSummary & visit_key).fetch(format="frame").reset_index() if not attr.startswith("total"): attr = "total_" + attr - visit_per_day_df["day"] = ( - visit_per_day_df["visit_date"] - visit_per_day_df["visit_date"].min() - ) + visit_per_day_df["day"] = visit_per_day_df["visit_date"] - visit_per_day_df["visit_date"].min() visit_per_day_df["day"] = visit_per_day_df["day"].dt.days fig = px.bar( @@ -350,14 +338,10 @@ def plot_foraging_bouts_count( else [foraging_bouts["bout_start"].dt.floor("D")] ) - foraging_bouts_count = ( - foraging_bouts.groupby(group_by_attrs).size().reset_index(name="count") - ) + foraging_bouts_count = foraging_bouts.groupby(group_by_attrs).size().reset_index(name="count") visit_start = (VisitEnd & visit_key).fetch1("visit_start") - foraging_bouts_count["day"] = ( - foraging_bouts_count["bout_start"].dt.date - visit_start.date() - ).dt.days + foraging_bouts_count["day"] = (foraging_bouts_count["bout_start"].dt.date - visit_start.date()).dt.days fig = px.bar( foraging_bouts_count, @@ -371,10 +355,7 @@ def plot_foraging_bouts_count( width=700, height=400, template="simple_white", - title=visit_key["subject"] - + "
Foraging bouts: count (freq='" - + freq - + "')", + title=visit_key["subject"] + "
Foraging bouts: count (freq='" + freq + "')", ) fig.update_layout( @@ -448,9 +429,7 @@ def plot_foraging_bouts_distribution( fig = go.Figure() if per_food_patch: - patch_names = (acquisition.ExperimentFoodPatch & visit_key).fetch( - "food_patch_description" - ) + patch_names = (acquisition.ExperimentFoodPatch & visit_key).fetch("food_patch_description") for patch in patch_names: bouts = foraging_bouts[foraging_bouts["food_patch_description"] == patch] fig.add_trace( @@ -477,9 +456,7 @@ def plot_foraging_bouts_distribution( ) fig.update_layout( - title_text=visit_key["subject"] - + "
Foraging bouts: " - + attr.replace("_", " "), + title_text=visit_key["subject"] + "
Foraging bouts: " + attr.replace("_", " "), xaxis_title="date", yaxis_title=attr.replace("_", " "), violingap=0, @@ -518,17 +495,11 @@ def plot_visit_time_distribution(visit_key, freq="D"): region = _get_region_data(visit_key) # Compute time spent per region - time_spent = ( - region.groupby([region.index.floor(freq), "region"]) - .size() - .reset_index(name="count") + time_spent = region.groupby([region.index.floor(freq), "region"]).size().reset_index(name="count") + time_spent["time_fraction"] = time_spent["count"] / time_spent.groupby("timestamps")["count"].transform( + "sum" ) - time_spent["time_fraction"] = time_spent["count"] / time_spent.groupby( - "timestamps" - )["count"].transform("sum") - time_spent["day"] = ( - time_spent["timestamps"] - time_spent["timestamps"].min() - ).dt.days + time_spent["day"] = (time_spent["timestamps"] - time_spent["timestamps"].min()).dt.days fig = px.bar( time_spent, @@ -540,10 +511,7 @@ def plot_visit_time_distribution(visit_key, freq="D"): "time_fraction": "time fraction", "timestamps": "date" if freq == "D" else "time", }, - title=visit_key["subject"] - + "
Fraction of time spent in each region (freq='" - + freq - + "')", + title=visit_key["subject"] + "
Fraction of time spent in each region (freq='" + freq + "')", width=700, height=400, template="simple_white", @@ -587,9 +555,7 @@ def _get_region_data(visit_key, attrs=None): for attr in attrs: if attr == "in_nest": # Nest in_nest = np.concatenate( - (VisitTimeDistribution.Nest & visit_key).fetch( - attr, order_by="visit_date" - ) + (VisitTimeDistribution.Nest & visit_key).fetch(attr, order_by="visit_date") ) region = pd.concat( [ @@ -604,16 +570,14 @@ def _get_region_data(visit_key, attrs=None): elif attr == "in_patch": # Food patch # Find all patches patches = np.unique( - ( - VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch - & visit_key - ).fetch("food_patch_description") + (VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & visit_key).fetch( + "food_patch_description" + ) ) for patch in patches: in_patch = np.concatenate( ( - VisitTimeDistribution.FoodPatch - * acquisition.ExperimentFoodPatch + VisitTimeDistribution.FoodPatch * acquisition.ExperimentFoodPatch & visit_key & f"food_patch_description = '{patch}'" ).fetch("in_patch", order_by="visit_date") @@ -645,19 +609,13 @@ def _get_region_data(visit_key, attrs=None): region = region.sort_index().rename_axis("timestamps") # Exclude data during maintenance - maintenance_period = get_maintenance_periods( - visit_key["experiment_name"], visit_start, visit_end - ) - region = filter_out_maintenance_periods( - region, maintenance_period, visit_end, dropna=True - ) + maintenance_period = get_maintenance_periods(visit_key["experiment_name"], visit_start, visit_end) + region = filter_out_maintenance_periods(region, maintenance_period, visit_end, dropna=True) return region -def plot_weight_patch_data( - visit_key, freq="H", smooth_weight=True, min_weight=0, max_weight=35 -): +def plot_weight_patch_data(visit_key, freq="H", smooth_weight=True, min_weight=0, max_weight=35): """Plot subject weight and patch data (pellet trigger count) per visit. Args: @@ -674,9 +632,7 @@ def plot_weight_patch_data( >>> fig = plot_weight_patch_data(visit_key, freq="H", smooth_weight=True) >>> fig = plot_weight_patch_data(visit_key, freq="D") """ - subject_weight = _get_filtered_subject_weight( - visit_key, smooth_weight, min_weight, max_weight - ) + subject_weight = _get_filtered_subject_weight(visit_key, smooth_weight, min_weight, max_weight) # Count pellet trigger per patch per day/hour/... patch = _get_patch_data(visit_key) @@ -704,12 +660,8 @@ def plot_weight_patch_data( for p in patch_names: fig.add_trace( go.Bar( - x=patch_summary[patch_summary["food_patch_description"] == p][ - "event_time" - ], - y=patch_summary[patch_summary["food_patch_description"] == p][ - "event_type" - ], + x=patch_summary[patch_summary["food_patch_description"] == p]["event_time"], + y=patch_summary[patch_summary["food_patch_description"] == p]["event_type"], name=p, ), secondary_y=False, @@ -734,10 +686,7 @@ def plot_weight_patch_data( fig.update_layout( barmode="stack", hovermode="x", - title_text=visit_key["subject"] - + "
Weight and pellet count (freq='" - + freq - + "')", + title_text=visit_key["subject"] + "
Weight and pellet count (freq='" + freq + "')", xaxis_title="date" if freq == "D" else "time", yaxis={"title": "pellet count"}, yaxis2={"title": "weight"}, @@ -758,9 +707,7 @@ def plot_weight_patch_data( return fig -def _get_filtered_subject_weight( - visit_key, smooth_weight=True, min_weight=0, max_weight=35 -): +def _get_filtered_subject_weight(visit_key, smooth_weight=True, min_weight=0, max_weight=35): """Retrieve subject weight from WeightMeasurementFiltered table. Args: @@ -799,9 +746,7 @@ def _get_filtered_subject_weight( subject_weight = subject_weight.loc[visit_start:visit_end] # Exclude data during maintenance - maintenance_period = get_maintenance_periods( - visit_key["experiment_name"], visit_start, visit_end - ) + maintenance_period = get_maintenance_periods(visit_key["experiment_name"], visit_start, visit_end) subject_weight = filter_out_maintenance_periods( subject_weight, maintenance_period, visit_end, dropna=True ) @@ -818,9 +763,7 @@ def _get_filtered_subject_weight( subject_weight = subject_weight.resample("1T").mean().dropna() if smooth_weight: - subject_weight["weight_subject"] = savgol_filter( - subject_weight["weight_subject"], 10, 3 - ) + subject_weight["weight_subject"] = savgol_filter(subject_weight["weight_subject"], 10, 3) return subject_weight @@ -841,9 +784,7 @@ def _get_patch_data(visit_key): ( dj.U("event_time", "event_type", "food_patch_description") & ( - acquisition.FoodPatchEvent - * acquisition.EventType - * acquisition.ExperimentFoodPatch + acquisition.FoodPatchEvent * acquisition.EventType * acquisition.ExperimentFoodPatch & f'event_time BETWEEN "{visit_start}" AND "{visit_end}"' & 'event_type = "TriggerPellet"' ) @@ -856,11 +797,7 @@ def _get_patch_data(visit_key): # TODO: handle repeat attempts (pellet delivery trigger and beam break) # Exclude data during maintenance - maintenance_period = get_maintenance_periods( - visit_key["experiment_name"], visit_start, visit_end - ) - patch = filter_out_maintenance_periods( - patch, maintenance_period, visit_end, dropna=True - ) + maintenance_period = get_maintenance_periods(visit_key["experiment_name"], visit_start, visit_end) + patch = filter_out_maintenance_periods(patch, maintenance_period, visit_end, dropna=True) return patch diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index a0fe9aab..31ff161e 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -110,10 +110,7 @@ def get_device_stream_template(device_type: str, stream_type: str, streams_modul # DeviceDataStream table(s) stream_detail = ( streams_module.StreamType - & ( - streams_module.DeviceType.Stream - & {"device_type": device_type, "stream_type": stream_type} - ) + & (streams_module.DeviceType.Stream & {"device_type": device_type, "stream_type": stream_type}) ).fetch1() reader = aeon @@ -121,9 +118,7 @@ def get_device_stream_template(device_type: str, stream_type: str, streams_modul reader = getattr(reader, n) if reader is aeon.io.reader.Pose: - logger.warning( - "Automatic generation of stream table for Pose reader is not supported. Skipping..." - ) + logger.warning("Automatic generation of stream table for Pose reader is not supported. Skipping...") return None, None stream = reader(**stream_detail["stream_reader_kwargs"]) @@ -154,33 +149,25 @@ def key_source(self): """ # noqa B021 device_type_name = dj.utils.from_camel_case(device_type) return ( - acquisition.Chunk - * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) + acquisition.Chunk * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) & f"chunk_start >= {device_type_name}_install_time" & f'chunk_start < IFNULL({device_type_name}_removal_time,"2200-01-01")' ) def make(self, key): """Load and insert the data for the DeviceDataStream table.""" - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) - device_name = (ExperimentDevice & key).fetch1( - f"{dj.utils.from_camel_case(device_type)}_name" - ) + device_name = (ExperimentDevice & key).fetch1(f"{dj.utils.from_camel_case(device_type)}_name") devices_schema = getattr( aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), - ) - stream_reader = getattr( - getattr(devices_schema, device_name), "{stream_type}" + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), ) + stream_reader = getattr(getattr(devices_schema, device_name), "{stream_type}") stream_data = io_api.load( root=data_dirs, diff --git a/aeon/dj_pipeline/webapps/sciviz/specsheet.yaml b/aeon/dj_pipeline/webapps/sciviz/specsheet.yaml index 3515bde8..d150b054 100644 --- a/aeon/dj_pipeline/webapps/sciviz/specsheet.yaml +++ b/aeon/dj_pipeline/webapps/sciviz/specsheet.yaml @@ -806,7 +806,7 @@ SciViz: def dj_query(aeon_block_analysis): aeon_analysis = aeon_block_analysis return {'query': aeon_block_analysis.BlockSubjectPositionPlots(), 'fetch_args': ['position_plot']} - + VideoStream: route: /videostream diff --git a/aeon/io/reader.py b/aeon/io/reader.py index cda78869..0623178a 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -70,12 +70,8 @@ def read(self, file): payloadtype = _payloadtypes[data[4] & ~0x10] elementsize = payloadtype.itemsize payloadshape = (length, payloadsize // elementsize) - seconds = np.ndarray( - length, dtype=np.uint32, buffer=data, offset=5, strides=stride - ) - ticks = np.ndarray( - length, dtype=np.uint16, buffer=data, offset=9, strides=stride - ) + seconds = np.ndarray(length, dtype=np.uint32, buffer=data, offset=5, strides=stride) + ticks = np.ndarray(length, dtype=np.uint16, buffer=data, offset=9, strides=stride) seconds = ticks * _SECONDS_PER_TICK + seconds payload = np.ndarray( payloadshape, @@ -86,9 +82,7 @@ def read(self, file): ) if self.columns is not None and payloadshape[1] < len(self.columns): - data = pd.DataFrame( - payload, index=seconds, columns=self.columns[: payloadshape[1]] - ) + data = pd.DataFrame(payload, index=seconds, columns=self.columns[: payloadshape[1]]) data[self.columns[payloadshape[1] :]] = math.nan return data else: @@ -117,17 +111,13 @@ class Metadata(Reader): def __init__(self, pattern="Metadata"): """Initialize the object with the specified pattern.""" - super().__init__( - pattern, columns=["workflow", "commit", "metadata"], extension="yml" - ) + super().__init__(pattern, columns=["workflow", "commit", "metadata"], extension="yml") def read(self, file): """Returns metadata for the specified epoch.""" epoch_str = file.parts[-2] date_str, time_str = epoch_str.split("T") - time = datetime.datetime.fromisoformat( - date_str + "T" + time_str.replace("-", ":") - ) + time = datetime.datetime.fromisoformat(date_str + "T" + time_str.replace("-", ":")) with open(file) as fp: metadata = json.load(fp) workflow = metadata.pop("Workflow") @@ -267,9 +257,7 @@ class Position(Harp): def __init__(self, pattern): """Initialize the object with a specified pattern and columns.""" - super().__init__( - pattern, columns=["x", "y", "angle", "major", "minor", "area", "id"] - ) + super().__init__(pattern, columns=["x", "y", "angle", "major", "minor", "area", "id"]) class BitmaskEvent(Harp): @@ -328,9 +316,7 @@ class Video(Csv): def __init__(self, pattern): """Initialize the object with a specified pattern.""" - super().__init__( - pattern, columns=["hw_counter", "hw_timestamp", "_frame", "_path", "_epoch"] - ) + super().__init__(pattern, columns=["hw_counter", "hw_timestamp", "_frame", "_path", "_epoch"]) self._rawcolumns = ["time"] + self.columns[0:2] def read(self, file): @@ -355,9 +341,7 @@ class (int): Int ID of a subject in the environment. y (float): Y-coordinate of the bodypart. """ - def __init__( - self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/processed" - ): + def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/processed"): """Pose reader constructor.""" # `pattern` for this reader should typically be '_*' super().__init__(pattern, columns=None) @@ -396,16 +380,10 @@ def read(self, file: Path) -> pd.DataFrame: # Drop any repeat parts. unique_parts, unique_idxs = np.unique(parts, return_index=True) repeat_idxs = np.setdiff1d(np.arange(len(parts)), unique_idxs) - if ( - repeat_idxs - ): # drop x, y, and likelihood cols for repeat parts (skip first 5 cols) + if repeat_idxs: # drop x, y, and likelihood cols for repeat parts (skip first 5 cols) init_rep_part_col_idx = (repeat_idxs - 1) * 3 + 5 - rep_part_col_idxs = np.concatenate( - [np.arange(i, i + 3) for i in init_rep_part_col_idx] - ) - keep_part_col_idxs = np.setdiff1d( - np.arange(len(data.columns)), rep_part_col_idxs - ) + rep_part_col_idxs = np.concatenate([np.arange(i, i + 3) for i in init_rep_part_col_idx]) + keep_part_col_idxs = np.setdiff1d(np.arange(len(data.columns)), rep_part_col_idxs) data = data.iloc[:, keep_part_col_idxs] parts = unique_parts @@ -413,25 +391,18 @@ def read(self, file: Path) -> pd.DataFrame: data = self.class_int2str(data, config_file) n_parts = len(parts) part_data_list = [pd.DataFrame()] * n_parts - new_columns = pd.Series( - ["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"] - ) + new_columns = pd.Series(["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]) new_data = pd.DataFrame(columns=new_columns) for i, part in enumerate(parts): part_columns = ( - columns[0 : (len(identities) + 1)] - if bonsai_sleap_v == BONSAI_SLEAP_V3 - else columns[0:2] + columns[0 : (len(identities) + 1)] if bonsai_sleap_v == BONSAI_SLEAP_V3 else columns[0:2] ) part_columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"]) part_data = pd.DataFrame(data[part_columns]) if bonsai_sleap_v == BONSAI_SLEAP_V3: # combine all identity_likelihood cols into a single col as dict part_data["identity_likelihood"] = part_data.apply( - lambda row: { - identity: row[f"{identity}_likelihood"] - for identity in identities - }, + lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, axis=1, ) part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True) @@ -496,14 +467,10 @@ def class_int2str(data: pd.DataFrame, config_file: Path) -> pd.DataFrame: return data @classmethod - def get_config_file( - cls, config_file_dir: Path, config_file_names: None | list[str] = None - ) -> Path: + def get_config_file(cls, config_file_dir: Path, config_file_names: None | list[str] = None) -> Path: """Returns the config file from a model's config directory.""" if config_file_names is None: - config_file_names = [ - "confmap_config.json" - ] # SLEAP (add for other trackers to this list) + config_file_names = ["confmap_config.json"] # SLEAP (add for other trackers to this list) config_file = None for f in config_file_names: if (config_file_dir / f).exists(): @@ -522,21 +489,14 @@ def from_dict(data, pattern=None): return globals()[reader_type](pattern=pattern, **kwargs) return DotMap( - { - k: from_dict(v, f"{pattern}_{k}" if pattern is not None else k) - for k, v in data.items() - } + {k: from_dict(v, f"{pattern}_{k}" if pattern is not None else k) for k, v in data.items()} ) def to_dict(dotmap): """Converts a DotMap object to a dictionary.""" if isinstance(dotmap, Reader): - kwargs = { - k: v - for k, v in vars(dotmap).items() - if k not in ["pattern"] and not k.startswith("_") - } + kwargs = {k: v for k, v in vars(dotmap).items() if k not in ["pattern"] and not k.startswith("_")} kwargs["type"] = type(dotmap).__name__ return kwargs return {k: to_dict(v) for k, v in dotmap.items()} diff --git a/aeon/schema/social_03.py b/aeon/schema/social_03.py index e1f624f5..0f07e72c 100644 --- a/aeon/schema/social_03.py +++ b/aeon/schema/social_03.py @@ -11,7 +11,6 @@ def __init__(self, path): class EnvironmentActiveConfiguration(Stream): - def __init__(self, path): """Initializes the EnvironmentActiveConfiguration stream.""" super().__init__(_reader.JsonList(f"{path}_ActiveConfiguration_*", columns=["name"])) From fc49b511a4cd1c4445e65c0aafe5c61901499f83 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 31 Oct 2024 16:06:28 +0000 Subject: [PATCH 078/143] fix: revert changes in `streams.py` --- aeon/dj_pipeline/streams.py | 1053 +++++++++++++++++------------------ 1 file changed, 506 insertions(+), 547 deletions(-) diff --git a/aeon/dj_pipeline/streams.py b/aeon/dj_pipeline/streams.py index 8fdd1f1e..eed24dcd 100644 --- a/aeon/dj_pipeline/streams.py +++ b/aeon/dj_pipeline/streams.py @@ -1,5 +1,5 @@ -# ---- DO NOT MODIFY ---- -# ---- THIS FILE IS AUTO-GENERATED BY `streams_maker.py` ---- +#---- DO NOT MODIFY ---- +#---- THIS FILE IS AUTO-GENERATED BY `streams_maker.py` ---- import re from uuid import UUID @@ -57,7 +57,7 @@ class Device(dj.Lookup): @schema class RfidReader(dj.Manual): - definition = f""" + definition = f""" # rfid_reader placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-unknown) -> acquisition.Experiment -> Device @@ -66,16 +66,16 @@ class RfidReader(dj.Manual): rfid_reader_name : varchar(36) """ - class Attribute(dj.Part): - definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device + class Attribute(dj.Part): + definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device -> master attribute_name : varchar(32) --- attribute_value=null : longblob """ - class RemovalTime(dj.Part): - definition = f""" + class RemovalTime(dj.Part): + definition = f""" -> master --- rfid_reader_removal_time: datetime(6) # time of the rfid_reader being removed @@ -84,7 +84,7 @@ class RemovalTime(dj.Part): @schema class SpinnakerVideoSource(dj.Manual): - definition = f""" + definition = f""" # spinnaker_video_source placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-unknown) -> acquisition.Experiment -> Device @@ -93,16 +93,16 @@ class SpinnakerVideoSource(dj.Manual): spinnaker_video_source_name : varchar(36) """ - class Attribute(dj.Part): - definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device + class Attribute(dj.Part): + definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device -> master attribute_name : varchar(32) --- attribute_value=null : longblob """ - class RemovalTime(dj.Part): - definition = f""" + class RemovalTime(dj.Part): + definition = f""" -> master --- spinnaker_video_source_removal_time: datetime(6) # time of the spinnaker_video_source being removed @@ -111,7 +111,7 @@ class RemovalTime(dj.Part): @schema class UndergroundFeeder(dj.Manual): - definition = f""" + definition = f""" # underground_feeder placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-unknown) -> acquisition.Experiment -> Device @@ -120,16 +120,16 @@ class UndergroundFeeder(dj.Manual): underground_feeder_name : varchar(36) """ - class Attribute(dj.Part): - definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device + class Attribute(dj.Part): + definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device -> master attribute_name : varchar(32) --- attribute_value=null : longblob """ - class RemovalTime(dj.Part): - definition = f""" + class RemovalTime(dj.Part): + definition = f""" -> master --- underground_feeder_removal_time: datetime(6) # time of the underground_feeder being removed @@ -138,7 +138,7 @@ class RemovalTime(dj.Part): @schema class WeightScale(dj.Manual): - definition = f""" + definition = f""" # weight_scale placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-unknown) -> acquisition.Experiment -> Device @@ -147,16 +147,16 @@ class WeightScale(dj.Manual): weight_scale_name : varchar(36) """ - class Attribute(dj.Part): - definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device + class Attribute(dj.Part): + definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device -> master attribute_name : varchar(32) --- attribute_value=null : longblob """ - class RemovalTime(dj.Part): - definition = f""" + class RemovalTime(dj.Part): + definition = f""" -> master --- weight_scale_removal_time: datetime(6) # time of the weight_scale being removed @@ -165,7 +165,7 @@ class RemovalTime(dj.Part): @schema class RfidReaderRfidEvents(dj.Imported): - definition = """ # Raw per-chunk RfidEvents data stream from RfidReader (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk RfidEvents data stream from RfidReader (auto-generated with aeon_mecha-unknown) -> RfidReader -> acquisition.Chunk --- @@ -174,62 +174,59 @@ class RfidReaderRfidEvents(dj.Imported): rfid: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and RfidReader with overlapping time + Chunk(s) that started after RfidReader install time and ended before RfidReader remove time + Chunk(s) that started after RfidReader install time for RfidReader that are not yet removed """ - return ( - acquisition.Chunk * RfidReader.join(RfidReader.RemovalTime, left=True) - & "chunk_start >= rfid_reader_install_time" - & 'chunk_start < IFNULL(rfid_reader_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (RfidReader & key).fetch1("rfid_reader_name") - - devices_schema = getattr( - aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "RfidEvents") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") + return ( + acquisition.Chunk * RfidReader.join(RfidReader.RemovalTime, left=True) + & 'chunk_start >= rfid_reader_install_time' + & 'chunk_start < IFNULL(rfid_reader_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (RfidReader & key).fetch1('rfid_reader_name') + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "RfidEvents") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") + }, }, - }, - ignore_extra_fields=True, - ) + ignore_extra_fields=True, + ) @schema class SpinnakerVideoSourceVideo(dj.Imported): - definition = """ # Raw per-chunk Video data stream from SpinnakerVideoSource (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk Video data stream from SpinnakerVideoSource (auto-generated with aeon_mecha-unknown) -> SpinnakerVideoSource -> acquisition.Chunk --- @@ -239,63 +236,59 @@ class SpinnakerVideoSourceVideo(dj.Imported): hw_timestamp: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and SpinnakerVideoSource with overlapping time + Chunk(s) that started after SpinnakerVideoSource install time and ended before SpinnakerVideoSource remove time + Chunk(s) that started after SpinnakerVideoSource install time for SpinnakerVideoSource that are not yet removed """ - return ( - acquisition.Chunk - * SpinnakerVideoSource.join(SpinnakerVideoSource.RemovalTime, left=True) - & "chunk_start >= spinnaker_video_source_install_time" - & 'chunk_start < IFNULL(spinnaker_video_source_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name") - - devices_schema = getattr( - aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "Video") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") + return ( + acquisition.Chunk * SpinnakerVideoSource.join(SpinnakerVideoSource.RemovalTime, left=True) + & 'chunk_start >= spinnaker_video_source_install_time' + & 'chunk_start < IFNULL(spinnaker_video_source_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (SpinnakerVideoSource & key).fetch1('spinnaker_video_source_name') + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "Video") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") + }, }, - }, - ignore_extra_fields=True, - ) + ignore_extra_fields=True, + ) @schema class UndergroundFeederBeamBreak(dj.Imported): - definition = """ # Raw per-chunk BeamBreak data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk BeamBreak data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -304,63 +297,59 @@ class UndergroundFeederBeamBreak(dj.Imported): event: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & "chunk_start >= underground_feeder_install_time" - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") - - devices_schema = getattr( - aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "BeamBreak") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") + return ( + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & 'chunk_start >= underground_feeder_install_time' + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "BeamBreak") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") + }, }, - }, - ignore_extra_fields=True, - ) + ignore_extra_fields=True, + ) @schema class UndergroundFeederDeliverPellet(dj.Imported): - definition = """ # Raw per-chunk DeliverPellet data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk DeliverPellet data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -369,63 +358,59 @@ class UndergroundFeederDeliverPellet(dj.Imported): event: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & "chunk_start >= underground_feeder_install_time" - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") - - devices_schema = getattr( - aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "DeliverPellet") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") + return ( + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & 'chunk_start >= underground_feeder_install_time' + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "DeliverPellet") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") + }, }, - }, - ignore_extra_fields=True, - ) + ignore_extra_fields=True, + ) @schema class UndergroundFeederDepletionState(dj.Imported): - definition = """ # Raw per-chunk DepletionState data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk DepletionState data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -436,63 +421,59 @@ class UndergroundFeederDepletionState(dj.Imported): rate: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & "chunk_start >= underground_feeder_install_time" - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") - - devices_schema = getattr( - aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "DepletionState") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") + return ( + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & 'chunk_start >= underground_feeder_install_time' + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "DepletionState") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") + }, }, - }, - ignore_extra_fields=True, - ) + ignore_extra_fields=True, + ) @schema class UndergroundFeederEncoder(dj.Imported): - definition = """ # Raw per-chunk Encoder data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk Encoder data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -502,63 +483,59 @@ class UndergroundFeederEncoder(dj.Imported): intensity: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & "chunk_start >= underground_feeder_install_time" - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") - - devices_schema = getattr( - aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "Encoder") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") + return ( + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & 'chunk_start >= underground_feeder_install_time' + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "Encoder") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") + }, }, - }, - ignore_extra_fields=True, - ) + ignore_extra_fields=True, + ) @schema class UndergroundFeederManualDelivery(dj.Imported): - definition = """ # Raw per-chunk ManualDelivery data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk ManualDelivery data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -567,63 +544,59 @@ class UndergroundFeederManualDelivery(dj.Imported): manual_delivery: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & "chunk_start >= underground_feeder_install_time" - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") - - devices_schema = getattr( - aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "ManualDelivery") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") + return ( + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & 'chunk_start >= underground_feeder_install_time' + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "ManualDelivery") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") + }, }, - }, - ignore_extra_fields=True, - ) + ignore_extra_fields=True, + ) @schema class UndergroundFeederMissedPellet(dj.Imported): - definition = """ # Raw per-chunk MissedPellet data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk MissedPellet data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -632,63 +605,59 @@ class UndergroundFeederMissedPellet(dj.Imported): missed_pellet: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & "chunk_start >= underground_feeder_install_time" - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") - - devices_schema = getattr( - aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "MissedPellet") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") + return ( + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & 'chunk_start >= underground_feeder_install_time' + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "MissedPellet") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") + }, }, - }, - ignore_extra_fields=True, - ) + ignore_extra_fields=True, + ) @schema class UndergroundFeederRetriedDelivery(dj.Imported): - definition = """ # Raw per-chunk RetriedDelivery data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk RetriedDelivery data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder -> acquisition.Chunk --- @@ -697,63 +666,59 @@ class UndergroundFeederRetriedDelivery(dj.Imported): retried_delivery: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and UndergroundFeeder with overlapping time + Chunk(s) that started after UndergroundFeeder install time and ended before UndergroundFeeder remove time + Chunk(s) that started after UndergroundFeeder install time for UndergroundFeeder that are not yet removed """ - return ( - acquisition.Chunk - * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) - & "chunk_start >= underground_feeder_install_time" - & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (UndergroundFeeder & key).fetch1("underground_feeder_name") - - devices_schema = getattr( - aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "RetriedDelivery") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") + return ( + acquisition.Chunk * UndergroundFeeder.join(UndergroundFeeder.RemovalTime, left=True) + & 'chunk_start >= underground_feeder_install_time' + & 'chunk_start < IFNULL(underground_feeder_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "RetriedDelivery") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") + }, }, - }, - ignore_extra_fields=True, - ) + ignore_extra_fields=True, + ) @schema class WeightScaleWeightFiltered(dj.Imported): - definition = """ # Raw per-chunk WeightFiltered data stream from WeightScale (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk WeightFiltered data stream from WeightScale (auto-generated with aeon_mecha-unknown) -> WeightScale -> acquisition.Chunk --- @@ -763,62 +728,59 @@ class WeightScaleWeightFiltered(dj.Imported): stability: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and WeightScale with overlapping time + Chunk(s) that started after WeightScale install time and ended before WeightScale remove time + Chunk(s) that started after WeightScale install time for WeightScale that are not yet removed """ - return ( - acquisition.Chunk * WeightScale.join(WeightScale.RemovalTime, left=True) - & "chunk_start >= weight_scale_install_time" - & 'chunk_start < IFNULL(weight_scale_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (WeightScale & key).fetch1("weight_scale_name") - - devices_schema = getattr( - aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "WeightFiltered") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") + return ( + acquisition.Chunk * WeightScale.join(WeightScale.RemovalTime, left=True) + & 'chunk_start >= weight_scale_install_time' + & 'chunk_start < IFNULL(weight_scale_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (WeightScale & key).fetch1('weight_scale_name') + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "WeightFiltered") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") + }, }, - }, - ignore_extra_fields=True, - ) + ignore_extra_fields=True, + ) @schema class WeightScaleWeightRaw(dj.Imported): - definition = """ # Raw per-chunk WeightRaw data stream from WeightScale (auto-generated with aeon_mecha-unknown) + definition = """ # Raw per-chunk WeightRaw data stream from WeightScale (auto-generated with aeon_mecha-unknown) -> WeightScale -> acquisition.Chunk --- @@ -828,54 +790,51 @@ class WeightScaleWeightRaw(dj.Imported): stability: longblob """ - @property - def key_source(self): - f""" + @property + def key_source(self): + f""" Only the combination of Chunk and WeightScale with overlapping time + Chunk(s) that started after WeightScale install time and ended before WeightScale remove time + Chunk(s) that started after WeightScale install time for WeightScale that are not yet removed """ - return ( - acquisition.Chunk * WeightScale.join(WeightScale.RemovalTime, left=True) - & "chunk_start >= weight_scale_install_time" - & 'chunk_start < IFNULL(weight_scale_removal_time, "2200-01-01")' - ) - - def make(self, key): - chunk_start, chunk_end = (acquisition.Chunk & key).fetch1( - "chunk_start", "chunk_end" - ) - - data_dirs = acquisition.Experiment.get_data_directories(key) - - device_name = (WeightScale & key).fetch1("weight_scale_name") - - devices_schema = getattr( - aeon_schemas, - ( - acquisition.Experiment.DevicesSchema - & {"experiment_name": key["experiment_name"]} - ).fetch1("devices_schema_name"), - ) - stream_reader = getattr(getattr(devices_schema, device_name), "WeightRaw") - - stream_data = io_api.load( - root=data_dirs, - reader=stream_reader, - start=pd.Timestamp(chunk_start), - end=pd.Timestamp(chunk_end), - ) - - self.insert1( - { - **key, - "sample_count": len(stream_data), - "timestamps": stream_data.index.values, - **{ - re.sub(r"\([^)]*\)", "", c): stream_data[c].values - for c in stream_reader.columns - if not c.startswith("_") + return ( + acquisition.Chunk * WeightScale.join(WeightScale.RemovalTime, left=True) + & 'chunk_start >= weight_scale_install_time' + & 'chunk_start < IFNULL(weight_scale_removal_time, "2200-01-01")' + ) + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (WeightScale & key).fetch1('weight_scale_name') + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "WeightRaw") + + stream_data = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + self.insert1( + { + **key, + "sample_count": len(stream_data), + "timestamps": stream_data.index.values, + **{ + re.sub(r"\([^)]*\)", "", c): stream_data[c].values + for c in stream_reader.columns + if not c.startswith("_") + }, }, - }, - ignore_extra_fields=True, - ) + ignore_extra_fields=True, + ) From 48493cbc0de8b6e2a5e20bcebc0f30f3f7e6719b Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 31 Oct 2024 17:50:35 +0000 Subject: [PATCH 079/143] fix: revert stream.py change in dependencies order --- aeon/dj_pipeline/streams.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/streams.py b/aeon/dj_pipeline/streams.py index eed24dcd..e3d6ba12 100644 --- a/aeon/dj_pipeline/streams.py +++ b/aeon/dj_pipeline/streams.py @@ -2,11 +2,11 @@ #---- THIS FILE IS AUTO-GENERATED BY `streams_maker.py` ---- import re +import datajoint as dj +import pandas as pd from uuid import UUID import aeon -import datajoint as dj -import pandas as pd from aeon.dj_pipeline import acquisition, get_schema_name from aeon.io import api as io_api from aeon.schema import schemas as aeon_schemas From 69dee9be202f76bc9587f8d06a4ce71c3b5f4605 Mon Sep 17 00:00:00 2001 From: lochhh Date: Fri, 1 Nov 2024 18:07:24 +0000 Subject: [PATCH 080/143] Fix docstring indent in `movies.py` --- aeon/analysis/movies.py | 28 ++++++++++++---------------- 1 file changed, 12 insertions(+), 16 deletions(-) diff --git a/aeon/analysis/movies.py b/aeon/analysis/movies.py index f8fbd7c4..692ba01a 100644 --- a/aeon/analysis/movies.py +++ b/aeon/analysis/movies.py @@ -15,9 +15,8 @@ def gridframes(frames, width, height, shape: None | int | tuple[int, int] = None :param list frames: A list of frames to include in the grid layout. :param int width: The width of the output grid image, in pixels. :param int height: The height of the output grid image, in pixels. - :param optional shape: - Either the number of frames to include, or the number of rows and columns - in the output grid image layout. + :param optional shape: Either the number of frames to include, + or the number of rows and columns in the output grid image layout. :return: A new image containing the arrangement of the frames in a grid. """ if shape is None: @@ -69,13 +68,12 @@ def groupframes(frames, n, fun): def triggerclip(data, events, before=None, after=None): """Split video data around the specified sequence of event timestamps. - :param DataFrame data: - A pandas DataFrame where each row specifies video acquisition path and frame number. + :param DataFrame data: A pandas DataFrame where each row specifies + video acquisition path and frame number. :param iterable events: A sequence of timestamps to extract. :param Timedelta before: The left offset from each timestamp used to clip the data. :param Timedelta after: The right offset from each timestamp used to clip the data. - :return: - A pandas DataFrame containing the frames, clip and sequence numbers for each event timestamp. + :return: A pandas DataFrame containing the frames, clip and sequence numbers for each event timestamp. """ if before is None: before = pd.Timedelta(0) @@ -102,9 +100,8 @@ def triggerclip(data, events, before=None, after=None): def collatemovie(clipdata, fun): """Collates a set of video clips into a single movie using the specified aggregation function. - :param DataFrame clipdata: - A pandas DataFrame where each row specifies video path, frame number, clip and sequence number. - This DataFrame can be obtained from the output of the triggerclip function. + :param DataFrame clipdata: A pandas DataFrame where each row specifies video path, frame number, + clip and sequence number. This DataFrame can be obtained from the output of the triggerclip function. :param callable fun: The aggregation function used to process the frames in each clip. :return: The sequence of processed frames representing the collated movie. """ @@ -116,14 +113,13 @@ def collatemovie(clipdata, fun): def gridmovie(clipdata, width, height, shape=None): """Collates a set of video clips into a grid movie with the specified pixel dimensions and grid layout. - :param DataFrame clipdata: - A pandas DataFrame where each row specifies video path, frame number, clip and sequence number. - This DataFrame can be obtained from the output of the triggerclip function. + :param DataFrame clipdata: A pandas DataFrame where each row specifies video path, frame number, + clip and sequence number. + This DataFrame can be obtained from the output of the triggerclip function. :param int width: The width of the output grid movie, in pixels. :param int height: The height of the output grid movie, in pixels. - :param optional shape: - Either the number of frames to include, or the number of rows and columns - in the output grid movie layout. + :param optional shape: Either the number of frames to include, + or the number of rows and columns in the output grid movie layout. :return: The sequence of processed frames representing the collated grid movie. """ return collatemovie(clipdata, lambda g: gridframes(g, width, height, shape)) From 6793ffd1d090f168f053be0e85a3323d4a9d89e5 Mon Sep 17 00:00:00 2001 From: lochhh Date: Fri, 1 Nov 2024 18:13:00 +0000 Subject: [PATCH 081/143] Fix docstring indent `plotting.py` --- aeon/analysis/plotting.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aeon/analysis/plotting.py b/aeon/analysis/plotting.py index d37c65c1..2b2bd361 100644 --- a/aeon/analysis/plotting.py +++ b/aeon/analysis/plotting.py @@ -73,7 +73,7 @@ def rateplot( :param datetime, optional end: The right bound of the time range for the continuous rate. :param datetime, optional smooth: The size of the smoothing kernel applied to the rate output. :param DateOffset, Timedelta or str, optional smooth: - The size of the smoothing kernel applied to the continuous rate output. + The size of the smoothing kernel applied to the continuous rate output. :param bool, optional center: Specifies whether to center the convolution kernels. :param Axes, optional ax: The Axes on which to draw the rate plot and raster. """ @@ -119,11 +119,11 @@ def colorline( :param array-like x, y: The horizontal / vertical coordinates of the data points. :param array-like, optional z: - The dynamic variable used to color each data point by indexing the color map. + The dynamic variable used to color each data point by indexing the color map. :param str or ~matplotlib.colors.Colormap, optional cmap: - The colormap used to map normalized data values to RGBA colors. + The colormap used to map normalized data values to RGBA colors. :param matplotlib.colors.Normalize, optional norm: - The normalizing object used to scale data to the range [0, 1] for indexing the color map. + The normalizing object used to scale data to the range [0, 1] for indexing the color map. :param Axes, optional ax: The Axes on which to draw the colored line. """ if ax is None: From c1cdb08ed61b5b129aec1d43006e71051be6bbe9 Mon Sep 17 00:00:00 2001 From: lochhh Date: Fri, 1 Nov 2024 18:27:07 +0000 Subject: [PATCH 082/143] Fix docstring indent + unapply black in `utils.py` --- aeon/analysis/utils.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/aeon/analysis/utils.py b/aeon/analysis/utils.py index 9f0b08e6..168fcdaa 100644 --- a/aeon/analysis/utils.py +++ b/aeon/analysis/utils.py @@ -61,11 +61,7 @@ def visits(data, onset="Enter", offset="Exit"): # duplicate offsets indicate missing data from previous pairing missing_data = data.duplicated(subset=time_offset, keep="last") if missing_data.any(): - data.loc[ - missing_data, - ["duration"] + [name for name in data.columns if rsuffix in name], - ] = pd.NA - + data.loc[missing_data, ["duration"] + [name for name in data.columns if rsuffix in name]] = pd.NA # rename columns and sort data data.rename({time_onset: lonset, id_onset: "id", time_offset: loffset}, axis=1, inplace=True) data = data[["id"] + [name for name in data.columns if "_" in name] + [lonset, loffset, "duration"]] @@ -88,7 +84,7 @@ def rate(events, window, frequency, weight=1, start=None, end=None, smooth=None, :param datetime, optional end: The right bound of the time range for the continuous rate. :param datetime, optional smooth: The size of the smoothing kernel applied to the rate output. :param DateOffset, Timedelta or str, optional smooth: - The size of the smoothing kernel applied to the continuous rate output. + The size of the smoothing kernel applied to the continuous rate output. :param bool, optional center: Specifies whether to center the convolution kernels. :return: A Series containing the continuous event rate over time. """ @@ -104,16 +100,20 @@ def rate(events, window, frequency, weight=1, start=None, end=None, smooth=None, def get_events_rates( - events, - window_len_sec, - frequency, - unit_len_sec=60, - start=None, - end=None, - smooth=None, - center=False, + events, window_len_sec, frequency, unit_len_sec=60, start=None, end=None, smooth=None, center=False ): - """Computes the event rate from a sequence of events over a specified window.""" + """Computes the event rate from a sequence of events over a specified window. + + :param Series events: The discrete sequence of events, with timestamps in seconds as index. + :param int window_len_sec: The length of the window over which the event rate is estimated. + :param DateOffset, Timedelta or str frequency: The sampling frequency for the continuous rate. + :param int, optional unit_len_sec: The length of one sample point. Default is 60 seconds. + :param datetime, optional start: The left bound of the time range for the continuous rate. + :param datetime, optional end: The right bound of the time range for the continuous rate. + :param int, optional smooth: The size of the smoothing kernel applied to the continuous rate output. + :param bool, optional center: Specifies whether to center the convolution kernels. + :return: A Series containing the continuous event rate over time. + """ # events is an array with the time (in seconds) of event occurence # window_len_sec is the size of the window over which the event rate is estimated # unit_len_sec is the length of one sample point From 8f9752f6a02e67a275a2f428cdde42a900ccbd14 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 6 Nov 2024 14:48:37 +0000 Subject: [PATCH 083/143] fix: revert changes in `aeon/README.md` --- aeon/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/aeon/README.md b/aeon/README.md index 5a608ffb..e69de29b 100644 --- a/aeon/README.md +++ b/aeon/README.md @@ -1 +0,0 @@ -# README # noqa D100 From 2e40630cc0230b27d4fc90c8cdcb3e829f92e9b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 14:49:35 +0000 Subject: [PATCH 084/143] Update aeon/analysis/block_plotting.py Co-authored-by: Chang Huan Lo --- aeon/analysis/block_plotting.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index c71cb629..7d83a5a2 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -34,11 +34,11 @@ def gen_hex_grad(hex_col, vals, min_lightness=0.3): ) grad = np.empty(shape=(len(vals),), dtype=" Date: Wed, 6 Nov 2024 14:51:27 +0000 Subject: [PATCH 085/143] Update aeon/README.md Co-authored-by: Chang Huan Lo --- aeon/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/aeon/README.md b/aeon/README.md index 5a608ffb..e69de29b 100644 --- a/aeon/README.md +++ b/aeon/README.md @@ -1 +0,0 @@ -# README # noqa D100 From efa5590a769d3a4b6c8690f86761ffc3a276c630 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 14:52:03 +0000 Subject: [PATCH 086/143] Update aeon/analysis/block_plotting.py Co-authored-by: Chang Huan Lo --- aeon/analysis/block_plotting.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index 7d83a5a2..26f2f742 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -63,13 +63,16 @@ def gen_subject_colors_dict(subject_names): def gen_patch_style_dict(patch_names): - """Based on a list of patches, generates a dictionary of the following items. + """ + Generates a dictionary of patch styles given a list of patch_names. - - patch_colors_dict: patch name to color - - patch_markers_dict: patch name to marker - - patch_symbols_dict: patch name to symbol - - patch_linestyles_dict: patch name to linestyle + The dictionary contains dictionaries which map patch names to their respective styles. + Below are the keys for each nested dictionary and their contents: + - colors: patch name to color + - markers: patch name to marker + - symbols: patch name to symbol + - linestyles: patch name to linestyle """ return { "colors": dict(zip(patch_names, patch_colors, strict=False)), From f21d79ac8f6ed6b4d2b30edad7323039dd29b57e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 14:52:22 +0000 Subject: [PATCH 087/143] Update aeon/dj_pipeline/__init__.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/__init__.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/__init__.py b/aeon/dj_pipeline/__init__.py index f225ab2b..a29bb2c8 100644 --- a/aeon/dj_pipeline/__init__.py +++ b/aeon/dj_pipeline/__init__.py @@ -51,10 +51,7 @@ def fetch_stream(query, drop_pk=True): df.set_index("time", inplace=True) df.sort_index(inplace=True) df = df.convert_dtypes( - convert_string=False, - convert_integer=False, - convert_boolean=False, - convert_floating=False, + convert_string=False, convert_integer=False, convert_boolean=False, convert_floating=False ) return df From 6ef1b3a53bf535f69c99c04636313c6aac2f35bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 14:52:48 +0000 Subject: [PATCH 088/143] Update aeon/dj_pipeline/acquisition.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/acquisition.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index 8c5056bc..b798db41 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -288,7 +288,7 @@ class Meta(dj.Part): -> master --- bonsai_workflow: varchar(36) - commit: varchar(64) # e.g., git commit hash of aeon_experiment used to generate this epoch + commit: varchar(64) # e.g. git commit hash of aeon_experiment used to generate this epoch source='': varchar(16) # e.g. aeon_experiment or aeon_acquisition (or others) metadata: longblob metadata_file_path: varchar(255) # path of the file, relative to the experiment repository From 8d468b9362b2f08572efdcbb02696d2247ab57ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 14:53:27 +0000 Subject: [PATCH 089/143] Update aeon/dj_pipeline/acquisition.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/acquisition.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index b798db41..4239065b 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -625,9 +625,7 @@ def _get_all_chunks(experiment_name, device_name): directory_types = ["quality-control", "raw"] raw_data_dirs = { dir_type: Experiment.get_data_directory( - experiment_key={"experiment_name": experiment_name}, - directory_type=dir_type, - as_posix=False, + experiment_key={"experiment_name": experiment_name}, directory_type=dir_type, as_posix=False ) for dir_type in directory_types } From 9816f0f704e567628affaa2b20e1f677dccf74a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 14:54:04 +0000 Subject: [PATCH 090/143] Update aeon/dj_pipeline/analysis/block_analysis.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/analysis/block_analysis.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index e491bc9e..ebf35da2 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -129,7 +129,10 @@ class BlockAnalysis(dj.Computed): @property def key_source(self): - """Ensure that the chunk ingestion has caught up with this block before processing (there exists a chunk that ends after the block end time).""" # noqa 501 + """Ensures chunk ingestion is complete before processing the block. + + This is done by checking that there exists a chunk that ends after the block end time. + """ ks = Block.aggr(acquisition.Chunk, latest_chunk_end="MAX(chunk_end)") ks = ks * Block & "latest_chunk_end >= block_end" & "block_end IS NOT NULL" return ks From 3904810338e7020bc8295eef32c00b583ea34e39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 14:55:14 +0000 Subject: [PATCH 091/143] Update aeon/dj_pipeline/analysis/block_analysis.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/analysis/block_analysis.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index ebf35da2..a6b097bc 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -167,8 +167,11 @@ class Subject(dj.Part): """ def make(self, key): - """ - Restrict, fetch and aggregate data from different streams to produce intermediate data products at a per-block level (for different patches and different subjects). + """Collates data from various streams to produce per-block intermediate data products. + + The intermediate data products consist of data for each ``Patch`` + and each ``Subject`` within the ``Block``. + The steps to restrict, fetch, and aggregate data from various streams are as follows: 1. Query data for all chunks within the block. 2. Fetch streams, filter by maintenance period. From 6a72beb7b57c34ae4482b79c0ed2a0bfce56086e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 14:55:32 +0000 Subject: [PATCH 092/143] Update aeon/dj_pipeline/analysis/block_analysis.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/analysis/block_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index a6b097bc..4e335a2d 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -177,7 +177,7 @@ def make(self, key): 2. Fetch streams, filter by maintenance period. 3. Fetch subject position data (SLEAP). 4. Aggregate and insert into the table. - """ # noqa 501 + """ block_start, block_end = (Block & key).fetch1("block_start", "block_end") chunk_restriction = acquisition.create_chunk_restriction( From 714a8fece3a6d73699241522ecae861d0343569e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 14:55:56 +0000 Subject: [PATCH 093/143] Update aeon/dj_pipeline/analysis/block_analysis.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/analysis/block_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 4e335a2d..93f5b468 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -389,7 +389,7 @@ class Patch(dj.Part): -> BlockAnalysis.Patch -> BlockAnalysis.Subject --- - in_patch_timestamps: longblob # timestamps when a subject spends time at a specific patch + in_patch_timestamps: longblob # timestamps when a subject is at a specific patch in_patch_time: float # total seconds spent in this patch for this block pellet_count: int pellet_timestamps: longblob From b3da517b92c96fa813f7caddf70c8d8c4ffd2405 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 14:56:18 +0000 Subject: [PATCH 094/143] Update aeon/dj_pipeline/analysis/block_analysis.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/analysis/block_analysis.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 93f5b468..185d6d05 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -953,9 +953,7 @@ def calculate_running_preference(group, pref_col, out_col): patch_pref.groupby("subject_name") .apply( lambda group: calculate_running_preference( - group, - "cumulative_preference_by_wheel", - "running_preference_by_wheel", + group, "cumulative_preference_by_wheel", "running_preference_by_wheel" ) ) .droplevel(0) From c60d67c1ab55257164e5b249ee5790091cefa494 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 14:56:34 +0000 Subject: [PATCH 095/143] Update aeon/schema/octagon.py Co-authored-by: Chang Huan Lo --- aeon/schema/octagon.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aeon/schema/octagon.py b/aeon/schema/octagon.py index ac121dbe..50aba7cd 100644 --- a/aeon/schema/octagon.py +++ b/aeon/schema/octagon.py @@ -92,8 +92,7 @@ def __init__(self, pattern): """Initialises the Response class.""" super().__init__( _reader.Csv( - f"{pattern}_response_*", - columns=["typetag", "wall_id", "poke_id", "response_time"], + f"{pattern}_response_*", columns=["typetag", "wall_id", "poke_id", "response_time"] ) ) From d2559eefc6da50997754a901fdac1bef9786e9fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 14:56:47 +0000 Subject: [PATCH 096/143] Update aeon/schema/schemas.py Co-authored-by: Chang Huan Lo --- aeon/schema/schemas.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/aeon/schema/schemas.py b/aeon/schema/schemas.py index 1d37b1d1..d62dc6ec 100644 --- a/aeon/schema/schemas.py +++ b/aeon/schema/schemas.py @@ -186,12 +186,4 @@ ) -__all__ = [ - "exp01", - "exp02", - "octagon01", - "social01", - "social02", - "social03", - "social04", -] +__all__ = ["exp01", "exp02", "octagon01", "social01", "social02", "social03", "social04"] From ac12ae06215ca0d9cffc11e656a12ff76b390b2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 14:57:02 +0000 Subject: [PATCH 097/143] Update aeon/schema/social_02.py Co-authored-by: Chang Huan Lo --- aeon/schema/social_02.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/aeon/schema/social_02.py b/aeon/schema/social_02.py index c3b64f8d..fd667031 100644 --- a/aeon/schema/social_02.py +++ b/aeon/schema/social_02.py @@ -16,10 +16,7 @@ class BlockState(Stream): def __init__(self, path): """Initializes the BlockState stream.""" super().__init__( - _reader.Csv( - f"{path}_BlockState_*", - columns=["pellet_ct", "pellet_ct_thresh", "due_time"], - ) + _reader.Csv(f"{path}_BlockState_*", columns=["pellet_ct", "pellet_ct_thresh", "due_time"]) ) class LightEvents(Stream): From cb124085223023780488b2602ae424c771ce321b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:01:36 +0000 Subject: [PATCH 098/143] Update aeon/dj_pipeline/analysis/block_analysis.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/analysis/block_analysis.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 185d6d05..d8f6d897 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -1416,10 +1416,7 @@ def make(self, key): & "attribute_name = 'Location'" ) rfid_locs = dict( - zip( - *rfid_location_query.fetch("rfid_reader_name", "attribute_value"), - strict=True, - ) + zip(*rfid_location_query.fetch("rfid_reader_name", "attribute_value"), strict=True) ) ## Create position ethogram df From 20adbdf80950cbda786a22b873519501abfde5b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:03:05 +0000 Subject: [PATCH 099/143] Update aeon/dj_pipeline/analysis/visit.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/analysis/visit.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index 6942c5f4..db96129f 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -131,7 +131,8 @@ def ingest_environment_visits(experiment_names: list | None = None): Args: experiment_names (list, optional): list of names of the experiment - to populate into the Visit table. Defaults to None. + to populate into the ``Visit`` table. + If unspecified, defaults to ``None`` and ``['exp0.2-r0']`` is used. """ if experiment_names is None: experiment_names = ["exp0.2-r0"] From ed722959f8052b33df5253071803c56a7cba1fec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:04:56 +0000 Subject: [PATCH 100/143] Update aeon/dj_pipeline/analysis/visit_analysis.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/analysis/visit_analysis.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 9149af01..a0224c00 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -195,10 +195,21 @@ def make(self, key): @classmethod def get_position(cls, visit_key=None, subject=None, start=None, end=None): - """Return a Pandas df of the subject's position data for a specified Visit given its key. - - Given a key to a single Visit, return a Pandas DataFrame for - the position data of the subject for the specified Visit time period. + """Retrieves a Pandas DataFrame of a subject's position data for a specified ``Visit``. + + A ``Visit`` is specified by either a ``visit_key`` or + a combination of ``subject``, ``start``, and ``end``. + If all four arguments are provided, the ``visit_key`` is ignored. + + Args: + visit_key (dict, optional): key to a single ``Visit``. + Only required if ``subject``, ``start``, and ``end`` are not provided. + subject (str, optional): subject name. + Only required if ``visit_key`` is not provided. + start (datetime): start time of the period of interest. + Only required if ``visit_key`` is not provided. + end (datetime, optional): end time of the period of interest. + Only required if ``visit_key`` is not provided. """ if visit_key is not None: if len(Visit & visit_key) != 1: From 89ac8b3608817dcf878d54427b8bc73431d56930 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:05:41 +0000 Subject: [PATCH 101/143] Update aeon/dj_pipeline/analysis/visit_analysis.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/analysis/visit_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index a0224c00..a53269e4 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -536,7 +536,7 @@ def make(self, key): @schema class VisitForagingBout(dj.Computed): - """Time period from when the animal enters to when it leaves a food patch while moving the wheel.""" + """Time period when a subject enters a food patch, moves the wheel, and then leaves the patch.""" definition = """ # Time from animal's entry to exit of a food patch while moving the wheel. -> Visit From 8bfdfcff819314cef5e0c49906cb499a1dbb4f16 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:06:17 +0000 Subject: [PATCH 102/143] Update aeon/dj_pipeline/analysis/visit_analysis.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/analysis/visit_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index a53269e4..27fdfc44 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -538,7 +538,7 @@ def make(self, key): class VisitForagingBout(dj.Computed): """Time period when a subject enters a food patch, moves the wheel, and then leaves the patch.""" - definition = """ # Time from animal's entry to exit of a food patch while moving the wheel. + definition = """ # Time from subject's entry to exit of a food patch to interact with the wheel. -> Visit -> acquisition.ExperimentFoodPatch bout_start: datetime(6) # start time of bout From 08ae6968bb7ef2668a8e2b5b43428a933ae63ff7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:07:11 +0000 Subject: [PATCH 103/143] Update aeon/dj_pipeline/create_experiments/create_socialexperiment.py Co-authored-by: Chang Huan Lo --- .../create_experiments/create_socialexperiment.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py index 757166e2..26643769 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py @@ -54,9 +54,6 @@ def create_new_social_experiment(experiment_name): ) acquisition.Experiment.Directory.insert(experiment_directories, skip_duplicates=True) acquisition.Experiment.DevicesSchema.insert1( - { - "experiment_name": experiment_name, - "devices_schema_name": exp_name.replace(".", ""), - }, + {"experiment_name": experiment_name, "devices_schema_name": exp_name.replace(".", "")}, skip_duplicates=True, ) From da1663073e5956db1cfc7f1cc57468f5834f5f07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:07:47 +0000 Subject: [PATCH 104/143] Update aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py Co-authored-by: Chang Huan Lo --- .../create_experiments/create_socialexperiment_0.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py index 3b13a1f3..ee2982a2 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py @@ -3,9 +3,7 @@ import pathlib from aeon.dj_pipeline import acquisition, lab, subject -from aeon.dj_pipeline.create_experiments.create_experiment_01 import ( - ingest_exp01_metadata, -) +from aeon.dj_pipeline.create_experiments.create_experiment_01 import ingest_exp01_metadata # ============ Manual and automatic steps to for experiment 0.1 populate ============ experiment_name = "social0-r1" From 65e2c6114f1b5e41674a0a8df9aa0bc33a32f5b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:09:38 +0000 Subject: [PATCH 105/143] Update aeon/dj_pipeline/populate/worker.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/populate/worker.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/aeon/dj_pipeline/populate/worker.py b/aeon/dj_pipeline/populate/worker.py index 92ba5f82..d84b7fce 100644 --- a/aeon/dj_pipeline/populate/worker.py +++ b/aeon/dj_pipeline/populate/worker.py @@ -1,11 +1,7 @@ """This module defines the workers for the AEON pipeline.""" import datajoint as dj -from datajoint_utilities.dj_worker import ( - DataJointWorker, - ErrorLog, - WorkerLog, -) +from datajoint_utilities.dj_worker import DataJointWorker, ErrorLog, WorkerLog from datajoint_utilities.dj_worker.worker_schema import is_djtable from aeon.dj_pipeline import acquisition, db_prefix, qc, subject, tracking From 2981366c9f7036f8cc54350f90595345dd7da1fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:12:23 +0000 Subject: [PATCH 106/143] Update aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py index 34ee1878..6e906f92 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp01.py @@ -13,15 +13,7 @@ schema_name_mapper = { source_db_prefix + schema_name: target_db_prefix + schema_name - for schema_name in ( - "lab", - "subject", - "acquisition", - "tracking", - "qc", - "report", - "analysis", - ) + for schema_name in ("lab", "subject", "acquisition", "tracking", "qc", "report", "analysis") } restriction = {"experiment_name": "exp0.1-r0"} From 26b90ea4f2efb9fffda1b8da7b4f789ea5e6051c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:12:56 +0000 Subject: [PATCH 107/143] Update aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py index 66050cd3..82d7815c 100644 --- a/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py +++ b/aeon/dj_pipeline/scripts/clone_and_freeze_exp02.py @@ -19,15 +19,7 @@ schema_name_mapper = { source_db_prefix + schema_name: target_db_prefix + schema_name - for schema_name in ( - "lab", - "subject", - "acquisition", - "tracking", - "qc", - "analysis", - "report", - ) + for schema_name in ("lab", "subject", "acquisition", "tracking", "qc", "analysis", "report") } restriction = [{"experiment_name": "exp0.2-r0"}, {"experiment_name": "social0-r1"}] From 4167c01bdf0d057128afeb3625575d2c53332831 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:17:04 +0000 Subject: [PATCH 108/143] Update aeon/dj_pipeline/subject.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/subject.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index 135c7077..d25c5742 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -99,10 +99,7 @@ def make(self, key): } ) Strain.insert1( - { - "strain_id": animal_resp["strain_id"], - "strain_name": animal_resp["strain_id"], - }, + {"strain_id": animal_resp["strain_id"], "strain_name": animal_resp["strain_id"]}, skip_duplicates=True, ) entry = { From ec5e143fa90981bdad10a4e824546f421ba72df4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:17:37 +0000 Subject: [PATCH 109/143] Update aeon/dj_pipeline/analysis/visit.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/analysis/visit.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index db96129f..68881021 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -123,11 +123,11 @@ def make(self, key): def ingest_environment_visits(experiment_names: list | None = None): - """Function to populate into `Visit` and `VisitEnd` for specified experiments (default: 'exp0.2-r0'). + """Populates ``Visit`` and ``VisitEnd`` for the specified experiment names. - This ingestion routine handles only those "complete" visits, - not ingesting any "on-going" visits using "analyze" method: - `aeon.analyze.utils.visits()`. + This ingestion routine includes only "complete" visits and + does not ingest any "on-going" visits. + Visits are retrieved using :func:`aeon.analysis.utils.visits`. Args: experiment_names (list, optional): list of names of the experiment From 8350deeb61d204a2e1b32cb0562c4ec07fdb3d2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:18:40 +0000 Subject: [PATCH 110/143] Update aeon/dj_pipeline/subject.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/subject.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index d25c5742..d26e0fb8 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -111,10 +111,7 @@ def make(self, key): } if animal_resp["gen_bg_id"] is not None: GeneticBackground.insert1( - { - "gen_bg_id": animal_resp["gen_bg_id"], - "gen_bg": animal_resp["gen_bg"], - }, + {"gen_bg_id": animal_resp["gen_bg_id"], "gen_bg": animal_resp["gen_bg"]}, skip_duplicates=True, ) entry["gen_bg_id"] = animal_resp["gen_bg_id"] From 43e6e9b942d555e8ed23d59e44bae33411cdab57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:20:35 +0000 Subject: [PATCH 111/143] Update aeon/dj_pipeline/tracking.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/tracking.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 56424a7d..8543a6cc 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -5,13 +5,7 @@ import numpy as np import pandas as pd -from aeon.dj_pipeline import ( - acquisition, - dict_to_uuid, - get_schema_name, - lab, - streams, -) +from aeon.dj_pipeline import acquisition, dict_to_uuid, get_schema_name, lab, streams from aeon.io import api as io_api from aeon.schema import schemas as aeon_schemas From 32d2ae8a9948e3941d2c5134e85fbb462bf575aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 6 Nov 2024 15:21:35 +0000 Subject: [PATCH 112/143] Update aeon/dj_pipeline/tracking.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/tracking.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 8543a6cc..4e3a2cbf 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -237,7 +237,14 @@ def make(self, key): def compute_distance(position_df, target, xcol="x", ycol="y"): - """Compute the distance of the position data from a target coordinate (X,Y).""" + """Compute the distance between the position and the target. + + Args: + position_df (pd.DataFrame): DataFrame containing the position data. + target (tuple): Tuple of length 2 indicating the target x and y position. + xcol (str): x column name in ``position_df``. Default is 'x'. + ycol (str): y column name in ``position_df``. Default is 'y'. + """ if len(target) != 2: # noqa PLR2004 raise ValueError("Target must be a list of tuple of length 2.") return np.sqrt(np.square(position_df[[xcol, ycol]] - target).sum(axis=1)) From 52b001736e1b085a555727d9452f0b2334573a4d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 17:18:39 +0000 Subject: [PATCH 113/143] Update aeon/dj_pipeline/create_experiments/create_experiment_02.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/create_experiments/create_experiment_02.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/create_experiments/create_experiment_02.py b/aeon/dj_pipeline/create_experiments/create_experiment_02.py index 10877546..081e6911 100644 --- a/aeon/dj_pipeline/create_experiments/create_experiment_02.py +++ b/aeon/dj_pipeline/create_experiments/create_experiment_02.py @@ -1,4 +1,4 @@ -"""Function to create new experiments for experiment0.2.""" +"""Functions to create new experiments for experiment0.2.""" from aeon.dj_pipeline import acquisition, lab, subject From 9e8183dee0b7a7531e5f1eeefb66ba1fed4d89e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 17:22:00 +0000 Subject: [PATCH 114/143] Update aeon/dj_pipeline/tracking.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/tracking.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 4e3a2cbf..4d0bf690 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -110,14 +110,10 @@ def insert_new_params( @schema class SLEAPTracking(dj.Imported): - """Tracking data from SLEAP for multi-animal experiments. + """Tracking data from SLEAP for multi-animal experiments.""" - Tracked objects position data from a particular - VideoSource for multi-animal experiment using the SLEAP tracking - method per chunk. - """ - - definition = """ + definition = """ # Tracked objects position data from a particular +VideoSource for multi-animal experiment using the SLEAP tracking method per chunk. -> acquisition.Chunk -> streams.SpinnakerVideoSource -> TrackingParamSet From a7ad6f30e2e0dc663ed8606088457990abadcb80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 17:28:16 +0000 Subject: [PATCH 115/143] Update aeon/dj_pipeline/tracking.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/tracking.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 4d0bf690..ce452030 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -249,7 +249,14 @@ def compute_distance(position_df, target, xcol="x", ycol="y"): def is_position_in_patch( position_df, patch_position, wheel_distance_travelled, patch_radius=0.2 ) -> pd.Series: - """The function returns a boolean array indicating whether the position is inside the patch.""" + """Returns a boolean array of whether a given position is inside the patch and the wheel is moving. + + Args: + position_df (pd.DataFrame): DataFrame containing the position data. + patch_position (tuple): Tuple of length 2 indicating the patch x and y position. + wheel_distance_travelled (pd.Series): distance travelled by the wheel. + patch_radius (float): Radius of the patch. Default is 0.2. + """ distance_from_patch = compute_distance(position_df, patch_position) in_patch = distance_from_patch < patch_radius exit_patch = in_patch.astype(np.int8).diff() < 0 From cf15a811ef402a6c54461c95e8b1c39d8079b9f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 17:34:56 +0000 Subject: [PATCH 116/143] Update aeon/dj_pipeline/utils/load_metadata.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/utils/load_metadata.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/aeon/dj_pipeline/utils/load_metadata.py b/aeon/dj_pipeline/utils/load_metadata.py index a4792f6b..1a7f2bde 100644 --- a/aeon/dj_pipeline/utils/load_metadata.py +++ b/aeon/dj_pipeline/utils/load_metadata.py @@ -242,11 +242,9 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath } ) - """ - Check if this device is currently installed. - If the same device serial number is currently installed check for changes in configuration. - If not, skip this. - """ + # Check if this device is currently installed. + # If the same device serial number is currently installed check for changes in configuration. + # If not, skip this. current_device_query = table - table.RemovalTime & experiment_key & device_key if current_device_query: From fd3baa0789b3fe4e7d77db76503756c1ae84696f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 17:37:12 +0000 Subject: [PATCH 117/143] Update aeon/dj_pipeline/utils/plotting.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/utils/plotting.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index d3641826..be9567e6 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -101,9 +101,7 @@ def plot_wheel_travelled_distance(session_keys): distance_travelled_df["in_arena"] = [ f'{subj_name}_{sess_start.strftime("%m/%d/%Y")}' for subj_name, sess_start in zip( - distance_travelled_df.subject, - distance_travelled_df.in_arena_start, - strict=False, + distance_travelled_df.subject, distance_travelled_df.in_arena_start, strict=False ) ] From f8d947d4a5bacf4582524d6e4fd2474ef8b54e92 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 17:37:58 +0000 Subject: [PATCH 118/143] Update aeon/dj_pipeline/utils/plotting.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/utils/plotting.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index be9567e6..976bd4b3 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -127,7 +127,7 @@ def plot_wheel_travelled_distance(session_keys): def plot_average_time_distribution(session_keys): - """Plotting the average time spent in different regions.""" + """Plots the average time spent in different regions.""" subject_list, arena_location_list, avg_time_spent_list = [], [], [] # Time spent in arena and corridor From d8f5887381953c6aa5dce4326d1e8b8e3a9adfee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 17:39:08 +0000 Subject: [PATCH 119/143] Update aeon/dj_pipeline/utils/plotting.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/utils/plotting.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index 976bd4b3..ea9d2103 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -218,12 +218,13 @@ def plot_visit_daily_summary( fig: Figure object Examples: - >>> fig = plot_visit_daily_summary(visit_key, attr='pellet_count', - per_food_patch=True) - >>> fig = plot_visit_daily_summary(visit_key, - attr='wheel_distance_travelled', per_food_patch=True) - >>> fig = plot_visit_daily_summary(visit_key, - attr='total_distance_travelled') + >>> fig = plot_visit_daily_summary(visit_key, attr='pellet_count', per_food_patch=True) + >>> fig = plot_visit_daily_summary( + ... visit_key, + ... attr="wheel_distance_travelled" + ... per_food_patch=True, + ... ) + >>> fig = plot_visit_daily_summary(visit_key, attr='total_distance_travelled') """ per_food_patch = not attr.startswith("total") color = "food_patch_description" if per_food_patch else None From 7e0e9c4978b8fa86849c74e59b7b85d093f82811 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 20:27:21 +0000 Subject: [PATCH 120/143] Update aeon/dj_pipeline/utils/plotting.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/utils/plotting.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index ea9d2103..03fe7b43 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -304,8 +304,13 @@ def plot_foraging_bouts_count( fig: Figure object Examples: - >>> fig = plot_foraging_bouts_count(visit_key, freq="D", - per_food_patch=True, min_bout_duration=1, min_wheel_dist=1) + >>> fig = plot_foraging_bouts_count( + ... visit_key, + ... freq="D", + ... per_food_patch=True, + ... min_bout_duration=1, + ... min_wheel_dist=1 + ... ) """ # Get all foraging bouts for the visit foraging_bouts = ( From ce1a22f6216eb0194f50596c738c27d366717a86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 20:28:09 +0000 Subject: [PATCH 121/143] Update aeon/dj_pipeline/utils/plotting.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/utils/plotting.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index 03fe7b43..82073962 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -469,13 +469,7 @@ def plot_foraging_bouts_distribution( width=700, height=400, template="simple_white", - legend={ - "orientation": "h", - "yanchor": "bottom", - "y": 1, - "xanchor": "right", - "x": 1, - }, + legend={"orientation": "h", "yanchor": "bottom", "y": 1, "xanchor": "right", "x": 1}, ) return fig From d5cab69cd0c1720cf2d1d2dcb06bb3119d2cdcf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 20:28:56 +0000 Subject: [PATCH 122/143] Update aeon/dj_pipeline/utils/plotting.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/utils/plotting.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index 82073962..4704b221 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -538,7 +538,8 @@ def _get_region_data(visit_key, attrs=None): Args: visit_key (dict): Key from the Visit table attrs (list, optional): List of column names (in VisitTimeDistribution tables) to retrieve. - Defaults is None, which will create a new list with the desired default values inside the function. + If unspecified, defaults to `None` and ``["in_nest", "in_arena", "in_corridor", "in_patch"]`` + is used. Returns: region (pd.DataFrame): Timestamped region info From 5f8f1277eefd78e721524f18facaca8cf07f976d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 20:29:41 +0000 Subject: [PATCH 123/143] Update aeon/dj_pipeline/utils/streams_maker.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/utils/streams_maker.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index 31ff161e..7cd01a79 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -25,10 +25,10 @@ class StreamType(dj.Lookup): """Catalog of all stream types used across Project Aeon. - Catalog of all steam types for the different device types used across Project Aeon. - One StreamType corresponds to one reader class in `aeon.io.reader`.The - combination of `stream_reader` and `stream_reader_kwargs` should fully specify the data - loading routine for a particular device, using the `aeon.io.utils`. + Catalog of all stream types for the different device types used across Project Aeon. + One StreamType corresponds to one Reader class in :mod:`aeon.io.reader`. + The combination of ``stream_reader`` and ``stream_reader_kwargs`` should fully specify the data + loading routine for a particular device, using :func:`aeon.io.api.load`. """ definition = """ # Catalog of all stream types used across Project Aeon From 89022c77d363779197f4548b42fe6b8d365b153d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 20:30:37 +0000 Subject: [PATCH 124/143] Update aeon/dj_pipeline/utils/streams_maker.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/utils/streams_maker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index 7cd01a79..adab3e54 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -34,7 +34,7 @@ class StreamType(dj.Lookup): definition = """ # Catalog of all stream types used across Project Aeon stream_type : varchar(20) --- - stream_reader : varchar(256) # name of the reader class found in `aeon_mecha` package (e.g. aeon.io.reader.Video) + stream_reader : varchar(256) # reader class name in aeon.io.reader (e.g. aeon.io.reader.Video) stream_reader_kwargs : longblob # keyword arguments to instantiate the reader class stream_description='': varchar(256) stream_hash : uuid # hash of dict(stream_reader_kwargs, stream_reader=stream_reader) From ff2ef8cbe3190517195b05bdafdc1bc0a8033ce8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 20:38:15 +0000 Subject: [PATCH 125/143] Update aeon/dj_pipeline/utils/streams_maker.py Co-authored-by: Chang Huan Lo --- aeon/dj_pipeline/utils/streams_maker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index adab3e54..5f654acc 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -39,7 +39,7 @@ class StreamType(dj.Lookup): stream_description='': varchar(256) stream_hash : uuid # hash of dict(stream_reader_kwargs, stream_reader=stream_reader) unique index (stream_hash) - """ # noqa: E501 + """ class DeviceType(dj.Lookup): From 42fdb2ceb402b6941b3b6a8556a9bc75b2cb6f74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 22:15:28 +0000 Subject: [PATCH 126/143] Update aeon/io/api.py Co-authored-by: Chang Huan Lo --- aeon/io/api.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aeon/io/api.py b/aeon/io/api.py index 2c1814e2..c45ec6f2 100644 --- a/aeon/io/api.py +++ b/aeon/io/api.py @@ -149,8 +149,7 @@ def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=No if not data.index.has_duplicates: warnings.warn( - f"data index for {reader.pattern} contains out-of-order timestamps!", - stacklevel=2, + f"data index for {reader.pattern} contains out-of-order timestamps!", stacklevel=2 ) data = data.sort_index() else: From 07e657efe434ac31d6fe3be6d9d3d50bcd0cf8e5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 22:21:13 +0000 Subject: [PATCH 127/143] Update aeon/io/api.py Co-authored-by: Chang Huan Lo --- aeon/io/api.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/aeon/io/api.py b/aeon/io/api.py index c45ec6f2..6cd4d1d5 100644 --- a/aeon/io/api.py +++ b/aeon/io/api.py @@ -153,10 +153,7 @@ def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=No ) data = data.sort_index() else: - warnings.warn( - f"data index for {reader.pattern} contains duplicate keys!", - stacklevel=2, - ) + warnings.warn(f"data index for {reader.pattern} contains duplicate keys!", stacklevel=2) data = data[~data.index.duplicated(keep="first")] return data.loc[start:end] return data From 29876d9843132d9cb4c16bb0c31668b8e305566b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 22:24:24 +0000 Subject: [PATCH 128/143] Update aeon/io/reader.py Co-authored-by: Chang Huan Lo --- aeon/io/reader.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/aeon/io/reader.py b/aeon/io/reader.py index 0623178a..5509fbea 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -74,11 +74,7 @@ def read(self, file): ticks = np.ndarray(length, dtype=np.uint16, buffer=data, offset=9, strides=stride) seconds = ticks * _SECONDS_PER_TICK + seconds payload = np.ndarray( - payloadshape, - dtype=payloadtype, - buffer=data, - offset=11, - strides=(stride, elementsize), + payloadshape, dtype=payloadtype, buffer=data, offset=11, strides=(stride, elementsize) ) if self.columns is not None and payloadshape[1] < len(self.columns): From 6d35eb906ddfc4e0c9b5aa3ef81af55b9b3e2598 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 22:25:08 +0000 Subject: [PATCH 129/143] Update aeon/schema/foraging.py Co-authored-by: Chang Huan Lo --- aeon/schema/foraging.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aeon/schema/foraging.py b/aeon/schema/foraging.py index 900af684..85601db6 100644 --- a/aeon/schema/foraging.py +++ b/aeon/schema/foraging.py @@ -54,7 +54,7 @@ class _Weight(_reader.Harp): """ def __init__(self, pattern): - """Initializes the Weight class.""" + """Initializes the Weight class.""" super().__init__(pattern, columns=["value", "stable"]) From e8ba3574772d9b7a5ead1b166692185bcc45d8cf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 22:27:45 +0000 Subject: [PATCH 130/143] Update aeon/io/reader.py Co-authored-by: Chang Huan Lo --- aeon/io/reader.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/aeon/io/reader.py b/aeon/io/reader.py index 5509fbea..a3ae37f2 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -398,8 +398,7 @@ def read(self, file: Path) -> pd.DataFrame: if bonsai_sleap_v == BONSAI_SLEAP_V3: # combine all identity_likelihood cols into a single col as dict part_data["identity_likelihood"] = part_data.apply( - lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, - axis=1, + lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, axis=1 ) part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True) part_data = part_data[ # reorder columns From 90a4ad58359f2e1537fd98d9f11b7d6325651327 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 22:30:22 +0000 Subject: [PATCH 131/143] Update aeon/io/reader.py Co-authored-by: Chang Huan Lo --- aeon/io/reader.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/aeon/io/reader.py b/aeon/io/reader.py index a3ae37f2..099b0fc1 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -402,13 +402,7 @@ def read(self, file: Path) -> pd.DataFrame: ) part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True) part_data = part_data[ # reorder columns - [ - "identity", - "identity_likelihood", - f"{part}_x", - f"{part}_y", - f"{part}_likelihood", - ] + ["identity", "identity_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"] ] part_data.insert(2, "part", part) part_data.columns = new_columns From 3aa1be23c48feed12a4323a30f1b1f169faa191f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Fri, 8 Nov 2024 22:31:09 +0000 Subject: [PATCH 132/143] Update aeon/schema/octagon.py Co-authored-by: Chang Huan Lo --- aeon/schema/octagon.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/aeon/schema/octagon.py b/aeon/schema/octagon.py index 50aba7cd..643ff77d 100644 --- a/aeon/schema/octagon.py +++ b/aeon/schema/octagon.py @@ -19,10 +19,7 @@ class BackgroundColor(Stream): def __init__(self, pattern): """Initializes the BackgroundColor stream.""" super().__init__( - _reader.Csv( - f"{pattern}_backgroundcolor_*", - columns=["typetag", "r", "g", "b", "a"], - ) + _reader.Csv(f"{pattern}_backgroundcolor_*", columns=["typetag", "r", "g", "b", "a"]) ) class ChangeSubjectState(Stream): From a1d9f75a8ea090b5db20552104e8a376e91ce65d Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Fri, 8 Nov 2024 23:05:07 +0000 Subject: [PATCH 133/143] fix: fix conflicts --- aeon/analysis/block_plotting.py | 3 +-- aeon/dj_pipeline/analysis/visit_analysis.py | 17 +++++++---------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/aeon/analysis/block_plotting.py b/aeon/analysis/block_plotting.py index 26f2f742..378c4713 100644 --- a/aeon/analysis/block_plotting.py +++ b/aeon/analysis/block_plotting.py @@ -63,8 +63,7 @@ def gen_subject_colors_dict(subject_names): def gen_patch_style_dict(patch_names): - """ - Generates a dictionary of patch styles given a list of patch_names. + """Generates a dictionary of patch styles given a list of patch_names. The dictionary contains dictionaries which map patch names to their respective styles. Below are the keys for each nested dictionary and their contents: diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 27fdfc44..6d9c77ce 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -6,16 +6,13 @@ import datajoint as dj import numpy as np import pandas as pd - from aeon.dj_pipeline import acquisition, lab, tracking -from aeon.dj_pipeline.analysis.visit import ( - Visit, - VisitEnd, - filter_out_maintenance_periods, - get_maintenance_periods, -) +from aeon.dj_pipeline.analysis.visit import (Visit, VisitEnd, + filter_out_maintenance_periods, + get_maintenance_periods) logger = dj.logger + # schema = dj.schema(get_schema_name("analysis")) schema = dj.schema() @@ -197,14 +194,14 @@ def make(self, key): def get_position(cls, visit_key=None, subject=None, start=None, end=None): """Retrieves a Pandas DataFrame of a subject's position data for a specified ``Visit``. - A ``Visit`` is specified by either a ``visit_key`` or - a combination of ``subject``, ``start``, and ``end``. + A ``Visit`` is specified by either a ``visit_key`` or + a combination of ``subject``, ``start``, and ``end``. If all four arguments are provided, the ``visit_key`` is ignored. Args: visit_key (dict, optional): key to a single ``Visit``. Only required if ``subject``, ``start``, and ``end`` are not provided. - subject (str, optional): subject name. + subject (str, optional): subject name. Only required if ``visit_key`` is not provided. start (datetime): start time of the period of interest. Only required if ``visit_key`` is not provided. From 73751c121339d24d27d864a48714960a8be9f4cb Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 12 Nov 2024 17:22:49 +0000 Subject: [PATCH 134/143] fix: two minor ruff checks --- aeon/dj_pipeline/analysis/visit_analysis.py | 10 +++++++--- aeon/dj_pipeline/tracking.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 6d9c77ce..1c76f50c 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -6,10 +6,14 @@ import datajoint as dj import numpy as np import pandas as pd + from aeon.dj_pipeline import acquisition, lab, tracking -from aeon.dj_pipeline.analysis.visit import (Visit, VisitEnd, - filter_out_maintenance_periods, - get_maintenance_periods) +from aeon.dj_pipeline.analysis.visit import ( + Visit, + VisitEnd, + filter_out_maintenance_periods, + get_maintenance_periods, +) logger = dj.logger diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index ce452030..ac545af2 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -234,7 +234,7 @@ def make(self, key): def compute_distance(position_df, target, xcol="x", ycol="y"): """Compute the distance between the position and the target. - + Args: position_df (pd.DataFrame): DataFrame containing the position data. target (tuple): Tuple of length 2 indicating the target x and y position. From dd830dfc69317d776109e41f4f9c07309435dec8 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 13 Nov 2024 13:05:42 +0000 Subject: [PATCH 135/143] fix: fix new ruff checks from latest PRs merged 446 --- aeon/dj_pipeline/__init__.py | 7 +++- aeon/dj_pipeline/analysis/block_analysis.py | 38 +++++++++---------- .../scripts/reingest_fullpose_sleap_data.py | 15 +++++--- .../scripts/sync_ingested_and_raw_epochs.py | 9 +++-- aeon/dj_pipeline/subject.py | 4 +- aeon/dj_pipeline/tracking.py | 13 ++++--- aeon/dj_pipeline/utils/load_metadata.py | 6 ++- aeon/schema/ingestion_schemas.py | 11 ++++-- aeon/schema/social_02.py | 4 +- tests/io/test_reader.py | 12 +++--- 10 files changed, 70 insertions(+), 49 deletions(-) diff --git a/aeon/dj_pipeline/__init__.py b/aeon/dj_pipeline/__init__.py index 9e8ed278..0c9500eb 100644 --- a/aeon/dj_pipeline/__init__.py +++ b/aeon/dj_pipeline/__init__.py @@ -57,7 +57,12 @@ def fetch_stream(query, drop_pk=True, round_microseconds=True): df.rename(columns={"timestamps": "time"}, inplace=True) df.set_index("time", inplace=True) df.sort_index(inplace=True) - df = df.convert_dtypes(convert_string=False, convert_integer=False, convert_boolean=False, convert_floating=False) + df = df.convert_dtypes( + convert_string=False, + convert_integer=False, + convert_boolean=False, + convert_floating=False + ) if not df.empty and round_microseconds: logging.warning("Rounding timestamps to microseconds is now enabled by default." " To disable, set round_microseconds=False.") diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 67605344..25641829 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -21,17 +21,8 @@ gen_subject_colors_dict, subject_colors, ) -from aeon.dj_pipeline import ( - acquisition, - fetch_stream, - get_schema_name, - streams, - tracking, -) -from aeon.dj_pipeline.analysis.visit import ( - filter_out_maintenance_periods, - get_maintenance_periods, -) +from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, tracking +from aeon.dj_pipeline.analysis.visit import filter_out_maintenance_periods, get_maintenance_periods from aeon.io import api as io_api schema = dj.schema(get_schema_name("block_analysis")) @@ -205,7 +196,9 @@ def make(self, key): if len(tracking.SLEAPTracking & chunk_keys) < len(tracking.SLEAPTracking.key_source & chunk_keys): if len(tracking.BlobPosition & chunk_keys) < len(tracking.BlobPosition.key_source & chunk_keys): raise ValueError( - f"BlockAnalysis Not Ready - SLEAPTracking (and BlobPosition) not yet fully ingested for block: {key}. Skipping (to retry later)..." + "BlockAnalysis Not Ready - " + f"SLEAPTracking (and BlobPosition) not yet fully ingested for block: {key}. " + "Skipping (to retry later)..." ) else: use_blob_position = True @@ -325,7 +318,8 @@ def make(self, key): if use_blob_position and len(subject_names) > 1: raise ValueError( - f"Without SLEAPTracking, BlobPosition can only handle single-subject block. Found {len(subject_names)} subjects." + f"Without SLEAPTracking, BlobPosition can only handle a single-subject block. " + f"Found {len(subject_names)} subjects." ) block_subject_entries = [] @@ -345,7 +339,9 @@ def make(self, key): pos_df = fetch_stream(pos_query)[block_start:block_end] pos_df["likelihood"] = np.nan # keep only rows with area between 0 and 1000 - likely artifacts otherwise - pos_df = pos_df[(pos_df.area > 0) & (pos_df.area < 1000)] + MIN_AREA = 0 + MAX_AREA = 1000 + pos_df = pos_df[(pos_df.area > MIN_AREA) & (pos_df.area < MAX_AREA)] else: pos_query = ( streams.SpinnakerVideoSource @@ -477,15 +473,19 @@ def make(self, key): # Ensure wheel_timestamps are of the same length across all patches wheel_lens = [len(p["wheel_timestamps"]) for p in block_patches] + MAX_WHEEL_DIFF = 10 + if len(set(wheel_lens)) > 1: max_diff = max(wheel_lens) - min(wheel_lens) - if max_diff > 10: + if max_diff > MAX_WHEEL_DIFF: # if diff is more than 10 samples, raise error, this is unexpected, some patches crash? - raise ValueError(f"Wheel data lengths are not consistent across patches ({max_diff} samples diff)") + raise ValueError( + f"Inconsistent wheel data lengths across patches ({max_diff} samples diff)" + ) + min_wheel_len = min(wheel_lens) for p in block_patches: - p["wheel_timestamps"] = p["wheel_timestamps"][: min(wheel_lens)] - p["wheel_cumsum_distance_travelled"] = p["wheel_cumsum_distance_travelled"][: min(wheel_lens)] - + p["wheel_timestamps"] = p["wheel_timestamps"][:min_wheel_len] + p["wheel_cumsum_distance_travelled"] = p["wheel_cumsum_distance_travelled"][:min_wheel_len] self.insert1(key) in_patch_radius = 130 # pixels diff --git a/aeon/dj_pipeline/scripts/reingest_fullpose_sleap_data.py b/aeon/dj_pipeline/scripts/reingest_fullpose_sleap_data.py index b3586f82..9411064b 100644 --- a/aeon/dj_pipeline/scripts/reingest_fullpose_sleap_data.py +++ b/aeon/dj_pipeline/scripts/reingest_fullpose_sleap_data.py @@ -1,4 +1,7 @@ +"""Functions to find and delete orphaned epochs that have been ingested but are no longer valid.""" + from datetime import datetime + from aeon.dj_pipeline import acquisition, tracking aeon_schemas = acquisition.aeon_schemas @@ -8,11 +11,10 @@ def find_chunks_to_reingest(exp_key, delete_not_fullpose=False): - """ - Find chunks with newly available full pose data to reingest. + """Find chunks with newly available full pose data to reingest. + If available, fullpose data can be found in `processed` folder """ - device_name = "CameraTop" devices_schema = getattr( @@ -21,13 +23,14 @@ def find_chunks_to_reingest(exp_key, delete_not_fullpose=False): "devices_schema_name" ), ) - stream_reader = getattr(getattr(devices_schema, device_name), "Pose") + stream_reader = getattr(devices_schema, device_name).Pose # special ingestion case for social0.2 full-pose data (using Pose reader from social03) if exp_key["experiment_name"].startswith("social0.2"): from aeon.io import reader as io_reader - stream_reader = getattr(getattr(devices_schema, device_name), "Pose03") - assert isinstance(stream_reader, io_reader.Pose), "Pose03 is not a Pose reader" + stream_reader = getattr(devices_schema, device_name).Pose03 + if not isinstance(stream_reader, io_reader.Pose): + raise TypeError("Pose03 is not a Pose reader") # find processed path for exp_key processed_dir = acquisition.Experiment.get_data_directory(exp_key, "processed") diff --git a/aeon/dj_pipeline/scripts/sync_ingested_and_raw_epochs.py b/aeon/dj_pipeline/scripts/sync_ingested_and_raw_epochs.py index 186355ce..3b247b7c 100644 --- a/aeon/dj_pipeline/scripts/sync_ingested_and_raw_epochs.py +++ b/aeon/dj_pipeline/scripts/sync_ingested_and_raw_epochs.py @@ -1,6 +1,9 @@ -import datajoint as dj +"""Functions to find and delete orphaned epochs that have been ingested but are no longer valid.""" + from datetime import datetime +import datajoint as dj + from aeon.dj_pipeline import acquisition, streams from aeon.dj_pipeline.analysis import block_analysis @@ -11,8 +14,8 @@ def find_orphaned_ingested_epochs(exp_key, delete_invalid_epochs=False): - """ - Find ingested epochs that are no longer valid + """Find ingested epochs that are no longer valid. + This is due to the raw epoch/chunk files/directories being deleted for whatever reason (e.g. bad data, testing, etc.) """ diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index c7cd1dac..71f2e1a4 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -461,8 +461,8 @@ def make(self, key): def get_pyrat_data(endpoint: str, params: dict = None, **kwargs): - """ - Get data from PyRat API. + """Get data from PyRat API. + See docs at: https://swc.pyrat.cloud/api/v3/docs (production) """ base_url = "https://swc.pyrat.cloud/api/v3/" diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 91358fd4..e5b0bf52 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd -from aeon.dj_pipeline import acquisition, dict_to_uuid, get_schema_name, lab, qc, streams, fetch_stream +from aeon.dj_pipeline import acquisition, dict_to_uuid, fetch_stream, get_schema_name, lab, qc, streams from aeon.io import api as io_api aeon_schemas = acquisition.aeon_schemas @@ -171,14 +171,15 @@ def make(self, key): ), ) - stream_reader = getattr(getattr(devices_schema, device_name), "Pose") + stream_reader = getattr(devices_schema, device_name).Pose # special ingestion case for social0.2 full-pose data (using Pose reader from social03) # fullpose for social0.2 has a different "pattern" for non-fullpose, hence the Pose03 reader if key["experiment_name"].startswith("social0.2"): from aeon.io import reader as io_reader - stream_reader = getattr(getattr(devices_schema, device_name), "Pose03") - assert isinstance(stream_reader, io_reader.Pose), "Pose03 is not a Pose reader" + stream_reader = getattr(devices_schema, device_name).Pose03 + if not isinstance(stream_reader, io_reader.Pose): + raise TypeError("Pose03 is not a Pose reader") data_dirs = [acquisition.Experiment.get_data_directory(key, "processed")] pose_data = io_api.load( @@ -258,7 +259,7 @@ class BlobPosition(dj.Imported): class Object(dj.Part): definition = """ # Position data of object tracked by a particular camera tracking -> master - object_id: int # object with id = -1 means "unknown/not sure", could potentially be the same object as those with other id value + object_id: int # id=-1 means "unknown"; could be the same object as those with other values --- identity_name='': varchar(16) sample_count: int # number of data points acquired from this stream for a given chunk @@ -270,6 +271,7 @@ class Object(dj.Part): @property def key_source(self): + """Return the keys to be processed.""" ks = ( acquisition.Chunk * ( @@ -282,6 +284,7 @@ def key_source(self): return ks - SLEAPTracking # do this only when SLEAPTracking is not available def make(self, key): + """Ingest blob position data for a given chunk.""" chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") data_dirs = acquisition.Experiment.get_data_directories(key) diff --git a/aeon/dj_pipeline/utils/load_metadata.py b/aeon/dj_pipeline/utils/load_metadata.py index edf24a22..e30f91ca 100644 --- a/aeon/dj_pipeline/utils/load_metadata.py +++ b/aeon/dj_pipeline/utils/load_metadata.py @@ -6,6 +6,7 @@ import pathlib from collections import defaultdict from pathlib import Path + import datajoint as dj import numpy as np from dotmap import DotMap @@ -37,8 +38,9 @@ def insert_stream_types(): existing_stream = (streams.StreamType.proj( "stream_reader", "stream_reader_kwargs") & {"stream_type": entry["stream_type"]}).fetch1() - if existing_stream["stream_reader_kwargs"].get("columns") != entry["stream_reader_kwargs"].get( - "columns"): + existing_columns = existing_stream["stream_reader_kwargs"].get("columns") + entry_columns = entry["stream_reader_kwargs"].get("columns") + if existing_columns != entry_columns: logger.warning(f"Stream type already exists:\n\t{entry}\n\t{existing_stream}") diff --git a/aeon/schema/ingestion_schemas.py b/aeon/schema/ingestion_schemas.py index fe2ee3dd..4a2d45d7 100644 --- a/aeon/schema/ingestion_schemas.py +++ b/aeon/schema/ingestion_schemas.py @@ -6,7 +6,8 @@ import aeon.schema.core as stream from aeon.io import reader -from aeon.io.api import aeon as aeon_time, chunk as aeon_chunk +from aeon.io.api import aeon as aeon_time +from aeon.io.api import chunk as aeon_chunk from aeon.schema import foraging, octagon, social_01, social_02, social_03 from aeon.schema.streams import Device, Stream, StreamGroup @@ -26,7 +27,8 @@ def read(self, file: PathLike[str], sr_hz: int = 50) -> pd.DataFrame: freq = 1 / sr_hz * 1e3 # convert to ms if first_index is not None: chunk_origin = aeon_chunk(first_index) - data = data.resample(f"{freq}ms", origin=chunk_origin).first() # take first sample in each resampled bin + data = data.resample(f"{freq}ms", origin=chunk_origin).first() + # take first sample in each resampled bin return data @@ -214,7 +216,10 @@ def __init__(self, path): social04 = DotMap( [ Device("Metadata", stream.Metadata), - Device("Environment", social_02.Environment, social_02.SubjectData, social_03.EnvironmentActiveConfiguration), + Device("Environment", + social_02.Environment, + social_02.SubjectData, + social_03.EnvironmentActiveConfiguration), Device("CameraTop", Video, stream.Position, social_03.Pose), Device("CameraNorth", Video), Device("CameraSouth", Video), diff --git a/aeon/schema/social_02.py b/aeon/schema/social_02.py index 0faafb21..4b5dbfe1 100644 --- a/aeon/schema/social_02.py +++ b/aeon/schema/social_02.py @@ -60,14 +60,14 @@ def __init__(self, path): class Pose03(Stream): - def __init__(self, path): + """Initializes the Pose stream.""" super().__init__(_reader.Pose(f"{path}_202_*")) class WeightRaw(Stream): def __init__(self, path): - """Initialize the WeightRaw stream.""" + """Initializes the WeightRaw stream.""" super().__init__(_reader.Harp(f"{path}_200_*", ["weight(g)", "stability"])) diff --git a/tests/io/test_reader.py b/tests/io/test_reader.py index 640768ab..e702c7e1 100644 --- a/tests/io/test_reader.py +++ b/tests/io/test_reader.py @@ -1,25 +1,25 @@ +"""Tests for the Pose stream.""" + from pathlib import Path import pytest -from pytest import mark import aeon from aeon.schema.schemas import social02, social03 pose_path = Path(__file__).parent.parent / "data" / "pose" - -@mark.api +@pytest.mark.api def test_Pose_read_local_model_dir(): + """Test that the Pose stream can read a local model directory.""" data = aeon.load(pose_path, social02.CameraTop.Pose) assert len(data) > 0 - -@mark.api +@pytest.mark.api def test_Pose_read_local_model_dir_with_register_prefix(): + """Test that the Pose stream can read a local model directory with a register prefix.""" data = aeon.load(pose_path, social03.CameraTop.Pose) assert len(data) > 0 - if __name__ == "__main__": pytest.main() From 50ce61935f9c89ffbc71fe41107945add73e2695 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 13 Nov 2024 13:07:08 +0000 Subject: [PATCH 136/143] chore: update gitignore with `.coverage` --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 2d1a3823..2599b3ba 100644 --- a/.gitignore +++ b/.gitignore @@ -24,4 +24,7 @@ dj_local_conf.json log*.txt scratch/ scratch*.py -**/*.nfs* \ No newline at end of file +**/*.nfs* + +# Test +.coverage \ No newline at end of file From 387277b5bcc85193cd3c8d0a32c06335a209aff0 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 13 Nov 2024 13:35:27 +0000 Subject: [PATCH 137/143] fix: add suggestions from reviewer --- aeon/dj_pipeline/analysis/block_analysis.py | 23 +++++++++-------- aeon/dj_pipeline/analysis/visit.py | 18 +++---------- aeon/dj_pipeline/analysis/visit_analysis.py | 11 +++----- .../create_experiments/create_octagon_1.py | 2 +- .../create_experiments/create_presocial.py | 2 +- .../create_socialexperiment.py | 2 +- .../create_socialexperiment_0.py | 2 +- aeon/dj_pipeline/qc.py | 2 +- aeon/dj_pipeline/subject.py | 19 +++++++------- aeon/dj_pipeline/utils/paths.py | 4 +-- aeon/dj_pipeline/utils/streams_maker.py | 25 ++++++++++--------- pyproject.toml | 1 - 12 files changed, 49 insertions(+), 62 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 25641829..2a233d8b 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -3,7 +3,7 @@ import itertools import json from collections import defaultdict -from datetime import datetime, timezone +from datetime import UTC, datetime import datajoint as dj import numpy as np @@ -265,7 +265,7 @@ def make(self, key): # log a note and pick the first rate to move forward AnalysisNote.insert1( { - "note_timestamp": datetime.now(timezone.utc), + "note_timestamp": datetime.now(UTC), "note_type": "Multiple patch rates", "note": ( f"Found multiple patch rates for block {key} " @@ -1621,18 +1621,20 @@ class AnalysisNote(dj.Manual): def get_threshold_associated_pellets(patch_key, start, end): - """Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time. + """Gets pellet delivery timestamps for each patch threshold update within the specified time range. 1. Get all patch state update timestamps (DepletionState): let's call these events "A" - - Remove all events within 1 second of each other - - Remove all events without threshold value (NaN) + + - Remove all events within 1 second of each other + - Remove all events without threshold value (NaN) 2. Get all pellet delivery timestamps (DeliverPellet): let's call these events "B" - - Find matching beam break timestamps within 1.2s after each pellet delivery + + - Find matching beam break timestamps within 1.2s after each pellet delivery 3. For each event "A", find the nearest event "B" within 100ms before or after the event "A" - - These are the pellet delivery events "B" associated with the previous threshold update - event "A" + + - These are the pellet delivery events "B" associated with the previous threshold update event "A" 4. Shift back the pellet delivery timestamps by 1 to match the pellet delivery with the - previous threshold update + previous threshold update 5. Remove all threshold updates events "A" without a corresponding pellet delivery event "B" Args: @@ -1642,12 +1644,13 @@ def get_threshold_associated_pellets(patch_key, start, end): Returns: pd.DataFrame: DataFrame with the following columns: + - threshold_update_timestamp (index) - pellet_timestamp - beam_break_timestamp - offset - rate - """ # noqa 501 + """ chunk_restriction = acquisition.create_chunk_restriction(patch_key["experiment_name"], start, end) # Step 1 - fetch data diff --git a/aeon/dj_pipeline/analysis/visit.py b/aeon/dj_pipeline/analysis/visit.py index 68881021..2d3d43fd 100644 --- a/aeon/dj_pipeline/analysis/visit.py +++ b/aeon/dj_pipeline/analysis/visit.py @@ -1,21 +1,14 @@ """Module for visit-related tables in the analysis schema.""" from collections import deque -from datetime import datetime, timezone +from datetime import UTC, datetime import datajoint as dj import numpy as np import pandas as pd from aeon.analysis import utils as analysis_utils -from aeon.dj_pipeline import ( - acquisition, - fetch_stream, - get_schema_name, - lab, - qc, - tracking, -) +from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name schema = dj.schema(get_schema_name("analysis")) @@ -146,7 +139,7 @@ def ingest_environment_visits(experiment_names: list | None = None): .fetch("last_visit") ) start = min(subjects_last_visits) if len(subjects_last_visits) else "1900-01-01" - end = datetime.now(timezone.utc) if start else "2200-01-01" + end = datetime.now(UTC) if start else "2200-01-01" enter_exit_query = ( acquisition.SubjectEnterExit.Time * acquisition.EventType @@ -161,10 +154,7 @@ def ingest_environment_visits(experiment_names: list | None = None): enter_exit_df = pd.DataFrame( zip( *enter_exit_query.fetch( - "subject", - "enter_exit_time", - "event_type", - order_by="enter_exit_time", + "subject", "enter_exit_time", "event_type", order_by="enter_exit_time" ), strict=False, ) diff --git a/aeon/dj_pipeline/analysis/visit_analysis.py b/aeon/dj_pipeline/analysis/visit_analysis.py index 1c76f50c..fe6db2c3 100644 --- a/aeon/dj_pipeline/analysis/visit_analysis.py +++ b/aeon/dj_pipeline/analysis/visit_analysis.py @@ -219,15 +219,10 @@ def get_position(cls, visit_key=None, subject=None, start=None, end=None): Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") & visit_key ).fetch1("visit_start", "visit_end") subject = visit_key["subject"] - elif all((subject, start, end)): - start = start # noqa PLW0127 - end = end # noqa PLW0127 - subject = subject # noqa PLW0127 - else: + elif not all((subject, start, end)): raise ValueError( - 'Either "visit_key" or all three "subject", "start" and "end" has to be specified' - ) - + 'Either "visit_key" or all three "subject", "start", and "end" must be specified.' + ) return tracking._get_position( cls.TimeSlice, object_attr="subject", diff --git a/aeon/dj_pipeline/create_experiments/create_octagon_1.py b/aeon/dj_pipeline/create_experiments/create_octagon_1.py index edbfdd64..3fc01ef6 100644 --- a/aeon/dj_pipeline/create_experiments/create_octagon_1.py +++ b/aeon/dj_pipeline/create_experiments/create_octagon_1.py @@ -1,4 +1,4 @@ -"""Function to create new experiments for octagon1.0.""" +"""Functions to create new experiments for octagon1.0.""" from aeon.dj_pipeline import acquisition, subject diff --git a/aeon/dj_pipeline/create_experiments/create_presocial.py b/aeon/dj_pipeline/create_experiments/create_presocial.py index 7a1ce9a3..3e9b8f76 100644 --- a/aeon/dj_pipeline/create_experiments/create_presocial.py +++ b/aeon/dj_pipeline/create_experiments/create_presocial.py @@ -1,4 +1,4 @@ -"""Function to create new experiments for presocial0.1.""" +"""Functions to create new experiments for presocial0.1.""" from aeon.dj_pipeline import acquisition, lab, subject diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py index 26643769..a3e60a32 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment.py @@ -1,4 +1,4 @@ -"""Function to create new social experiments.""" +"""Functions to create new social experiments.""" from datetime import datetime diff --git a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py index ee2982a2..7d10f734 100644 --- a/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py +++ b/aeon/dj_pipeline/create_experiments/create_socialexperiment_0.py @@ -1,4 +1,4 @@ -"""Function to create new experiments for social0-r1.""" +"""Functions to create new experiments for social0-r1.""" import pathlib diff --git a/aeon/dj_pipeline/qc.py b/aeon/dj_pipeline/qc.py index cf357c0f..b54951dd 100644 --- a/aeon/dj_pipeline/qc.py +++ b/aeon/dj_pipeline/qc.py @@ -39,7 +39,7 @@ class QCRoutine(dj.Lookup): @schema class CameraQC(dj.Imported): - definition = """ # Quality controls performed on a particular camera for a particular acquisition chunk + definition = """ # Quality controls performed on a particular camera for one acquisition chunk -> acquisition.Chunk -> streams.SpinnakerVideoSource --- diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index 71f2e1a4..3ff95770 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -3,7 +3,7 @@ import json import os import time -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta import datajoint as dj import requests @@ -187,7 +187,7 @@ def get_reference_weight(cls, subject_name): 0 ] else: - ref_date = datetime.now(timezone.utc).date() + ref_date = datetime.now(UTC).date() weight_query = SubjectWeight & subj_key & f"weight_time < '{ref_date}'" ref_weight = ( @@ -197,7 +197,7 @@ def get_reference_weight(cls, subject_name): entry = { "subject": subject_name, "reference_weight": ref_weight, - "last_updated_time": datetime.now(timezone.utc), + "last_updated_time": datetime.now(UTC), } cls.update1(entry) if cls & {"subject": subject_name} else cls.insert1(entry) @@ -240,7 +240,7 @@ class PyratIngestion(dj.Imported): def _auto_schedule(self): """Automatically schedule the next task.""" - utc_now = datetime.now(timezone.utc) + utc_now = datetime.now(UTC) next_task_schedule_time = utc_now + timedelta(hours=self.schedule_interval) if ( @@ -253,7 +253,7 @@ def _auto_schedule(self): def make(self, key): """Automatically import or update entries in the Subject table.""" - execution_time = datetime.now(timezone.utc) + execution_time = datetime.now(UTC) new_eartags = [] for responsible_id in lab.User.fetch("responsible_id"): # 1 - retrieve all animals from this user @@ -288,7 +288,7 @@ def make(self, key): new_entry_count += 1 logger.info(f"Inserting {new_entry_count} new subject(s) from Pyrat") - completion_time = datetime.now(timezone.utc) + completion_time = datetime.now(UTC) self.insert1( { **key, @@ -319,7 +319,7 @@ class PyratCommentWeightProcedure(dj.Imported): def make(self, key): """Automatically import or update entries in the PyratCommentWeightProcedure table.""" - execution_time = datetime.now(timezone.utc) + execution_time = datetime.now(UTC) logger.info("Extracting weights/comments/procedures") eartag_or_id = key["subject"] @@ -372,8 +372,7 @@ def make(self, key): "lab_id": animal_resp["labid"], } ) - - completion_time = datetime.now(timezone.utc) + completion_time = datetime.now(UTC) self.insert1( { **key, @@ -391,7 +390,7 @@ class CreatePyratIngestionTask(dj.Computed): def make(self, key): """Create one new PyratIngestionTask for every newly added users.""" - PyratIngestionTask.insert1({"pyrat_task_scheduled_time": datetime.now(timezone.utc)}) + PyratIngestionTask.insert1({"pyrat_task_scheduled_time": datetime.now(UTC)}) time.sleep(1) self.insert1(key) diff --git a/aeon/dj_pipeline/utils/paths.py b/aeon/dj_pipeline/utils/paths.py index ebba44b5..75b459d4 100644 --- a/aeon/dj_pipeline/utils/paths.py +++ b/aeon/dj_pipeline/utils/paths.py @@ -34,7 +34,7 @@ def get_repository_path(repository_name: str) -> pathlib.Path: def find_root_directory( root_directories: str | pathlib.Path, full_path: str | pathlib.Path ) -> pathlib.Path: - """Given multiple potential root directories and a full-path, search and return one directory that is the parent of the given path. + """Finds the parent directory of a given full path among multiple potential root directories. Args: root_directories (str | pathlib.Path): A list of potential root directories. @@ -46,7 +46,7 @@ def find_root_directory( Returns: pathlib.Path: The full path to the discovered root directory. - """ # noqa E501 + """ full_path = pathlib.Path(full_path) if not full_path.exists(): diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index 6f3134c4..cffc5345 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -75,21 +75,21 @@ def get_device_template(device_type: str): device_type = dj.utils.from_camel_case(device_type) class ExperimentDevice(dj.Manual): - definition = f""" # {device_title} placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-{aeon.__version__}) + definition = f"""# {device_title} operation for time,location, experiment (v{aeon.__version__}) -> acquisition.Experiment -> Device - {device_type}_install_time : datetime(6) # time of the {device_type} placed and started operation at this position + {device_type}_install_time : datetime(6) # {device_type} time of placement and start operation --- - {device_type}_name : varchar(36) - """ # noqa: E501 + {device_type}_name : varchar(36) + """ class Attribute(dj.Part): - definition = """ # metadata/attributes (e.g. FPS, config, calibration, etc.) associated with this experimental device + definition = """ # Metadata (e.g. FPS, config, calibration) for this experimental device -> master attribute_name : varchar(32) --- attribute_value=null : longblob - """ # noqa: E501 + """ class RemovalTime(dj.Part): definition = f""" @@ -270,17 +270,18 @@ def main(create_tables=True): device_stream_table_def = inspect.getsource(table_class).lstrip() # Replace the definition + device_type_name = dj.utils.from_camel_case(device_type) replacements = { "DeviceDataStream": f"{device_type}{stream_type}", "ExperimentDevice": device_type, - 'f"chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time"': ( - f"'chunk_start >= {dj.utils.from_camel_case(device_type)}_install_time'" + 'f"chunk_start >= {device_type_name}_install_time"': ( + f"'chunk_start >= {device_type_name}_install_time'" ), - """f'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time, "2200-01-01")'""": ( # noqa E501 - f"""'chunk_start < IFNULL({dj.utils.from_camel_case(device_type)}_removal_time,"2200-01-01")'""" # noqa E501 + """f'chunk_start < IFNULL({device_type_name}_removal_time, "2200-01-01")'""": ( + f"""'chunk_start < IFNULL({device_type_name}_removal_time,"2200-01-01")'""" ), - 'f"{dj.utils.from_camel_case(device_type)}_name"': ( - f"'{dj.utils.from_camel_case(device_type)}_name'" + 'f"{device_type_name}_name"': ( + f"'{device_type_name}_name'" ), "{device_type}": device_type, "{stream_type}": stream_type, diff --git a/pyproject.toml b/pyproject.toml index 96c4a98c..0d7d7ec7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -102,7 +102,6 @@ lint.ignore = [ "PLR0912", "PLR0913", "PLR0915", - "UP017" # skip `datetime.UTC` alias ] extend-exclude = [ ".git", From 50f5ff1aadfe4826a0dc337294d5f47a3f0bafb8 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 13 Nov 2024 14:02:49 +0000 Subject: [PATCH 138/143] fix: add suggested doctstring from reviewer --- aeon/dj_pipeline/utils/plotting.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/aeon/dj_pipeline/utils/plotting.py b/aeon/dj_pipeline/utils/plotting.py index 4704b221..5a160b70 100644 --- a/aeon/dj_pipeline/utils/plotting.py +++ b/aeon/dj_pipeline/utils/plotting.py @@ -25,16 +25,17 @@ def plot_reward_rate_differences(subject_keys): - """Plotting the reward rate differences between food patches (Patch 2 - Patch 1) for all sessions from all subjects specified in "subject_keys". + """Plots the reward rate differences between two food patches (Patch 2 - Patch 1). - Examples: - ``` - subject_keys = - (acquisition.Experiment.Subject & 'experiment_name = "exp0.1-r0"').fetch('KEY') + The reward rate differences between the two food patches are plotted + for all sessions from all subjects in ``subject_keys``. - fig = plot_reward_rate_differences(subject_keys) - ``` - """ # noqa E501 + Examples: + >>> subject_keys = ( + ... acquisition.Experiment.Subject + ... & 'experiment_name = "exp0.1-r0"').fetch('KEY') + >>> fig = plot_reward_rate_differences(subject_keys) + """ subj_names, sess_starts, rate_timestamps, rate_diffs = ( analysis.InArenaRewardRate & subject_keys ).fetch("subject", "in_arena_start", "pellet_rate_timestamps", "patch2_patch1_rate_diff") From 7202221b6a3249d59a75a0d193450f147e6062bb Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 13 Nov 2024 14:11:37 +0000 Subject: [PATCH 139/143] fix: fix definition comment --- aeon/dj_pipeline/tracking.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index e5b0bf52..350c215b 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -5,7 +5,7 @@ import numpy as np import pandas as pd -from aeon.dj_pipeline import acquisition, dict_to_uuid, fetch_stream, get_schema_name, lab, qc, streams +from aeon.dj_pipeline import acquisition, dict_to_uuid, fetch_stream, get_schema_name, lab, streams from aeon.io import api as io_api aeon_schemas = acquisition.aeon_schemas @@ -113,8 +113,7 @@ def insert_new_params( class SLEAPTracking(dj.Imported): """Tracking data from SLEAP for multi-animal experiments.""" - definition = """ # Tracked objects position data from a particular -VideoSource for multi-animal experiment using the SLEAP tracking method per chunk. + definition = """ # Position data from a VideoSource for multi-animal experiments using SLEAP per chunk -> acquisition.Chunk -> streams.SpinnakerVideoSource -> TrackingParamSet From 6be26d1033f2d8e5d5fbcc2636157cb3b6b53075 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 13 Nov 2024 14:26:46 +0000 Subject: [PATCH 140/143] fix: remove noqa --- aeon/dj_pipeline/utils/streams_maker.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index cffc5345..f04af930 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -123,13 +123,14 @@ def get_device_stream_template(device_type: str, stream_type: str, streams_modul stream = reader(**stream_detail["stream_reader_kwargs"]) - table_definition = f""" # Raw per-chunk {stream_type} data stream from {device_type} (auto-generated with aeon_mecha-{aeon.__version__}) + ver = aeon.__version__ + table_definition = f""" # Raw per-chunk {stream_type} from {device_type}(auto-generated with v{ver}) -> {device_type} -> acquisition.Chunk --- sample_count: int # number of data points acquired from this stream for a given chunk timestamps: longblob # (datetime) timestamps of {stream_type} data - """ # noqa: E501 + """ for col in stream.columns: if col.startswith("_"): @@ -142,11 +143,12 @@ class DeviceDataStream(dj.Imported): @property def key_source(self): - f"""Only the combination of Chunk and {device_type} with overlapping time. + docstring = f"""Only the combination of Chunk and {device_type} with overlapping time. - + Chunk(s) that started after {device_type} install time and ended before {device_type} remove time - + Chunk(s) that started after {device_type} install time for {device_type} that are not yet removed - """ # noqa B021 + + Chunk(s) started after {device_type} install time & ended before {device_type} remove time + + Chunk(s) started after {device_type} install time for {device_type} and not yet removed + """ + self.__doc__ = docstring device_type_name = dj.utils.from_camel_case(device_type) return ( acquisition.Chunk * ExperimentDevice.join(ExperimentDevice.RemovalTime, left=True) From 6279bda3faa16916a2a84226e9ee82e669a88367 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 13 Nov 2024 14:35:01 +0000 Subject: [PATCH 141/143] fix: solve #noqa's --- aeon/__init__.py | 4 +++- aeon/dj_pipeline/tracking.py | 3 ++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/aeon/__init__.py b/aeon/__init__.py index 48a87f97..f5cb7fe7 100644 --- a/aeon/__init__.py +++ b/aeon/__init__.py @@ -12,4 +12,6 @@ del version, PackageNotFoundError # Set functions available directly under the 'aeon' top-level namespace -from aeon.io.api import load as load # noqa: PLC0414 +from aeon.io.api import load + +__all__ = ["load"] diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 350c215b..559d08fb 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -370,7 +370,8 @@ def compute_distance(position_df, target, xcol="x", ycol="y"): xcol (str): x column name in ``position_df``. Default is 'x'. ycol (str): y column name in ``position_df``. Default is 'y'. """ - if len(target) != 2: # noqa PLR2004 + COORDS = 2 # x, y + if len(target) != COORDS: raise ValueError("Target must be a list of tuple of length 2.") return np.sqrt(np.square(position_df[[xcol, ycol]] - target).sum(axis=1)) From 3149f29187d5f317cdc9412826d6ea07dbf43740 Mon Sep 17 00:00:00 2001 From: lochhh Date: Mon, 18 Nov 2024 17:22:33 +0000 Subject: [PATCH 142/143] Fix indentations in docstrings --- aeon/io/api.py | 2 +- aeon/io/reader.py | 72 +++++++++++++++++++++++++++-------------------- aeon/io/video.py | 8 +++--- 3 files changed, 46 insertions(+), 36 deletions(-) diff --git a/aeon/io/api.py b/aeon/io/api.py index 6cd4d1d5..22a11ce7 100644 --- a/aeon/io/api.py +++ b/aeon/io/api.py @@ -75,7 +75,7 @@ def load(root, reader, start=None, end=None, time=None, tolerance=None, epoch=No :param datetime, optional end: The right bound of the time range to extract. :param datetime, optional time: An object or series specifying the timestamps to extract. :param datetime, optional tolerance: - The maximum distance between original and new timestamps for inexact matches. + The maximum distance between original and new timestamps for inexact matches. :param str, optional epoch: A wildcard pattern to use when searching epoch data. :param optional kwargs: Optional keyword arguments to forward to the reader when reading chunk data. :return: A pandas data frame containing epoch event metadata, sorted by time. diff --git a/aeon/io/reader.py b/aeon/io/reader.py index 53d400ad..dbf574ec 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -169,10 +169,11 @@ class Subject(Csv): """Extracts metadata for subjects entering and exiting the environment. Columns: - id (str): Unique identifier of a subject in the environment. - weight (float): Weight measurement of the subject on entering - or exiting the environment. - event (str): Event type. Can be one of `Enter`, `Exit` or `Remain`. + + - id (str): Unique identifier of a subject in the environment. + - weight (float): Weight measurement of the subject on entering + or exiting the environment. + - event (str): Event type. Can be one of `Enter`, `Exit` or `Remain`. """ def __init__(self, pattern): @@ -184,10 +185,11 @@ class Log(Csv): """Extracts message log data. Columns: - priority (str): Priority level of the message. - type (str): Type of the log message. - message (str): Log message data. Can be structured using tab - separated values. + + - priority (str): Priority level of the message. + - type (str): Type of the log message. + - message (str): Log message data. Can be structured using tab + separated values. """ def __init__(self, pattern): @@ -199,7 +201,8 @@ class Heartbeat(Harp): """Extract periodic heartbeat event data. Columns: - second (int): The whole second corresponding to the heartbeat, in seconds. + + - second (int): The whole second corresponding to the heartbeat, in seconds. """ def __init__(self, pattern): @@ -211,8 +214,9 @@ class Encoder(Harp): """Extract magnetic encoder data. Columns: - angle (float): Absolute angular position, in radians, of the magnetic encoder. - intensity (float): Intensity of the magnetic field. + + - angle (float): Absolute angular position, in radians, of the magnetic encoder. + - intensity (float): Intensity of the magnetic field. """ def __init__(self, pattern): @@ -224,15 +228,16 @@ class Position(Harp): """Extract 2D position tracking data for a specific camera. Columns: - x (float): x-coordinate of the object center of mass. - y (float): y-coordinate of the object center of mass. - angle (float): angle, in radians, of the ellipse fit to the object. - major (float): length, in pixels, of the major axis of the ellipse - fit to the object. - minor (float): length, in pixels, of the minor axis of the ellipse - fit to the object. - area (float): number of pixels in the object mass. - id (float): unique tracking ID of the object in a frame. + + - x (float): x-coordinate of the object center of mass. + - y (float): y-coordinate of the object center of mass. + - angle (float): angle, in radians, of the ellipse fit to the object. + - major (float): length, in pixels, of the major axis of the ellipse + fit to the object. + - minor (float): length, in pixels, of the minor axis of the ellipse + fit to the object. + - area (float): number of pixels in the object mass. + - id (float): unique tracking ID of the object in a frame. """ def __init__(self, pattern): @@ -244,7 +249,8 @@ class BitmaskEvent(Harp): """Extracts event data matching a specific digital I/O bitmask. Columns: - event (str): Unique identifier for the event code. + + - event (str): Unique identifier for the event code. """ def __init__(self, pattern, value, tag): @@ -268,7 +274,8 @@ class DigitalBitmask(Harp): """Extracts event data matching a specific digital I/O bitmask. Columns: - event (str): Unique identifier for the event code. + + - event (str): Unique identifier for the event code. """ def __init__(self, pattern, mask, columns): @@ -290,8 +297,9 @@ class Video(Csv): """Extracts video frame metadata. Columns: - hw_counter (int): Hardware frame counter value for the current frame. - hw_timestamp (int): Internal camera timestamp for the current frame. + + - hw_counter (int): Hardware frame counter value for the current frame. + - hw_timestamp (int): Internal camera timestamp for the current frame. """ def __init__(self, pattern): @@ -313,12 +321,13 @@ class Pose(Harp): """Reader for Harp-binarized tracking data given a model that outputs id, parts, and likelihoods. Columns: - class (int): Int ID of a subject in the environment. - class_likelihood (float): Likelihood of the subject's identity. - part (str): Bodypart on the subject. - part_likelihood (float): Likelihood of the specified bodypart. - x (float): X-coordinate of the bodypart. - y (float): Y-coordinate of the bodypart. + + - class (int): Int ID of a subject in the environment. + - class_likelihood (float): Likelihood of the subject's identity. + - part (str): Bodypart on the subject. + - part_likelihood (float): Likelihood of the specified bodypart. + - x (float): X-coordinate of the bodypart. + - y (float): Y-coordinate of the bodypart. """ def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/processed"): @@ -401,7 +410,8 @@ def read(self, file: Path) -> pd.DataFrame: if bonsai_sleap_v == BONSAI_SLEAP_V3: # combine all identity_likelihood cols into a single col as dict part_data["identity_likelihood"] = part_data.apply( - lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, axis=1 + lambda row: {identity: row[f"{identity}_likelihood"] for identity in identities}, + axis=1, ) part_data.drop(columns=columns[1 : (len(identities) + 1)], inplace=True) part_data = part_data[ # reorder columns diff --git a/aeon/io/video.py b/aeon/io/video.py index 658379e7..36f709ab 100644 --- a/aeon/io/video.py +++ b/aeon/io/video.py @@ -7,10 +7,10 @@ def frames(data): """Extracts the raw frames corresponding to the provided video metadata. :param DataFrame data: - A pandas DataFrame where each row specifies video acquisition path and frame number. + A pandas DataFrame where each row specifies video acquisition path and frame number. :return: - An object to iterate over numpy arrays for each row in the DataFrame, - containing the raw video frame data. + An object to iterate over numpy arrays for each row in the DataFrame, + containing the raw video frame data. """ capture = None filename = None @@ -44,7 +44,7 @@ def export(frames, file, fps, fourcc=None): :param str file: The path to the exported video file. :param fps: The frame rate of the exported video. :param optional fourcc: - Specifies the four character code of the codec used to compress the frames. + Specifies the four character code of the codec used to compress the frames. """ writer = None try: From c41bd02a45646b62a0edf51c7e7e3044b3dc6c15 Mon Sep 17 00:00:00 2001 From: lochhh Date: Tue, 19 Nov 2024 10:34:23 +0000 Subject: [PATCH 143/143] Standardise module descriptions --- aeon/io/video.py | 2 +- aeon/schema/octagon.py | 2 +- aeon/schema/social_01.py | 2 +- aeon/schema/social_02.py | 2 +- aeon/schema/social_03.py | 2 +- aeon/schema/streams.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/aeon/io/video.py b/aeon/io/video.py index 36f709ab..bd5326fa 100644 --- a/aeon/io/video.py +++ b/aeon/io/video.py @@ -1,4 +1,4 @@ -"""This module provides functions to read and write video files using OpenCV.""" +"""Module for reading and writing video files using OpenCV.""" import cv2 diff --git a/aeon/schema/octagon.py b/aeon/schema/octagon.py index 643ff77d..ae085abe 100644 --- a/aeon/schema/octagon.py +++ b/aeon/schema/octagon.py @@ -1,4 +1,4 @@ -"""Octagon schema definition.""" +"""Schema definition for octagon experiments-specific data streams.""" import aeon.io.reader as _reader from aeon.schema.streams import Stream, StreamGroup diff --git a/aeon/schema/social_01.py b/aeon/schema/social_01.py index f7f44b80..3230e1aa 100644 --- a/aeon/schema/social_01.py +++ b/aeon/schema/social_01.py @@ -1,4 +1,4 @@ -"""This module contains the schema for the social_01 dataset.""" +"""Schema definition for social_01 experiments-specific data streams.""" import aeon.io.reader as _reader from aeon.schema.streams import Stream diff --git a/aeon/schema/social_02.py b/aeon/schema/social_02.py index 4b5dbfe1..0df58e32 100644 --- a/aeon/schema/social_02.py +++ b/aeon/schema/social_02.py @@ -1,4 +1,4 @@ -"""This module defines the schema for the social_02 dataset.""" +"""Schema definition for social_02 experiments-specific data streams.""" import aeon.io.reader as _reader from aeon.schema import core, foraging diff --git a/aeon/schema/social_03.py b/aeon/schema/social_03.py index 0f07e72c..5954d35b 100644 --- a/aeon/schema/social_03.py +++ b/aeon/schema/social_03.py @@ -1,4 +1,4 @@ -"""This module contains the schema for the social_03 dataset.""" +"""Schema definition for social_03 experiments-specific data streams.""" import aeon.io.reader as _reader from aeon.schema.streams import Stream diff --git a/aeon/schema/streams.py b/aeon/schema/streams.py index 0269a2f6..c29e2779 100644 --- a/aeon/schema/streams.py +++ b/aeon/schema/streams.py @@ -1,4 +1,4 @@ -"""Contains classes for defining data streams and devices.""" +"""Classes for defining data streams and devices.""" import inspect from itertools import chain