Skip to content

Commit

Permalink
Merge pull request #458 from ttngu207/datajoint_pipeline
Browse files Browse the repository at this point in the history
add in_patch_rfid_timestamps in `BlockSubjectAnalysis`
  • Loading branch information
ttngu207 authored Dec 13, 2024
2 parents bf4e56c + fef4465 commit 9002be9
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 12 deletions.
57 changes: 45 additions & 12 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
gen_subject_colors_dict,
subject_colors,
)
from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, tracking
from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, subject, tracking
from aeon.dj_pipeline.analysis.visit import filter_out_maintenance_periods, get_maintenance_periods
from aeon.io import api as io_api

Expand Down Expand Up @@ -439,6 +439,7 @@ class Patch(dj.Part):
---
in_patch_timestamps: longblob # timestamps when a subject is at a specific patch
in_patch_time: float # total seconds spent in this patch for this block
in_patch_rfid_timestamps=null: longblob # in_patch_timestamps based on RFID
pellet_count: int
pellet_timestamps: longblob
patch_threshold: longblob # patch threshold value at each pellet delivery
Expand All @@ -464,10 +465,16 @@ class Preference(dj.Part):

def make(self, key):
"""Compute preference scores for each subject at each patch within a block."""
block_start, block_end = (Block & key).fetch1("block_start", "block_end")
chunk_restriction = acquisition.create_chunk_restriction(
key["experiment_name"], block_start, block_end
)

block_patches = (BlockAnalysis.Patch & key).fetch(as_dict=True)
block_subjects = (BlockAnalysis.Subject & key).fetch(as_dict=True)
subject_names = [s["subject_name"] for s in block_subjects]
patch_names = [p["patch_name"] for p in block_patches]

# Construct subject position dataframe
subjects_positions_df = pd.concat(
[
Expand Down Expand Up @@ -503,6 +510,7 @@ def make(self, key):
for p in block_patches:
p["wheel_timestamps"] = p["wheel_timestamps"][:min_wheel_len]
p["wheel_cumsum_distance_travelled"] = p["wheel_cumsum_distance_travelled"][:min_wheel_len]

self.insert1(key)

in_patch_radius = 130 # pixels
Expand All @@ -518,6 +526,18 @@ def make(self, key):
p: {s: {a: pd.Series() for a in pref_attrs} for s in subject_names} for p in patch_names
}

# subject-rfid mapping
rfid2subj_map = {
int(lab_id): subj_name
for subj_name, lab_id in zip(
*(subject.SubjectDetail.proj("lab_id")
& f"subject in {tuple(list(subject_names) + [''])}").fetch(
"subject", "lab_id"
),
strict=False,
)
}

for patch in block_patches:
cum_wheel_dist = pd.Series(
index=patch["wheel_timestamps"],
Expand Down Expand Up @@ -601,6 +621,18 @@ def make(self, key):
# In patch time
in_patch = dist_to_patch_wheel_ts_id_df < in_patch_radius
dt = np.median(np.diff(cum_wheel_dist.index)).astype(int) / 1e9 # s

# In patch time from RFID
rfid_query = (
streams.RfidReader.proj(rfid_name="REPLACE(rfid_reader_name, 'Rfid', '')")
* streams.RfidReaderRfidEvents
& key
& {"rfid_name": patch["patch_name"]}
& chunk_restriction
)
rfid_df = fetch_stream(rfid_query)[block_start:block_end]
rfid_df["subject"] = rfid_df.rfid.map(rfid2subj_map)

# Fill in `all_subj_patch_pref`
for subject_name in subject_names:
all_subj_patch_pref_dict[patch["patch_name"]][subject_name]["cum_dist"] = (
Expand All @@ -623,6 +655,7 @@ def make(self, key):
"subject_name": subject_name,
"in_patch_timestamps": subject_in_patch[in_patch[subject_name]].index.values,
"in_patch_time": subject_in_patch_cum_time[-1],
"in_patch_rfid_timestamps": rfid_df[rfid_df.subject == subject_name].index.values,
"pellet_count": len(subj_pellets),
"pellet_timestamps": subj_pellets.index.values,
"patch_threshold": subj_patch_thresh,
Expand Down Expand Up @@ -1151,11 +1184,11 @@ def calculate_running_preference(group, pref_col, out_col):
pel_patches = [p for p in patch_names if "dummy" not in p.lower()] # exclude dummy patches
data = []
for patch in pel_patches:
for subject in subject_names:
for subject_name in subject_names:
data.append(
{
"patch_name": patch,
"subject_name": subject,
"subject_name": subject_name,
"time": wheel_ts[patch],
"weighted_dist": np.empty_like(wheel_ts[patch]),
}
Expand Down Expand Up @@ -1255,18 +1288,18 @@ def norm_inv_norm(group):
df = subj_wheel_pel_weighted_dist
# Iterate through patches and subjects to create plots
for i, patch in enumerate(pel_patches, start=1):
for j, subject in enumerate(subject_names, start=1):
for j, subject_name in enumerate(subject_names, start=1):
# Filter data for this patch and subject
times = df.loc[patch].loc[subject]["time"]
norm_values = df.loc[patch].loc[subject]["norm_value"]
wheel_prefs = df.loc[patch].loc[subject]["wheel_pref"]
times = df.loc[patch].loc[subject_name]["time"]
norm_values = df.loc[patch].loc[subject_name]["norm_value"]
wheel_prefs = df.loc[patch].loc[subject_name]["wheel_pref"]

# Add wheel_pref trace
weighted_patch_pref_fig.add_trace(
go.Scatter(
x=times,
y=wheel_prefs,
name=f"{subject} - wheel_pref",
name=f"{subject_name} - wheel_pref",
line={
"color": subject_colors[i - 1],
"dash": patch_linestyles_dict[patch],
Expand All @@ -1284,7 +1317,7 @@ def norm_inv_norm(group):
go.Scatter(
x=times,
y=norm_values,
name=f"{subject} - norm_value",
name=f"{subject_name} - norm_value",
line={
"color": subject_colors[i - 1],
"dash": patch_linestyles_dict[patch],
Expand Down Expand Up @@ -1814,8 +1847,8 @@ def get_foraging_bouts(
# - For the foraging bout end time, we need to account for the final pellet delivery time
# - Filter out events with < `min_pellets`
# - For final events, get: duration, n_pellets, cum_wheel_distance -> add to returned DF
for subject in subject_patch_data.index.unique("subject_name"):
cur_subject_data = subject_patch_data.xs(subject, level="subject_name")
for subject_name in subject_patch_data.index.unique("subject_name"):
cur_subject_data = subject_patch_data.xs(subject_name, level="subject_name")
n_pels = sum([arr.size for arr in cur_subject_data["pellet_timestamps"].values])
if n_pels < min_pellets:
continue
Expand Down Expand Up @@ -1897,7 +1930,7 @@ def get_foraging_bouts(
"end": bout_starts_ends[:, 1],
"n_pellets": bout_pellets,
"cum_wheel_dist": bout_cum_wheel_dist,
"subject": subject,
"subject": subject_name,
}
),
]
Expand Down
78 changes: 78 additions & 0 deletions aeon/dj_pipeline/scripts/update_in_patch_rfid_timestamps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""Script to update in_patch_rfid_timestamps for all blocks that are missing it."""

import datajoint as dj

from aeon.dj_pipeline import acquisition, fetch_stream, streams, subject
from aeon.dj_pipeline.analysis.block_analysis import Block, BlockAnalysis, BlockSubjectAnalysis

logger = dj.logger


def update_in_patch_rfid_timestamps(block_key):
"""Update in_patch_rfid_timestamps for a given block_key.
Args:
block_key (dict): block key
"""
logger.info(f"Updating in_patch_rfid_timestamps for {block_key}")

block_key = (Block & block_key).fetch1("KEY")
block_start, block_end = (Block & block_key).fetch1("block_start", "block_end")
chunk_restriction = acquisition.create_chunk_restriction(
block_key["experiment_name"], block_start, block_end
)
patch_names = (BlockAnalysis.Patch & block_key).fetch("patch_name")
subject_names = (BlockAnalysis.Subject & block_key).fetch("subject_name")

rfid2subj_map = {
int(lab_id): subj_name
for subj_name, lab_id in zip(
*(subject.SubjectDetail.proj("lab_id")
& f"subject in {tuple(list(subject_names) + [''])}").fetch(
"subject", "lab_id"
),
strict=False,
)
}

entries = []
for patch_name in patch_names:
# In patch time from RFID
rfid_query = (
streams.RfidReader.proj(rfid_name="REPLACE(rfid_reader_name, 'Rfid', '')")
* streams.RfidReaderRfidEvents
& block_key
& {"rfid_name": patch_name}
& chunk_restriction
)
rfid_df = fetch_stream(rfid_query)[block_start:block_end]
rfid_df["subject"] = rfid_df.rfid.map(rfid2subj_map)

for subject_name in subject_names:
k = {
**block_key,
"patch_name": patch_name,
"subject_name": subject_name,
}
if BlockSubjectAnalysis.Patch & k:
entries.append(
{**k, "in_patch_rfid_timestamps": rfid_df[rfid_df.subject == subject_name].index.values}
)

with BlockSubjectAnalysis.connection.transaction:
for e in entries:
BlockSubjectAnalysis.Patch.update1(e)
logger.info(f"\tUpdated {len(entries)} entries.")


def main():
"""Update in_patch_rfid_timestamps for all blocks that are missing it."""
block_keys = (
BlockSubjectAnalysis
& (BlockSubjectAnalysis.Patch & "in_patch_rfid_timestamps IS NULL")
).fetch("KEY")
for block_key in block_keys:
try:
update_in_patch_rfid_timestamps(block_key)
except Exception as e:
logger.error(f"Error updating {block_key}: {e}")

0 comments on commit 9002be9

Please sign in to comment.