Skip to content

Commit

Permalink
Merge pull request #438 from ttngu207/datajoint_pipeline
Browse files Browse the repository at this point in the history
New `ingestion_schemas` + MANY minor fixes and improvements
  • Loading branch information
jkbhagatio authored Nov 6, 2024
2 parents 5dacb45 + b0952eb commit c2e90b6
Show file tree
Hide file tree
Showing 21 changed files with 996 additions and 255 deletions.
13 changes: 12 additions & 1 deletion aeon/dj_pipeline/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import hashlib
import logging
import os
import uuid

Expand Down Expand Up @@ -30,11 +31,17 @@ def dict_to_uuid(key) -> uuid.UUID:
return uuid.UUID(hex=hashed.hexdigest())


def fetch_stream(query, drop_pk=True):
def fetch_stream(query, drop_pk=True, round_microseconds=True):
"""Fetches data from a Stream table based on a query and returns it as a DataFrame.
Provided a query containing data from a Stream table,
fetch and aggregate the data into one DataFrame indexed by "time"
Args:
query (datajoint.Query): A query object containing data from a Stream table
drop_pk (bool, optional): Drop primary key columns. Defaults to True.
round_microseconds (bool, optional): Round timestamps to microseconds. Defaults to True.
(this is important as timestamps in mysql is only accurate to microseconds)
"""
df = (query & "sample_count > 0").fetch(format="frame").reset_index()
cols2explode = [
Expand All @@ -47,6 +54,10 @@ def fetch_stream(query, drop_pk=True):
df.set_index("time", inplace=True)
df.sort_index(inplace=True)
df = df.convert_dtypes(convert_string=False, convert_integer=False, convert_boolean=False, convert_floating=False)
if not df.empty and round_microseconds:
logging.warning("Rounding timestamps to microseconds is now enabled by default."
" To disable, set round_microseconds=False.")
df.index = df.index.round("us")
return df


Expand Down
10 changes: 7 additions & 3 deletions aeon/dj_pipeline/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from aeon.dj_pipeline.utils import paths
from aeon.io import api as io_api
from aeon.io import reader as io_reader
from aeon.schema import schemas as aeon_schemas
from aeon.schema import ingestion_schemas as aeon_schemas

logger = dj.logger
schema = dj.schema(get_schema_name("acquisition"))
Expand Down Expand Up @@ -646,10 +646,14 @@ def _match_experiment_directory(experiment_name, path, directories):

def create_chunk_restriction(experiment_name, start_time, end_time):
"""Create a time restriction string for the chunks between the specified "start" and "end" times."""
exp_key = {"experiment_name": experiment_name}
start_restriction = f'"{start_time}" BETWEEN chunk_start AND chunk_end'
end_restriction = f'"{end_time}" BETWEEN chunk_start AND chunk_end'
start_query = Chunk & {"experiment_name": experiment_name} & start_restriction
end_query = Chunk & {"experiment_name": experiment_name} & end_restriction
start_query = Chunk & exp_key & start_restriction
end_query = Chunk & exp_key & end_restriction
if not end_query:
# No chunk contains the end time, so we need to find the last chunk that starts before the end time
end_query = Chunk & exp_key & f'chunk_end BETWEEN "{start_time}" AND "{end_time}"'
if not (start_query and end_query):
raise ValueError(f"No Chunk found between {start_time} and {end_time}")
time_restriction = (
Expand Down
158 changes: 102 additions & 56 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class BlockDetection(dj.Computed):
-> acquisition.Environment
"""

key_source = acquisition.Environment - {"experiment_name": "social0.1-aeon3"}

def make(self, key):
"""On a per-chunk basis, check for the presence of new block, insert into Block table.
Expand Down Expand Up @@ -88,8 +90,7 @@ def make(self, key):
blocks_df = block_state_df[block_state_df.pellet_ct == 0]
# account for the double 0s - find any 0s that are within 1 second of each other, remove the 2nd one
double_0s = blocks_df.index.to_series().diff().dt.total_seconds() < 1
# find the indices of the 2nd 0s and remove
double_0s = double_0s.shift(-1).fillna(False)
# keep the first 0s
blocks_df = blocks_df[~double_0s]

block_entries = []
Expand Down Expand Up @@ -144,8 +145,8 @@ class Patch(dj.Part):
wheel_timestamps: longblob
patch_threshold: longblob
patch_threshold_timestamps: longblob
patch_rate: float
patch_offset: float
patch_rate=null: float
patch_offset=null: float
"""

class Subject(dj.Part):
Expand Down Expand Up @@ -181,17 +182,27 @@ def make(self, key):
streams.UndergroundFeederDepletionState,
streams.UndergroundFeederDeliverPellet,
streams.UndergroundFeederEncoder,
tracking.SLEAPTracking,
)
for streams_table in streams_tables:
if len(streams_table & chunk_keys) < len(streams_table.key_source & chunk_keys):
raise ValueError(
f"BlockAnalysis Not Ready - {streams_table.__name__} not yet fully ingested for block: {key}. Skipping (to retry later)..."
)

# Check if SLEAPTracking is ready, if not, see if BlobPosition can be used instead
use_blob_position = False
if len(tracking.SLEAPTracking & chunk_keys) < len(tracking.SLEAPTracking.key_source & chunk_keys):
if len(tracking.BlobPosition & chunk_keys) < len(tracking.BlobPosition.key_source & chunk_keys):
raise ValueError(
f"BlockAnalysis Not Ready - SLEAPTracking (and BlobPosition) not yet fully ingested for block: {key}. Skipping (to retry later)..."
)
else:
use_blob_position = True

# Patch data - TriggerPellet, DepletionState, Encoder (distancetravelled)
# For wheel data, downsample to 10Hz
final_encoder_fs = 10
# For wheel data, downsample to 50Hz
final_encoder_hz = 50
freq = 1 / final_encoder_hz * 1e3 # in ms

maintenance_period = get_maintenance_periods(key["experiment_name"], block_start, block_end)

Expand Down Expand Up @@ -233,51 +244,52 @@ def make(self, key):
encoder_df, maintenance_period, block_end, dropna=True
)

if depletion_state_df.empty:
raise ValueError(f"No depletion state data found for block {key} - patch: {patch_name}")

encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled(encoder_df.angle)

if len(depletion_state_df.rate.unique()) > 1:
# multiple patch rates per block is unexpected, log a note and pick the first rate to move forward
AnalysisNote.insert1(
{
"note_timestamp": datetime.utcnow(),
"note_type": "Multiple patch rates",
"note": f"Found multiple patch rates for block {key} - patch: {patch_name} - rates: {depletion_state_df.rate.unique()}",
}
)
# if all dataframes are empty, skip
if pellet_ts_threshold_df.empty and depletion_state_df.empty and encoder_df.empty:
continue

patch_rate = depletion_state_df.rate.iloc[0]
patch_offset = depletion_state_df.offset.iloc[0]
# handles patch rate value being INF
patch_rate = 999999999 if np.isinf(patch_rate) else patch_rate
if encoder_df.empty:
encoder_df["distance_travelled"] = 0
else:
# -1 is for placement of magnetic encoder, where wheel movement actually decreases encoder
encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled(encoder_df.angle)
encoder_df = encoder_df.resample(f"{freq}ms").first()

if not depletion_state_df.empty:
if len(depletion_state_df.rate.unique()) > 1:
# multiple patch rates per block is unexpected, log a note and pick the first rate to move forward
AnalysisNote.insert1(
{
"note_timestamp": datetime.utcnow(),
"note_type": "Multiple patch rates",
"note": f"Found multiple patch rates for block {key} - patch: {patch_name} - rates: {depletion_state_df.rate.unique()}",
}
)

encoder_fs = (
1 / encoder_df.index.to_series().diff().dt.total_seconds().median()
) # mean or median?
wheel_downsampling_factor = int(encoder_fs / final_encoder_fs)
patch_rate = depletion_state_df.rate.iloc[0]
patch_offset = depletion_state_df.offset.iloc[0]
# handles patch rate value being INF
patch_rate = 999999999 if np.isinf(patch_rate) else patch_rate
else:
logger.warning(f"No depletion state data found for block {key} - patch: {patch_name}")
patch_rate = None
patch_offset = None

block_patch_entries.append(
{
**key,
"patch_name": patch_name,
"pellet_count": len(pellet_ts_threshold_df),
"pellet_timestamps": pellet_ts_threshold_df.pellet_timestamp.values,
"wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values[
::wheel_downsampling_factor
],
"wheel_timestamps": encoder_df.index.values[::wheel_downsampling_factor],
"wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values,
"wheel_timestamps": encoder_df.index.values,
"patch_threshold": pellet_ts_threshold_df.threshold.values,
"patch_threshold_timestamps": pellet_ts_threshold_df.index.values,
"patch_rate": patch_rate,
"patch_offset": patch_offset,
}
)

# update block_end if last timestamp of encoder_df is before the current block_end
block_end = min(encoder_df.index[-1], block_end)

# Subject data
# Get all unique subjects that visited the environment over the entire exp;
# For each subject, see 'type' of visit most recent to start of block
Expand All @@ -288,27 +300,50 @@ def make(self, key):
& f'chunk_start <= "{chunk_keys[-1]["chunk_start"]}"'
)[:block_start]
subject_visits_df = subject_visits_df[subject_visits_df.region == "Environment"]
subject_visits_df = subject_visits_df[~subject_visits_df.id.str.contains("Test", case=False)]
subject_names = []
for subject_name in set(subject_visits_df.id):
_df = subject_visits_df[subject_visits_df.id == subject_name]
if _df.type.iloc[-1] != "Exit":
subject_names.append(subject_name)

if use_blob_position and len(subject_names) > 1:
raise ValueError(
f"Without SLEAPTracking, BlobPosition can only handle single-subject block. Found {len(subject_names)} subjects."
)

block_subject_entries = []
for subject_name in subject_names:
# positions - query for CameraTop, identity_name matches subject_name,
pos_query = (
streams.SpinnakerVideoSource
* tracking.SLEAPTracking.PoseIdentity.proj("identity_name", part_name="anchor_part")
* tracking.SLEAPTracking.Part
& key
& {
"spinnaker_video_source_name": "CameraTop",
"identity_name": subject_name,
}
& chunk_restriction
)
pos_df = fetch_stream(pos_query)[block_start:block_end]
if use_blob_position:
pos_query = (
streams.SpinnakerVideoSource
* tracking.BlobPosition.Object
& key
& chunk_restriction
& {
"spinnaker_video_source_name": "CameraTop",
"identity_name": subject_name
}
)
pos_df = fetch_stream(pos_query)[block_start:block_end]
pos_df["likelihood"] = np.nan
# keep only rows with area between 0 and 1000 - likely artifacts otherwise
pos_df = pos_df[(pos_df.area > 0) & (pos_df.area < 1000)]
else:
pos_query = (
streams.SpinnakerVideoSource
* tracking.SLEAPTracking.PoseIdentity.proj("identity_name", part_name="anchor_part")
* tracking.SLEAPTracking.Part
& key
& {
"spinnaker_video_source_name": "CameraTop",
"identity_name": subject_name,
}
& chunk_restriction
)
pos_df = fetch_stream(pos_query)[block_start:block_end]

pos_df = filter_out_maintenance_periods(pos_df, maintenance_period, block_end)

if pos_df.empty:
Expand Down Expand Up @@ -345,8 +380,8 @@ def make(self, key):
{
**key,
"block_duration": (block_end - block_start).total_seconds() / 3600,
"patch_count": len(patch_keys),
"subject_count": len(subject_names),
"patch_count": len(block_patch_entries),
"subject_count": len(block_subject_entries),
}
)
self.Patch.insert(block_patch_entries)
Expand Down Expand Up @@ -423,6 +458,17 @@ def make(self, key):
)
subjects_positions_df.set_index("position_timestamps", inplace=True)

# Ensure wheel_timestamps are of the same length across all patches
wheel_lens = [len(p["wheel_timestamps"]) for p in block_patches]
if len(set(wheel_lens)) > 1:
max_diff = max(wheel_lens) - min(wheel_lens)
if max_diff > 10:
# if diff is more than 10 samples, raise error, this is unexpected, some patches crash?
raise ValueError(f"Wheel data lengths are not consistent across patches ({max_diff} samples diff)")
for p in block_patches:
p["wheel_timestamps"] = p["wheel_timestamps"][: min(wheel_lens)]
p["wheel_cumsum_distance_travelled"] = p["wheel_cumsum_distance_travelled"][: min(wheel_lens)]

self.insert1(key)

in_patch_radius = 130 # pixels
Expand Down Expand Up @@ -541,7 +587,7 @@ def make(self, key):
| {
"patch_name": patch["patch_name"],
"subject_name": subject_name,
"in_patch_timestamps": subject_in_patch.index.values,
"in_patch_timestamps": subject_in_patch[in_patch[subject_name]].index.values,
"in_patch_time": subject_in_patch_cum_time[-1],
"pellet_count": len(subj_pellets),
"pellet_timestamps": subj_pellets.index.values,
Expand Down Expand Up @@ -1521,10 +1567,10 @@ def make(self, key):
foraging_bout_df = get_foraging_bouts(key)
foraging_bout_df.rename(
columns={
"subject_name": "subject",
"bout_start": "start",
"bout_end": "end",
"pellet_count": "n_pellets",
"subject": "subject_name",
"start": "bout_start",
"end": "bout_end",
"n_pellets": "pellet_count",
"cum_wheel_dist": "cum_wheel_dist",
},
inplace=True,
Expand All @@ -1540,7 +1586,7 @@ def make(self, key):
@schema
class AnalysisNote(dj.Manual):
definition = """ # Generic table to catch all notes generated during analysis
note_timestamp: datetime
note_timestamp: datetime(6)
---
note_type='': varchar(64)
note: varchar(3000)
Expand Down
24 changes: 22 additions & 2 deletions aeon/dj_pipeline/populate/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,32 @@ def ingest_epochs_chunks():
)

analysis_worker(block_analysis.BlockAnalysis, max_calls=6)
analysis_worker(block_analysis.BlockPlots, max_calls=6)
analysis_worker(block_analysis.BlockSubjectAnalysis, max_calls=6)
analysis_worker(block_analysis.BlockSubjectPlots, max_calls=6)
analysis_worker(block_analysis.BlockForaging, max_calls=6)
analysis_worker(block_analysis.BlockPatchPlots, max_calls=6)
analysis_worker(block_analysis.BlockSubjectPositionPlots, max_calls=6)


def get_workflow_operation_overview():
from datajoint_utilities.dj_worker.utils import get_workflow_operation_overview

return get_workflow_operation_overview(worker_schema_name=worker_schema_name, db_prefixes=[db_prefix])


def retrieve_schemas_sizes(schema_only=False, all_schemas=False):
schema_names = [n for n in dj.list_schemas() if n != "mysql"]
if not all_schemas:
schema_names = [n for n in schema_names
if n.startswith(db_prefix) and not n.startswith(f"{db_prefix}archived")]

if schema_only:
return {n: dj.Schema(n).size_on_disk / 1e9 for n in schema_names}

schema_sizes = {n: {} for n in schema_names}
for n in schema_names:
vm = dj.VirtualModule(n, n)
schema_sizes[n]["schema_gb"] = vm.schema.size_on_disk / 1e9
schema_sizes[n]["tables_gb"] = {n: t().size_on_disk / 1e9
for n, t in vm.__dict__.items()
if isinstance(t, dj.user_tables.TableMeta)}
return schema_sizes
Loading

0 comments on commit c2e90b6

Please sign in to comment.