Skip to content

Commit

Permalink
feat(block_analysis): improve BlockDetection logic to better track ne…
Browse files Browse the repository at this point in the history
…wly identified blocks
  • Loading branch information
ttngu207 committed Nov 21, 2024
1 parent e73c099 commit fa53e23
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,18 @@ class Block(dj.Manual):

@schema
class BlockDetection(dj.Computed):
definition = """
definition = """ # Detecting new block(s) for each new Chunk
-> acquisition.Environment
---
execution_time=null: datetime
"""

class IdentifiedBlock(dj.Part):
definition = """ # the block(s) identified in this BlockDetection
-> master
-> Block
"""

key_source = acquisition.Environment - {"experiment_name": "social0.1-aeon3"}

def make(self, key):
Expand All @@ -70,12 +78,9 @@ def make(self, key):
block_state_query = acquisition.Environment.BlockState & exp_key & chunk_restriction
block_state_df = fetch_stream(block_state_query)
if block_state_df.empty:
self.insert1(key)
# self.insert1(key)
return

block_state_df.index = block_state_df.index.round(
"us"
) # timestamp precision in DJ is only at microseconds
block_state_df = block_state_df.loc[
(block_state_df.index > chunk_start) & (block_state_df.index <= chunk_end)
]
Expand Down Expand Up @@ -103,7 +108,10 @@ def make(self, key):
)

Block.insert(block_entries, skip_duplicates=True)
self.insert1(key)
# self.insert1({**key, "execution_time": datetime.now(UTC)})
self.IdentifiedBlock.insert(
{**key, "block_start": entry["block_start"]} for entry in block_entries
)


# ---- Block Analysis and Visualization ----
Expand Down Expand Up @@ -316,6 +324,15 @@ def make(self, key):
if _df.type.iloc[-1] != "Exit":
subject_names.append(subject_name)

# Check for ExperimentTimeline to validate subjects in this block
timeline_query = (acquisition.ExperimentTimeline
& acquisition.ExperimentTimeline.Subject
& key
& f"start <= '{block_start}' AND end >= '{block_end}'")
timeline_subjects = (acquisition.ExperimentTimeline.Subject & timeline_query).fetch("subject")
if len(timeline_subjects):
subject_names = [s for s in subject_names if s in timeline_subjects]

if use_blob_position and len(subject_names) > 1:
raise ValueError(
f"Without SLEAPTracking, BlobPosition can only handle a single-subject block. "
Expand Down

0 comments on commit fa53e23

Please sign in to comment.