From bcdd1605c34ffbe94510f67fa25c68dac77e25dd Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 12 Dec 2024 11:49:13 -0600 Subject: [PATCH 1/5] feat(block_analysis): add `in_patch_rfid_timestamps` --- aeon/dj_pipeline/analysis/block_analysis.py | 34 ++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 4267f883..f0caab31 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -21,7 +21,7 @@ gen_subject_colors_dict, subject_colors, ) -from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, tracking +from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, subject, tracking from aeon.dj_pipeline.analysis.visit import filter_out_maintenance_periods, get_maintenance_periods from aeon.io import api as io_api @@ -439,6 +439,7 @@ class Patch(dj.Part): --- in_patch_timestamps: longblob # timestamps when a subject is at a specific patch in_patch_time: float # total seconds spent in this patch for this block + in_patch_rfid_timestamps=null: longblob # timestamps when a subject is at a specific patch based on RFID pellet_count: int pellet_timestamps: longblob patch_threshold: longblob # patch threshold value at each pellet delivery @@ -464,10 +465,16 @@ class Preference(dj.Part): def make(self, key): """Compute preference scores for each subject at each patch within a block.""" + block_start, block_end = (Block & key).fetch1("block_start", "block_end") + chunk_restriction = acquisition.create_chunk_restriction( + key["experiment_name"], block_start, block_end + ) + block_patches = (BlockAnalysis.Patch & key).fetch(as_dict=True) block_subjects = (BlockAnalysis.Subject & key).fetch(as_dict=True) subject_names = [s["subject_name"] for s in block_subjects] patch_names = [p["patch_name"] for p in block_patches] + # Construct subject position dataframe subjects_positions_df = pd.concat( [ @@ -503,6 +510,7 @@ def make(self, key): for p in block_patches: p["wheel_timestamps"] = p["wheel_timestamps"][:min_wheel_len] p["wheel_cumsum_distance_travelled"] = p["wheel_cumsum_distance_travelled"][:min_wheel_len] + self.insert1(key) in_patch_radius = 130 # pixels @@ -518,6 +526,17 @@ def make(self, key): p: {s: {a: pd.Series() for a in pref_attrs} for s in subject_names} for p in patch_names } + # subject-rfid mapping + rfid2subj_map = { + int(l): s + for s, l in zip( + *(subject.SubjectDetail.proj("lab_id") & f"subject in {tuple(subject_names)}").fetch( + "subject", "lab_id" + ), + strict=False, + ) + } + for patch in block_patches: cum_wheel_dist = pd.Series( index=patch["wheel_timestamps"], @@ -601,6 +620,18 @@ def make(self, key): # In patch time in_patch = dist_to_patch_wheel_ts_id_df < in_patch_radius dt = np.median(np.diff(cum_wheel_dist.index)).astype(int) / 1e9 # s + + # In patch time from RFID + rfid_query = ( + streams.RfidReader.proj(rfid_name="REPLACE(rfid_reader_name, 'Rfid', '')") + * streams.RfidReaderRfidEvents + & key + & {"rfid_name": patch["patch_name"]} + & chunk_restriction + ) + rfid_df = fetch_stream(rfid_query)[block_start:block_end] + rfid_df["subject"] = rfid_df.rfid.map(rfid2subj_map) + # Fill in `all_subj_patch_pref` for subject_name in subject_names: all_subj_patch_pref_dict[patch["patch_name"]][subject_name]["cum_dist"] = ( @@ -623,6 +654,7 @@ def make(self, key): "subject_name": subject_name, "in_patch_timestamps": subject_in_patch[in_patch[subject_name]].index.values, "in_patch_time": subject_in_patch_cum_time[-1], + "in_patch_rfid_timestamps": rfid_df[rfid_df.subject == subject_name].index.values, "pellet_count": len(subj_pellets), "pellet_timestamps": subj_pellets.index.values, "patch_threshold": subj_patch_thresh, From 95345ff23ed4b4cf3745ef819556ae1232e64e4d Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 12 Dec 2024 11:51:57 -0600 Subject: [PATCH 2/5] feat(block_analysis): script to add `in_patch_rfid_timestamps` with `update1` --- .../update_in_patch_rfid_timestamps.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py diff --git a/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py b/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py new file mode 100644 index 00000000..39a42ba4 --- /dev/null +++ b/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py @@ -0,0 +1,63 @@ +from aeon.dj_pipeline.analysis.block_analysis import * + +logger = dj.logger + + +def update_in_patch_rfid_timestamps(block_key): + logger.info(f"Updating in_patch_rfid_timestamps for {block_key}") + + block_start, block_end = (Block & block_key).fetch1("block_start", "block_end") + chunk_restriction = acquisition.create_chunk_restriction( + block_key["experiment_name"], block_start, block_end + ) + patch_names = (BlockAnalysis.Patch & block_key).fetch("patch_name") + subject_names = (BlockAnalysis.Subject & block_key).fetch("subject_name") + + rfid2subj_map = { + int(l): s + for s, l in zip( + *(subject.SubjectDetail.proj("lab_id") & f"subject in {tuple(subject_names)}").fetch( + "subject", "lab_id" + ), + strict=False, + ) + } + + entries = [] + for patch_name in patch_names: + # In patch time from RFID + rfid_query = ( + streams.RfidReader.proj(rfid_name="REPLACE(rfid_reader_name, 'Rfid', '')") + * streams.RfidReaderRfidEvents + & block_key + & {"rfid_name": patch_name} + & chunk_restriction + ) + rfid_df = fetch_stream(rfid_query)[block_start:block_end] + rfid_df["subject"] = rfid_df.rfid.map(rfid2subj_map) + + for subject_name in subject_names: + k = { + **block_key, + "patch_name": patch_name, + "subject_name": subject_name, + } + if BlockSubjectAnalysis.Patch & k: + entries.append( + {**k, "in_patch_rfid_timestamps": rfid_df[rfid_df.subject == subject_name].index.values} + ) + + with BlockSubjectAnalysis.connection.transaction: + for e in entries: + BlockSubjectAnalysis.Patch.update1(e) + + +def main(): + block_keys = BlockSubjectAnalysis & ( + BlockSubjectAnalysis.Patch & "in_patch_rfid_timestamps IS NULL" + ).fetch("KEY") + for block_key in block_keys: + try: + update_in_patch_rfid_timestamps(block_key) + except Exception as e: + logger.error(f"Error updating {block_key}: {e}") From 642589d0a01e4f0a67a98d4ce208deda05739032 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 12 Dec 2024 11:57:31 -0600 Subject: [PATCH 3/5] chore: minor logging updates --- aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py b/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py index 39a42ba4..c7dbe6d4 100644 --- a/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py +++ b/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py @@ -6,6 +6,7 @@ def update_in_patch_rfid_timestamps(block_key): logger.info(f"Updating in_patch_rfid_timestamps for {block_key}") + block_key = (Block & block_key).fetch1("KEY") block_start, block_end = (Block & block_key).fetch1("block_start", "block_end") chunk_restriction = acquisition.create_chunk_restriction( block_key["experiment_name"], block_start, block_end @@ -50,6 +51,7 @@ def update_in_patch_rfid_timestamps(block_key): with BlockSubjectAnalysis.connection.transaction: for e in entries: BlockSubjectAnalysis.Patch.update1(e) + logger.info(f"\tUpdated {len(entries)} entries.") def main(): From 09106955e2086d8df7c9e7d50ec301f010010c3c Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Thu, 12 Dec 2024 12:05:11 -0600 Subject: [PATCH 4/5] chore: ruff --- aeon/dj_pipeline/analysis/block_analysis.py | 28 +++++++++---------- .../update_in_patch_rfid_timestamps.py | 17 +++++++++-- 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index f0caab31..0a888cf6 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -439,7 +439,7 @@ class Patch(dj.Part): --- in_patch_timestamps: longblob # timestamps when a subject is at a specific patch in_patch_time: float # total seconds spent in this patch for this block - in_patch_rfid_timestamps=null: longblob # timestamps when a subject is at a specific patch based on RFID + in_patch_rfid_timestamps=null: longblob # in_patch_timestamps based on RFID pellet_count: int pellet_timestamps: longblob patch_threshold: longblob # patch threshold value at each pellet delivery @@ -528,8 +528,8 @@ def make(self, key): # subject-rfid mapping rfid2subj_map = { - int(l): s - for s, l in zip( + int(lab_id): subj_name + for subj_name, lab_id in zip( *(subject.SubjectDetail.proj("lab_id") & f"subject in {tuple(subject_names)}").fetch( "subject", "lab_id" ), @@ -1183,11 +1183,11 @@ def calculate_running_preference(group, pref_col, out_col): pel_patches = [p for p in patch_names if "dummy" not in p.lower()] # exclude dummy patches data = [] for patch in pel_patches: - for subject in subject_names: + for subject_name in subject_names: data.append( { "patch_name": patch, - "subject_name": subject, + "subject_name": subject_name, "time": wheel_ts[patch], "weighted_dist": np.empty_like(wheel_ts[patch]), } @@ -1287,18 +1287,18 @@ def norm_inv_norm(group): df = subj_wheel_pel_weighted_dist # Iterate through patches and subjects to create plots for i, patch in enumerate(pel_patches, start=1): - for j, subject in enumerate(subject_names, start=1): + for j, subject_name in enumerate(subject_names, start=1): # Filter data for this patch and subject - times = df.loc[patch].loc[subject]["time"] - norm_values = df.loc[patch].loc[subject]["norm_value"] - wheel_prefs = df.loc[patch].loc[subject]["wheel_pref"] + times = df.loc[patch].loc[subject_name]["time"] + norm_values = df.loc[patch].loc[subject_name]["norm_value"] + wheel_prefs = df.loc[patch].loc[subject_name]["wheel_pref"] # Add wheel_pref trace weighted_patch_pref_fig.add_trace( go.Scatter( x=times, y=wheel_prefs, - name=f"{subject} - wheel_pref", + name=f"{subject_name} - wheel_pref", line={ "color": subject_colors[i - 1], "dash": patch_linestyles_dict[patch], @@ -1316,7 +1316,7 @@ def norm_inv_norm(group): go.Scatter( x=times, y=norm_values, - name=f"{subject} - norm_value", + name=f"{subject_name} - norm_value", line={ "color": subject_colors[i - 1], "dash": patch_linestyles_dict[patch], @@ -1846,8 +1846,8 @@ def get_foraging_bouts( # - For the foraging bout end time, we need to account for the final pellet delivery time # - Filter out events with < `min_pellets` # - For final events, get: duration, n_pellets, cum_wheel_distance -> add to returned DF - for subject in subject_patch_data.index.unique("subject_name"): - cur_subject_data = subject_patch_data.xs(subject, level="subject_name") + for subject_name in subject_patch_data.index.unique("subject_name"): + cur_subject_data = subject_patch_data.xs(subject_name, level="subject_name") n_pels = sum([arr.size for arr in cur_subject_data["pellet_timestamps"].values]) if n_pels < min_pellets: continue @@ -1929,7 +1929,7 @@ def get_foraging_bouts( "end": bout_starts_ends[:, 1], "n_pellets": bout_pellets, "cum_wheel_dist": bout_cum_wheel_dist, - "subject": subject, + "subject": subject_name, } ), ] diff --git a/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py b/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py index c7dbe6d4..a09a650d 100644 --- a/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py +++ b/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py @@ -1,9 +1,19 @@ -from aeon.dj_pipeline.analysis.block_analysis import * +"""Script to update in_patch_rfid_timestamps for all blocks that are missing it.""" + +import datajoint as dj + +from aeon.dj_pipeline import acquisition, fetch_stream, streams, subject +from aeon.dj_pipeline.analysis.block_analysis import Block, BlockAnalysis, BlockSubjectAnalysis logger = dj.logger def update_in_patch_rfid_timestamps(block_key): + """Update in_patch_rfid_timestamps for a given block_key. + + Args: + block_key (dict): block key + """ logger.info(f"Updating in_patch_rfid_timestamps for {block_key}") block_key = (Block & block_key).fetch1("KEY") @@ -15,8 +25,8 @@ def update_in_patch_rfid_timestamps(block_key): subject_names = (BlockAnalysis.Subject & block_key).fetch("subject_name") rfid2subj_map = { - int(l): s - for s, l in zip( + int(lab_id): subj_name + for subj_name, lab_id in zip( *(subject.SubjectDetail.proj("lab_id") & f"subject in {tuple(subject_names)}").fetch( "subject", "lab_id" ), @@ -55,6 +65,7 @@ def update_in_patch_rfid_timestamps(block_key): def main(): + """Update in_patch_rfid_timestamps for all blocks that are missing it.""" block_keys = BlockSubjectAnalysis & ( BlockSubjectAnalysis.Patch & "in_patch_rfid_timestamps IS NULL" ).fetch("KEY") From fef4465d3b85cc01178736b562ba4f22828de673 Mon Sep 17 00:00:00 2001 From: Thinh Nguyen Date: Fri, 13 Dec 2024 08:38:07 -0600 Subject: [PATCH 5/5] fix(analysis): safeguard query when there's only one subject --- aeon/dj_pipeline/analysis/block_analysis.py | 3 ++- .../scripts/update_in_patch_rfid_timestamps.py | 8 +++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 0a888cf6..a7dfd27c 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -530,7 +530,8 @@ def make(self, key): rfid2subj_map = { int(lab_id): subj_name for subj_name, lab_id in zip( - *(subject.SubjectDetail.proj("lab_id") & f"subject in {tuple(subject_names)}").fetch( + *(subject.SubjectDetail.proj("lab_id") + & f"subject in {tuple(list(subject_names) + [''])}").fetch( "subject", "lab_id" ), strict=False, diff --git a/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py b/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py index a09a650d..e617ba14 100644 --- a/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py +++ b/aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py @@ -27,7 +27,8 @@ def update_in_patch_rfid_timestamps(block_key): rfid2subj_map = { int(lab_id): subj_name for subj_name, lab_id in zip( - *(subject.SubjectDetail.proj("lab_id") & f"subject in {tuple(subject_names)}").fetch( + *(subject.SubjectDetail.proj("lab_id") + & f"subject in {tuple(list(subject_names) + [''])}").fetch( "subject", "lab_id" ), strict=False, @@ -66,8 +67,9 @@ def update_in_patch_rfid_timestamps(block_key): def main(): """Update in_patch_rfid_timestamps for all blocks that are missing it.""" - block_keys = BlockSubjectAnalysis & ( - BlockSubjectAnalysis.Patch & "in_patch_rfid_timestamps IS NULL" + block_keys = ( + BlockSubjectAnalysis + & (BlockSubjectAnalysis.Patch & "in_patch_rfid_timestamps IS NULL") ).fetch("KEY") for block_key in block_keys: try: