Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/datajoint_pipeline' into dev_q…
Browse files Browse the repository at this point in the history
…ueriable_in_roi_time
  • Loading branch information
ttngu207 committed Nov 20, 2024
2 parents 507a0bc + 9be1f8e commit 28e344c
Show file tree
Hide file tree
Showing 30 changed files with 197 additions and 170 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,7 @@ dj_local_conf.json
log*.txt
scratch/
scratch*.py
**/*.nfs*
**/*.nfs*

# Test
.coverage
4 changes: 3 additions & 1 deletion aeon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@
del version, PackageNotFoundError

# Set functions available directly under the 'aeon' top-level namespace
from aeon.io.api import load as load # noqa: PLC0414
from aeon.io.api import load

__all__ = ["load"]
7 changes: 6 additions & 1 deletion aeon/dj_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,12 @@ def fetch_stream(query, drop_pk=True, round_microseconds=True):
df.rename(columns={"timestamps": "time"}, inplace=True)
df.set_index("time", inplace=True)
df.sort_index(inplace=True)
df = df.convert_dtypes(convert_string=False, convert_integer=False, convert_boolean=False, convert_floating=False)
df = df.convert_dtypes(
convert_string=False,
convert_integer=False,
convert_boolean=False,
convert_floating=False
)
if not df.empty and round_microseconds:
logging.warning("Rounding timestamps to microseconds is now enabled by default."
" To disable, set round_microseconds=False.")
Expand Down
61 changes: 32 additions & 29 deletions aeon/dj_pipeline/analysis/block_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import itertools
import json
from collections import defaultdict
from datetime import datetime, timezone
from datetime import UTC, datetime

import datajoint as dj
import numpy as np
Expand All @@ -21,17 +21,8 @@
gen_subject_colors_dict,
subject_colors,
)
from aeon.dj_pipeline import (
acquisition,
fetch_stream,
get_schema_name,
streams,
tracking,
)
from aeon.dj_pipeline.analysis.visit import (
filter_out_maintenance_periods,
get_maintenance_periods,
)
from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name, streams, tracking
from aeon.dj_pipeline.analysis.visit import filter_out_maintenance_periods, get_maintenance_periods
from aeon.io import api as io_api

schema = dj.schema(get_schema_name("block_analysis"))
Expand Down Expand Up @@ -205,7 +196,9 @@ def make(self, key):
if len(tracking.SLEAPTracking & chunk_keys) < len(tracking.SLEAPTracking.key_source & chunk_keys):
if len(tracking.BlobPosition & chunk_keys) < len(tracking.BlobPosition.key_source & chunk_keys):
raise ValueError(
f"BlockAnalysis Not Ready - SLEAPTracking (and BlobPosition) not yet fully ingested for block: {key}. Skipping (to retry later)..."
"BlockAnalysis Not Ready - "
f"SLEAPTracking (and BlobPosition) not yet fully ingested for block: {key}. "
"Skipping (to retry later)..."
)
else:
use_blob_position = True
Expand Down Expand Up @@ -272,7 +265,7 @@ def make(self, key):
# log a note and pick the first rate to move forward
AnalysisNote.insert1(
{
"note_timestamp": datetime.now(timezone.utc),
"note_timestamp": datetime.now(UTC),
"note_type": "Multiple patch rates",
"note": (
f"Found multiple patch rates for block {key} "
Expand Down Expand Up @@ -325,7 +318,8 @@ def make(self, key):

if use_blob_position and len(subject_names) > 1:
raise ValueError(
f"Without SLEAPTracking, BlobPosition can only handle single-subject block. Found {len(subject_names)} subjects."
f"Without SLEAPTracking, BlobPosition can only handle a single-subject block. "
f"Found {len(subject_names)} subjects."
)

block_subject_entries = []
Expand All @@ -345,7 +339,9 @@ def make(self, key):
pos_df = fetch_stream(pos_query)[block_start:block_end]
pos_df["likelihood"] = np.nan
# keep only rows with area between 0 and 1000 - likely artifacts otherwise
pos_df = pos_df[(pos_df.area > 0) & (pos_df.area < 1000)]
MIN_AREA = 0
MAX_AREA = 1000
pos_df = pos_df[(pos_df.area > MIN_AREA) & (pos_df.area < MAX_AREA)]
else:
pos_query = (
streams.SpinnakerVideoSource
Expand Down Expand Up @@ -477,15 +473,19 @@ def make(self, key):

# Ensure wheel_timestamps are of the same length across all patches
wheel_lens = [len(p["wheel_timestamps"]) for p in block_patches]
MAX_WHEEL_DIFF = 10

if len(set(wheel_lens)) > 1:
max_diff = max(wheel_lens) - min(wheel_lens)
if max_diff > 10:
if max_diff > MAX_WHEEL_DIFF:
# if diff is more than 10 samples, raise error, this is unexpected, some patches crash?
raise ValueError(f"Wheel data lengths are not consistent across patches ({max_diff} samples diff)")
raise ValueError(
f"Inconsistent wheel data lengths across patches ({max_diff} samples diff)"
)
min_wheel_len = min(wheel_lens)
for p in block_patches:
p["wheel_timestamps"] = p["wheel_timestamps"][: min(wheel_lens)]
p["wheel_cumsum_distance_travelled"] = p["wheel_cumsum_distance_travelled"][: min(wheel_lens)]

p["wheel_timestamps"] = p["wheel_timestamps"][:min_wheel_len]
p["wheel_cumsum_distance_travelled"] = p["wheel_cumsum_distance_travelled"][:min_wheel_len]
self.insert1(key)

in_patch_radius = 130 # pixels
Expand Down Expand Up @@ -1637,18 +1637,20 @@ class AnalysisNote(dj.Manual):


def get_threshold_associated_pellets(patch_key, start, end):
"""Retrieve the pellet delivery timestamps associated with each patch threshold update within the specified start-end time.
"""Gets pellet delivery timestamps for each patch threshold update within the specified time range.
1. Get all patch state update timestamps (DepletionState): let's call these events "A"
- Remove all events within 1 second of each other
- Remove all events without threshold value (NaN)
- Remove all events within 1 second of each other
- Remove all events without threshold value (NaN)
2. Get all pellet delivery timestamps (DeliverPellet): let's call these events "B"
- Find matching beam break timestamps within 1.2s after each pellet delivery
- Find matching beam break timestamps within 1.2s after each pellet delivery
3. For each event "A", find the nearest event "B" within 100ms before or after the event "A"
- These are the pellet delivery events "B" associated with the previous threshold update
event "A"
- These are the pellet delivery events "B" associated with the previous threshold update event "A"
4. Shift back the pellet delivery timestamps by 1 to match the pellet delivery with the
previous threshold update
previous threshold update
5. Remove all threshold updates events "A" without a corresponding pellet delivery event "B"
Args:
Expand All @@ -1658,12 +1660,13 @@ def get_threshold_associated_pellets(patch_key, start, end):
Returns:
pd.DataFrame: DataFrame with the following columns:
- threshold_update_timestamp (index)
- pellet_timestamp
- beam_break_timestamp
- offset
- rate
""" # noqa 501
"""
chunk_restriction = acquisition.create_chunk_restriction(patch_key["experiment_name"], start, end)

# Step 1 - fetch data
Expand Down
18 changes: 4 additions & 14 deletions aeon/dj_pipeline/analysis/visit.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
"""Module for visit-related tables in the analysis schema."""

from collections import deque
from datetime import datetime, timezone
from datetime import UTC, datetime

import datajoint as dj
import numpy as np
import pandas as pd

from aeon.analysis import utils as analysis_utils
from aeon.dj_pipeline import (
acquisition,
fetch_stream,
get_schema_name,
lab,
qc,
tracking,
)
from aeon.dj_pipeline import acquisition, fetch_stream, get_schema_name

schema = dj.schema(get_schema_name("analysis"))

Expand Down Expand Up @@ -146,7 +139,7 @@ def ingest_environment_visits(experiment_names: list | None = None):
.fetch("last_visit")
)
start = min(subjects_last_visits) if len(subjects_last_visits) else "1900-01-01"
end = datetime.now(timezone.utc) if start else "2200-01-01"
end = datetime.now(UTC) if start else "2200-01-01"

enter_exit_query = (
acquisition.SubjectEnterExit.Time * acquisition.EventType
Expand All @@ -161,10 +154,7 @@ def ingest_environment_visits(experiment_names: list | None = None):
enter_exit_df = pd.DataFrame(
zip(
*enter_exit_query.fetch(
"subject",
"enter_exit_time",
"event_type",
order_by="enter_exit_time",
"subject", "enter_exit_time", "event_type", order_by="enter_exit_time"
),
strict=False,
)
Expand Down
11 changes: 3 additions & 8 deletions aeon/dj_pipeline/analysis/visit_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,10 @@ def get_position(cls, visit_key=None, subject=None, start=None, end=None):
Visit.join(VisitEnd, left=True).proj(visit_end="IFNULL(visit_end, NOW())") & visit_key
).fetch1("visit_start", "visit_end")
subject = visit_key["subject"]
elif all((subject, start, end)):
start = start # noqa PLW0127
end = end # noqa PLW0127
subject = subject # noqa PLW0127
else:
elif not all((subject, start, end)):
raise ValueError(
'Either "visit_key" or all three "subject", "start" and "end" has to be specified'
)

'Either "visit_key" or all three "subject", "start", and "end" must be specified.'
)
return tracking._get_position(
cls.TimeSlice,
object_attr="subject",
Expand Down
2 changes: 1 addition & 1 deletion aeon/dj_pipeline/create_experiments/create_octagon_1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Function to create new experiments for octagon1.0."""
"""Functions to create new experiments for octagon1.0."""

from aeon.dj_pipeline import acquisition, subject

Expand Down
2 changes: 1 addition & 1 deletion aeon/dj_pipeline/create_experiments/create_presocial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Function to create new experiments for presocial0.1."""
"""Functions to create new experiments for presocial0.1."""

from aeon.dj_pipeline import acquisition, lab, subject

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Function to create new social experiments."""
"""Functions to create new social experiments."""

from datetime import datetime

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Function to create new experiments for social0-r1."""
"""Functions to create new experiments for social0-r1."""

import pathlib

Expand Down
2 changes: 1 addition & 1 deletion aeon/dj_pipeline/qc.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class QCRoutine(dj.Lookup):

@schema
class CameraQC(dj.Imported):
definition = """ # Quality controls performed on a particular camera for a particular acquisition chunk
definition = """ # Quality controls performed on a particular camera for one acquisition chunk
-> acquisition.Chunk
-> streams.SpinnakerVideoSource
---
Expand Down
15 changes: 9 additions & 6 deletions aeon/dj_pipeline/scripts/reingest_fullpose_sleap_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Functions to find and delete orphaned epochs that have been ingested but are no longer valid."""

from datetime import datetime

from aeon.dj_pipeline import acquisition, tracking

aeon_schemas = acquisition.aeon_schemas
Expand All @@ -8,11 +11,10 @@


def find_chunks_to_reingest(exp_key, delete_not_fullpose=False):
"""
Find chunks with newly available full pose data to reingest.
"""Find chunks with newly available full pose data to reingest.
If available, fullpose data can be found in `processed` folder
"""

device_name = "CameraTop"

devices_schema = getattr(
Expand All @@ -21,13 +23,14 @@ def find_chunks_to_reingest(exp_key, delete_not_fullpose=False):
"devices_schema_name"
),
)
stream_reader = getattr(getattr(devices_schema, device_name), "Pose")
stream_reader = getattr(devices_schema, device_name).Pose

# special ingestion case for social0.2 full-pose data (using Pose reader from social03)
if exp_key["experiment_name"].startswith("social0.2"):
from aeon.io import reader as io_reader
stream_reader = getattr(getattr(devices_schema, device_name), "Pose03")
assert isinstance(stream_reader, io_reader.Pose), "Pose03 is not a Pose reader"
stream_reader = getattr(devices_schema, device_name).Pose03
if not isinstance(stream_reader, io_reader.Pose):
raise TypeError("Pose03 is not a Pose reader")

# find processed path for exp_key
processed_dir = acquisition.Experiment.get_data_directory(exp_key, "processed")
Expand Down
9 changes: 6 additions & 3 deletions aeon/dj_pipeline/scripts/sync_ingested_and_raw_epochs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import datajoint as dj
"""Functions to find and delete orphaned epochs that have been ingested but are no longer valid."""

from datetime import datetime

import datajoint as dj

from aeon.dj_pipeline import acquisition, streams
from aeon.dj_pipeline.analysis import block_analysis

Expand All @@ -11,8 +14,8 @@


def find_orphaned_ingested_epochs(exp_key, delete_invalid_epochs=False):
"""
Find ingested epochs that are no longer valid
"""Find ingested epochs that are no longer valid.
This is due to the raw epoch/chunk files/directories being deleted for whatever reason
(e.g. bad data, testing, etc.)
"""
Expand Down
Loading

0 comments on commit 28e344c

Please sign in to comment.