Skip to content

Commit

Permalink
Merge pull request #432 from ttngu207/datajoint_pipeline
Browse files Browse the repository at this point in the history
feat(block_analysis): add table for BlockForagingBout
  • Loading branch information
jkbhagatio authored Oct 5, 2024
2 parents c9d0d4a + 24061c5 commit 5dacb45
Showing 1 changed file with 68 additions and 23 deletions.
91 changes: 68 additions & 23 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -1496,6 +1496,44 @@ def make(self, key):
self.insert1(entry)


# ---- Foraging Bout Analysis ----

@schema
class BlockForaging(dj.Computed):
definition = """
-> BlockSubjectAnalysis
---
bout_count: int # number of foraging bouts in the block
"""

class Bout(dj.Part):
definition = """
-> master
-> BlockAnalysis.Subject
bout_start: datetime(6)
---
bout_end: datetime(6)
pellet_count: int # number of pellets consumed during the bout
cum_wheel_dist: float # cumulative distance travelled during the bout
"""

def make(self, key):
foraging_bout_df = get_foraging_bouts(key)
foraging_bout_df.rename(
columns={
"subject_name": "subject",
"bout_start": "start",
"bout_end": "end",
"pellet_count": "n_pellets",
"cum_wheel_dist": "cum_wheel_dist",
},
inplace=True,
)

self.insert1({**key, "bout_count": len(foraging_bout_df)})
self.Bout.insert({**key, **row} for _, row in foraging_bout_df.iterrows())


# ---- AnalysisNote ----


Expand All @@ -1511,7 +1549,6 @@ class AnalysisNote(dj.Manual):

# ---- Helper Functions ----


def get_threshold_associated_pellets(patch_key, start, end):
"""Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time.
Expand All @@ -1538,35 +1575,53 @@ def get_threshold_associated_pellets(patch_key, start, end):
"""
chunk_restriction = acquisition.create_chunk_restriction(patch_key["experiment_name"], start, end)

# Get pellet delivery trigger data
# Step 1 - fetch data
# pellet delivery trigger
delivered_pellet_df = fetch_stream(
streams.UndergroundFeederDeliverPellet & patch_key & chunk_restriction
)[start:end]
# Remove invalid rows where the time difference is less than 1.2 seconds
invalid_rows = delivered_pellet_df.index.to_series().diff().dt.total_seconds() < 1.2
delivered_pellet_df = delivered_pellet_df[~invalid_rows]

# Get beambreak data
# beambreak
beambreak_df = fetch_stream(streams.UndergroundFeederBeamBreak & patch_key & chunk_restriction)[
start:end
]
# Remove invalid rows where the time difference is less than 1 second
invalid_rows = beambreak_df.index.to_series().diff().dt.total_seconds() < 1
beambreak_df = beambreak_df[~invalid_rows]
# Exclude manual deliveries
# patch threshold
depletion_state_df = fetch_stream(
streams.UndergroundFeederDepletionState & patch_key & chunk_restriction
)[start:end]
# manual delivery
manual_delivery_df = fetch_stream(
streams.UndergroundFeederManualDelivery & patch_key & chunk_restriction
)[start:end]

# Return empty if no data
if delivered_pellet_df.empty or beambreak_df.empty or depletion_state_df.empty:
return acquisition.io_api._empty(
["threshold", "offset", "rate", "pellet_timestamp", "beam_break_timestamp"]
)

# Step 2 - Remove invalid rows (back-to-back events)
# pellet delivery trigger - time difference is less than 1.2 seconds
invalid_rows = delivered_pellet_df.index.to_series().diff().dt.total_seconds() < 1.2
delivered_pellet_df = delivered_pellet_df[~invalid_rows]
# exclude manual deliveries
delivered_pellet_df = delivered_pellet_df.loc[
delivered_pellet_df.index.difference(manual_delivery_df.index)
]
# beambreak - time difference is less than 1 seconds
invalid_rows = beambreak_df.index.to_series().diff().dt.total_seconds() < 1
beambreak_df = beambreak_df[~invalid_rows]
# patch threshold - time difference is less than 1 seconds
depletion_state_df = depletion_state_df.dropna(subset=["threshold"])
invalid_rows = depletion_state_df.index.to_series().diff().dt.total_seconds() < 1
depletion_state_df = depletion_state_df[~invalid_rows]

# Return empty if no pellets
if delivered_pellet_df.empty or beambreak_df.empty:
# Return empty if no data
if delivered_pellet_df.empty or beambreak_df.empty or depletion_state_df.empty:
return acquisition.io_api._empty(
["threshold", "offset", "rate", "pellet_timestamp", "beam_break_timestamp"]
)

# Step 3 - event matching
# Find pellet delivery triggers with matching beambreaks within 1.2s after each pellet delivery
pellet_beam_break_df = (
pd.merge_asof(
Expand All @@ -1582,16 +1637,6 @@ def get_threshold_associated_pellets(patch_key, start, end):
)
pellet_beam_break_df.drop_duplicates(subset="beam_break_timestamp", keep="last", inplace=True)

# Get patch threshold data
depletion_state_df = fetch_stream(
streams.UndergroundFeederDepletionState & patch_key & chunk_restriction
)[start:end]
# Remove NaNs
depletion_state_df = depletion_state_df.dropna(subset=["threshold"])
# Remove invalid rows where the time difference is less than 1 second
invalid_rows = depletion_state_df.index.to_series().diff().dt.total_seconds() < 1
depletion_state_df = depletion_state_df[~invalid_rows]

# Find pellet delivery triggers that approximately coincide with each threshold update
# i.e. nearest pellet delivery within 100ms before or after threshold update
pellet_ts_threshold_df = (
Expand Down

0 comments on commit 5dacb45

Please sign in to comment.