diff --git a/aeon/dj_pipeline/__init__.py b/aeon/dj_pipeline/__init__.py index 9bb1128e..1225020a 100644 --- a/aeon/dj_pipeline/__init__.py +++ b/aeon/dj_pipeline/__init__.py @@ -1,4 +1,5 @@ import hashlib +import logging import os import uuid @@ -30,11 +31,17 @@ def dict_to_uuid(key) -> uuid.UUID: return uuid.UUID(hex=hashed.hexdigest()) -def fetch_stream(query, drop_pk=True): +def fetch_stream(query, drop_pk=True, round_microseconds=True): """Fetches data from a Stream table based on a query and returns it as a DataFrame. Provided a query containing data from a Stream table, fetch and aggregate the data into one DataFrame indexed by "time" + + Args: + query (datajoint.Query): A query object containing data from a Stream table + drop_pk (bool, optional): Drop primary key columns. Defaults to True. + round_microseconds (bool, optional): Round timestamps to microseconds. Defaults to True. + (this is important as timestamps in mysql is only accurate to microseconds) """ df = (query & "sample_count > 0").fetch(format="frame").reset_index() cols2explode = [ @@ -47,6 +54,10 @@ def fetch_stream(query, drop_pk=True): df.set_index("time", inplace=True) df.sort_index(inplace=True) df = df.convert_dtypes(convert_string=False, convert_integer=False, convert_boolean=False, convert_floating=False) + if not df.empty and round_microseconds: + logging.warning("Rounding timestamps to microseconds is now enabled by default." + " To disable, set round_microseconds=False.") + df.index = df.index.round("us") return df diff --git a/aeon/dj_pipeline/acquisition.py b/aeon/dj_pipeline/acquisition.py index 38499455..e943adc9 100644 --- a/aeon/dj_pipeline/acquisition.py +++ b/aeon/dj_pipeline/acquisition.py @@ -12,7 +12,7 @@ from aeon.dj_pipeline.utils import paths from aeon.io import api as io_api from aeon.io import reader as io_reader -from aeon.schema import schemas as aeon_schemas +from aeon.schema import ingestion_schemas as aeon_schemas logger = dj.logger schema = dj.schema(get_schema_name("acquisition")) @@ -646,10 +646,14 @@ def _match_experiment_directory(experiment_name, path, directories): def create_chunk_restriction(experiment_name, start_time, end_time): """Create a time restriction string for the chunks between the specified "start" and "end" times.""" + exp_key = {"experiment_name": experiment_name} start_restriction = f'"{start_time}" BETWEEN chunk_start AND chunk_end' end_restriction = f'"{end_time}" BETWEEN chunk_start AND chunk_end' - start_query = Chunk & {"experiment_name": experiment_name} & start_restriction - end_query = Chunk & {"experiment_name": experiment_name} & end_restriction + start_query = Chunk & exp_key & start_restriction + end_query = Chunk & exp_key & end_restriction + if not end_query: + # No chunk contains the end time, so we need to find the last chunk that starts before the end time + end_query = Chunk & exp_key & f'chunk_end BETWEEN "{start_time}" AND "{end_time}"' if not (start_query and end_query): raise ValueError(f"No Chunk found between {start_time} and {end_time}") time_restriction = ( diff --git a/aeon/dj_pipeline/analysis/block_analysis.py b/aeon/dj_pipeline/analysis/block_analysis.py index 7e853a5b..2bc73823 100644 --- a/aeon/dj_pipeline/analysis/block_analysis.py +++ b/aeon/dj_pipeline/analysis/block_analysis.py @@ -53,6 +53,8 @@ class BlockDetection(dj.Computed): -> acquisition.Environment """ + key_source = acquisition.Environment - {"experiment_name": "social0.1-aeon3"} + def make(self, key): """On a per-chunk basis, check for the presence of new block, insert into Block table. @@ -88,8 +90,7 @@ def make(self, key): blocks_df = block_state_df[block_state_df.pellet_ct == 0] # account for the double 0s - find any 0s that are within 1 second of each other, remove the 2nd one double_0s = blocks_df.index.to_series().diff().dt.total_seconds() < 1 - # find the indices of the 2nd 0s and remove - double_0s = double_0s.shift(-1).fillna(False) + # keep the first 0s blocks_df = blocks_df[~double_0s] block_entries = [] @@ -144,8 +145,8 @@ class Patch(dj.Part): wheel_timestamps: longblob patch_threshold: longblob patch_threshold_timestamps: longblob - patch_rate: float - patch_offset: float + patch_rate=null: float + patch_offset=null: float """ class Subject(dj.Part): @@ -181,7 +182,6 @@ def make(self, key): streams.UndergroundFeederDepletionState, streams.UndergroundFeederDeliverPellet, streams.UndergroundFeederEncoder, - tracking.SLEAPTracking, ) for streams_table in streams_tables: if len(streams_table & chunk_keys) < len(streams_table.key_source & chunk_keys): @@ -189,9 +189,20 @@ def make(self, key): f"BlockAnalysis Not Ready - {streams_table.__name__} not yet fully ingested for block: {key}. Skipping (to retry later)..." ) + # Check if SLEAPTracking is ready, if not, see if BlobPosition can be used instead + use_blob_position = False + if len(tracking.SLEAPTracking & chunk_keys) < len(tracking.SLEAPTracking.key_source & chunk_keys): + if len(tracking.BlobPosition & chunk_keys) < len(tracking.BlobPosition.key_source & chunk_keys): + raise ValueError( + f"BlockAnalysis Not Ready - SLEAPTracking (and BlobPosition) not yet fully ingested for block: {key}. Skipping (to retry later)..." + ) + else: + use_blob_position = True + # Patch data - TriggerPellet, DepletionState, Encoder (distancetravelled) - # For wheel data, downsample to 10Hz - final_encoder_fs = 10 + # For wheel data, downsample to 50Hz + final_encoder_hz = 50 + freq = 1 / final_encoder_hz * 1e3 # in ms maintenance_period = get_maintenance_periods(key["experiment_name"], block_start, block_end) @@ -233,30 +244,36 @@ def make(self, key): encoder_df, maintenance_period, block_end, dropna=True ) - if depletion_state_df.empty: - raise ValueError(f"No depletion state data found for block {key} - patch: {patch_name}") - - encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled(encoder_df.angle) - - if len(depletion_state_df.rate.unique()) > 1: - # multiple patch rates per block is unexpected, log a note and pick the first rate to move forward - AnalysisNote.insert1( - { - "note_timestamp": datetime.utcnow(), - "note_type": "Multiple patch rates", - "note": f"Found multiple patch rates for block {key} - patch: {patch_name} - rates: {depletion_state_df.rate.unique()}", - } - ) + # if all dataframes are empty, skip + if pellet_ts_threshold_df.empty and depletion_state_df.empty and encoder_df.empty: + continue - patch_rate = depletion_state_df.rate.iloc[0] - patch_offset = depletion_state_df.offset.iloc[0] - # handles patch rate value being INF - patch_rate = 999999999 if np.isinf(patch_rate) else patch_rate + if encoder_df.empty: + encoder_df["distance_travelled"] = 0 + else: + # -1 is for placement of magnetic encoder, where wheel movement actually decreases encoder + encoder_df["distance_travelled"] = -1 * analysis_utils.distancetravelled(encoder_df.angle) + encoder_df = encoder_df.resample(f"{freq}ms").first() + + if not depletion_state_df.empty: + if len(depletion_state_df.rate.unique()) > 1: + # multiple patch rates per block is unexpected, log a note and pick the first rate to move forward + AnalysisNote.insert1( + { + "note_timestamp": datetime.utcnow(), + "note_type": "Multiple patch rates", + "note": f"Found multiple patch rates for block {key} - patch: {patch_name} - rates: {depletion_state_df.rate.unique()}", + } + ) - encoder_fs = ( - 1 / encoder_df.index.to_series().diff().dt.total_seconds().median() - ) # mean or median? - wheel_downsampling_factor = int(encoder_fs / final_encoder_fs) + patch_rate = depletion_state_df.rate.iloc[0] + patch_offset = depletion_state_df.offset.iloc[0] + # handles patch rate value being INF + patch_rate = 999999999 if np.isinf(patch_rate) else patch_rate + else: + logger.warning(f"No depletion state data found for block {key} - patch: {patch_name}") + patch_rate = None + patch_offset = None block_patch_entries.append( { @@ -264,10 +281,8 @@ def make(self, key): "patch_name": patch_name, "pellet_count": len(pellet_ts_threshold_df), "pellet_timestamps": pellet_ts_threshold_df.pellet_timestamp.values, - "wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values[ - ::wheel_downsampling_factor - ], - "wheel_timestamps": encoder_df.index.values[::wheel_downsampling_factor], + "wheel_cumsum_distance_travelled": encoder_df.distance_travelled.values, + "wheel_timestamps": encoder_df.index.values, "patch_threshold": pellet_ts_threshold_df.threshold.values, "patch_threshold_timestamps": pellet_ts_threshold_df.index.values, "patch_rate": patch_rate, @@ -275,9 +290,6 @@ def make(self, key): } ) - # update block_end if last timestamp of encoder_df is before the current block_end - block_end = min(encoder_df.index[-1], block_end) - # Subject data # Get all unique subjects that visited the environment over the entire exp; # For each subject, see 'type' of visit most recent to start of block @@ -288,27 +300,50 @@ def make(self, key): & f'chunk_start <= "{chunk_keys[-1]["chunk_start"]}"' )[:block_start] subject_visits_df = subject_visits_df[subject_visits_df.region == "Environment"] + subject_visits_df = subject_visits_df[~subject_visits_df.id.str.contains("Test", case=False)] subject_names = [] for subject_name in set(subject_visits_df.id): _df = subject_visits_df[subject_visits_df.id == subject_name] if _df.type.iloc[-1] != "Exit": subject_names.append(subject_name) + if use_blob_position and len(subject_names) > 1: + raise ValueError( + f"Without SLEAPTracking, BlobPosition can only handle single-subject block. Found {len(subject_names)} subjects." + ) + block_subject_entries = [] for subject_name in subject_names: # positions - query for CameraTop, identity_name matches subject_name, - pos_query = ( - streams.SpinnakerVideoSource - * tracking.SLEAPTracking.PoseIdentity.proj("identity_name", part_name="anchor_part") - * tracking.SLEAPTracking.Part - & key - & { - "spinnaker_video_source_name": "CameraTop", - "identity_name": subject_name, - } - & chunk_restriction - ) - pos_df = fetch_stream(pos_query)[block_start:block_end] + if use_blob_position: + pos_query = ( + streams.SpinnakerVideoSource + * tracking.BlobPosition.Object + & key + & chunk_restriction + & { + "spinnaker_video_source_name": "CameraTop", + "identity_name": subject_name + } + ) + pos_df = fetch_stream(pos_query)[block_start:block_end] + pos_df["likelihood"] = np.nan + # keep only rows with area between 0 and 1000 - likely artifacts otherwise + pos_df = pos_df[(pos_df.area > 0) & (pos_df.area < 1000)] + else: + pos_query = ( + streams.SpinnakerVideoSource + * tracking.SLEAPTracking.PoseIdentity.proj("identity_name", part_name="anchor_part") + * tracking.SLEAPTracking.Part + & key + & { + "spinnaker_video_source_name": "CameraTop", + "identity_name": subject_name, + } + & chunk_restriction + ) + pos_df = fetch_stream(pos_query)[block_start:block_end] + pos_df = filter_out_maintenance_periods(pos_df, maintenance_period, block_end) if pos_df.empty: @@ -345,8 +380,8 @@ def make(self, key): { **key, "block_duration": (block_end - block_start).total_seconds() / 3600, - "patch_count": len(patch_keys), - "subject_count": len(subject_names), + "patch_count": len(block_patch_entries), + "subject_count": len(block_subject_entries), } ) self.Patch.insert(block_patch_entries) @@ -423,6 +458,17 @@ def make(self, key): ) subjects_positions_df.set_index("position_timestamps", inplace=True) + # Ensure wheel_timestamps are of the same length across all patches + wheel_lens = [len(p["wheel_timestamps"]) for p in block_patches] + if len(set(wheel_lens)) > 1: + max_diff = max(wheel_lens) - min(wheel_lens) + if max_diff > 10: + # if diff is more than 10 samples, raise error, this is unexpected, some patches crash? + raise ValueError(f"Wheel data lengths are not consistent across patches ({max_diff} samples diff)") + for p in block_patches: + p["wheel_timestamps"] = p["wheel_timestamps"][: min(wheel_lens)] + p["wheel_cumsum_distance_travelled"] = p["wheel_cumsum_distance_travelled"][: min(wheel_lens)] + self.insert1(key) in_patch_radius = 130 # pixels @@ -541,7 +587,7 @@ def make(self, key): | { "patch_name": patch["patch_name"], "subject_name": subject_name, - "in_patch_timestamps": subject_in_patch.index.values, + "in_patch_timestamps": subject_in_patch[in_patch[subject_name]].index.values, "in_patch_time": subject_in_patch_cum_time[-1], "pellet_count": len(subj_pellets), "pellet_timestamps": subj_pellets.index.values, @@ -1521,10 +1567,10 @@ def make(self, key): foraging_bout_df = get_foraging_bouts(key) foraging_bout_df.rename( columns={ - "subject_name": "subject", - "bout_start": "start", - "bout_end": "end", - "pellet_count": "n_pellets", + "subject": "subject_name", + "start": "bout_start", + "end": "bout_end", + "n_pellets": "pellet_count", "cum_wheel_dist": "cum_wheel_dist", }, inplace=True, @@ -1540,7 +1586,7 @@ def make(self, key): @schema class AnalysisNote(dj.Manual): definition = """ # Generic table to catch all notes generated during analysis - note_timestamp: datetime + note_timestamp: datetime(6) --- note_type='': varchar(64) note: varchar(3000) diff --git a/aeon/dj_pipeline/populate/worker.py b/aeon/dj_pipeline/populate/worker.py index 68f1803a..81e9cb18 100644 --- a/aeon/dj_pipeline/populate/worker.py +++ b/aeon/dj_pipeline/populate/worker.py @@ -99,12 +99,32 @@ def ingest_epochs_chunks(): ) analysis_worker(block_analysis.BlockAnalysis, max_calls=6) -analysis_worker(block_analysis.BlockPlots, max_calls=6) analysis_worker(block_analysis.BlockSubjectAnalysis, max_calls=6) -analysis_worker(block_analysis.BlockSubjectPlots, max_calls=6) +analysis_worker(block_analysis.BlockForaging, max_calls=6) +analysis_worker(block_analysis.BlockPatchPlots, max_calls=6) +analysis_worker(block_analysis.BlockSubjectPositionPlots, max_calls=6) def get_workflow_operation_overview(): from datajoint_utilities.dj_worker.utils import get_workflow_operation_overview return get_workflow_operation_overview(worker_schema_name=worker_schema_name, db_prefixes=[db_prefix]) + + +def retrieve_schemas_sizes(schema_only=False, all_schemas=False): + schema_names = [n for n in dj.list_schemas() if n != "mysql"] + if not all_schemas: + schema_names = [n for n in schema_names + if n.startswith(db_prefix) and not n.startswith(f"{db_prefix}archived")] + + if schema_only: + return {n: dj.Schema(n).size_on_disk / 1e9 for n in schema_names} + + schema_sizes = {n: {} for n in schema_names} + for n in schema_names: + vm = dj.VirtualModule(n, n) + schema_sizes[n]["schema_gb"] = vm.schema.size_on_disk / 1e9 + schema_sizes[n]["tables_gb"] = {n: t().size_on_disk / 1e9 + for n, t in vm.__dict__.items() + if isinstance(t, dj.user_tables.TableMeta)} + return schema_sizes diff --git a/aeon/dj_pipeline/scripts/reingest_fullpose_sleap_data.py b/aeon/dj_pipeline/scripts/reingest_fullpose_sleap_data.py new file mode 100644 index 00000000..b3586f82 --- /dev/null +++ b/aeon/dj_pipeline/scripts/reingest_fullpose_sleap_data.py @@ -0,0 +1,58 @@ +from datetime import datetime +from aeon.dj_pipeline import acquisition, tracking + +aeon_schemas = acquisition.aeon_schemas +logger = acquisition.logger + +exp_key = {"experiment_name": "social0.2-aeon4"} + + +def find_chunks_to_reingest(exp_key, delete_not_fullpose=False): + """ + Find chunks with newly available full pose data to reingest. + If available, fullpose data can be found in `processed` folder + """ + + device_name = "CameraTop" + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": exp_key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + stream_reader = getattr(getattr(devices_schema, device_name), "Pose") + + # special ingestion case for social0.2 full-pose data (using Pose reader from social03) + if exp_key["experiment_name"].startswith("social0.2"): + from aeon.io import reader as io_reader + stream_reader = getattr(getattr(devices_schema, device_name), "Pose03") + assert isinstance(stream_reader, io_reader.Pose), "Pose03 is not a Pose reader" + + # find processed path for exp_key + processed_dir = acquisition.Experiment.get_data_directory(exp_key, "processed") + + files = sorted(f.stem for f in processed_dir.rglob(f"{stream_reader.pattern}.bin") if f.is_file()) + # extract timestamps from the file names & convert to datetime + file_times = [datetime.strptime(f.split("_")[-1], "%Y-%m-%dT%H-%M-%S") for f in files] + + # sleap query with files in processed dir + query = acquisition.Chunk & exp_key & [{"chunk_start": t} for t in file_times] + epochs = acquisition.Epoch & query.proj("epoch_start") + sleap_query = tracking.SLEAPTracking & (acquisition.Chunk & epochs.proj("epoch_start")) + + fullpose, not_fullpose = [], [] + for key in sleap_query.fetch("KEY"): + identity_count = len(tracking.SLEAPTracking.PoseIdentity & key) + part_count = len(tracking.SLEAPTracking.Part & key) + if part_count <= identity_count: + not_fullpose.append(key) + else: + fullpose.append(key) + + logger.info(f"Fullpose: {len(fullpose)} | Not fullpose: {len(not_fullpose)}") + + if delete_not_fullpose: + (tracking.SLEAPTracking & not_fullpose).delete() + + return fullpose, not_fullpose diff --git a/aeon/dj_pipeline/scripts/sync_ingested_and_raw_epochs.py b/aeon/dj_pipeline/scripts/sync_ingested_and_raw_epochs.py new file mode 100644 index 00000000..186355ce --- /dev/null +++ b/aeon/dj_pipeline/scripts/sync_ingested_and_raw_epochs.py @@ -0,0 +1,72 @@ +import datajoint as dj +from datetime import datetime + +from aeon.dj_pipeline import acquisition, streams +from aeon.dj_pipeline.analysis import block_analysis + +aeon_schemas = acquisition.aeon_schemas +logger = acquisition.logger + +exp_key = {"experiment_name": "social0.3-aeon3"} + + +def find_orphaned_ingested_epochs(exp_key, delete_invalid_epochs=False): + """ + Find ingested epochs that are no longer valid + This is due to the raw epoch/chunk files/directories being deleted for whatever reason + (e.g. bad data, testing, etc.) + """ + raw_dir = acquisition.Experiment.get_data_directory(exp_key, "raw") + epoch_dirs = [d.name for d in raw_dir.glob("*T*") if d.is_dir()] + + epoch_query = acquisition.Epoch.join(acquisition.EpochEnd, left=True) & exp_key + + valid_epochs = epoch_query & f"epoch_dir in {tuple(epoch_dirs)}" + invalid_epochs = epoch_query - f"epoch_dir in {tuple(epoch_dirs)}" + + logger.info(f"Valid Epochs: {len(valid_epochs)} | Invalid Epochs: {len(invalid_epochs)}") + + if not invalid_epochs or not delete_invalid_epochs: + return + + # delete blocks + # delete streams device installations + # delete epochs + invalid_blocks = [] + for key in invalid_epochs.fetch("KEY"): + epoch_start, epoch_end = (invalid_epochs & key).fetch1("epoch_start", "epoch_end") + invalid_blocks.extend( + (block_analysis.Block + & exp_key + & f"block_start BETWEEN '{epoch_start}' AND '{epoch_end}'").fetch("KEY")) + + # devices + invalid_devices_query = acquisition.EpochConfig.DeviceType & invalid_epochs + device_types = set(invalid_devices_query.fetch("device_type")) + device_table_invalid_query = [] + for device_type in device_types: + device_table = getattr(streams, device_type) + install_time_attr_name = next(n for n in device_table.primary_key if n.endswith("_install_time")) + invalid_device_query = device_table & invalid_epochs.proj(**{install_time_attr_name: "epoch_start"}) + if invalid_device_query: + logger.warning("Invalid devices found - please run the rest manually to confirm deletion") + logger.warning(invalid_devices_query) + return + logger.debug(invalid_device_query) + device_table_invalid_query.append((device_table, invalid_device_query)) + + # delete + dj.conn().start_transaction() + + with dj.config(safemode=False): + (block_analysis.Block & invalid_blocks).delete() + for device_table, invalid_query in device_table_invalid_query: + (device_table & invalid_query.fetch("KEY")).delete() + (acquisition.Epoch & invalid_epochs.fetch("KEY")).delete() + + if dj.utils.user_choice("Commit deletes?", default="no") == "yes": + dj.conn().commit_transaction() + logger.info("Deletes committed.") + else: + dj.conn().cancel_transaction() + logger.info("Deletes cancelled") diff --git a/aeon/dj_pipeline/streams.py b/aeon/dj_pipeline/streams.py index e3d6ba12..6c3d01b0 100644 --- a/aeon/dj_pipeline/streams.py +++ b/aeon/dj_pipeline/streams.py @@ -9,27 +9,27 @@ import aeon from aeon.dj_pipeline import acquisition, get_schema_name from aeon.io import api as io_api -from aeon.schema import schemas as aeon_schemas + +aeon_schemas = acquisition.aeon_schemas schema = dj.Schema(get_schema_name("streams")) -@schema +@schema class StreamType(dj.Lookup): """Catalog of all steam types for the different device types used across Project Aeon. One StreamType corresponds to one reader class in `aeon.io.reader`. The combination of `stream_reader` and `stream_reader_kwargs` should fully specify the data loading routine for a particular device, using the `aeon.io.utils`.""" definition = """ # Catalog of all stream types used across Project Aeon - stream_type : varchar(20) + stream_type : varchar(36) --- stream_reader : varchar(256) # name of the reader class found in `aeon_mecha` package (e.g. aeon.io.reader.Video) stream_reader_kwargs : longblob # keyword arguments to instantiate the reader class stream_description='': varchar(256) stream_hash : uuid # hash of dict(stream_reader_kwargs, stream_reader=stream_reader) - unique index (stream_hash) """ -@schema +@schema class DeviceType(dj.Lookup): """Catalog of all device types used across Project Aeon.""" @@ -46,7 +46,7 @@ class Stream(dj.Part): """ -@schema +@schema class Device(dj.Lookup): definition = """ # Physical devices, of a particular type, identified by unique serial number device_serial_number: varchar(12) @@ -55,7 +55,7 @@ class Device(dj.Lookup): """ -@schema +@schema class RfidReader(dj.Manual): definition = f""" # rfid_reader placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-unknown) @@ -82,7 +82,7 @@ class RemovalTime(dj.Part): """ -@schema +@schema class SpinnakerVideoSource(dj.Manual): definition = f""" # spinnaker_video_source placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-unknown) @@ -109,7 +109,7 @@ class RemovalTime(dj.Part): """ -@schema +@schema class UndergroundFeeder(dj.Manual): definition = f""" # underground_feeder placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-unknown) @@ -136,7 +136,7 @@ class RemovalTime(dj.Part): """ -@schema +@schema class WeightScale(dj.Manual): definition = f""" # weight_scale placement and operation for a particular time period, at a certain location, for a given experiment (auto-generated with aeon_mecha-unknown) @@ -163,7 +163,7 @@ class RemovalTime(dj.Part): """ -@schema +@schema class RfidReaderRfidEvents(dj.Imported): definition = """ # Raw per-chunk RfidEvents data stream from RfidReader (auto-generated with aeon_mecha-unknown) -> RfidReader @@ -189,7 +189,6 @@ def key_source(self): def make(self, key): chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - data_dirs = acquisition.Experiment.get_data_directories(key) device_name = (RfidReader & key).fetch1('rfid_reader_name') @@ -224,7 +223,7 @@ def make(self, key): ) -@schema +@schema class SpinnakerVideoSourceVideo(dj.Imported): definition = """ # Raw per-chunk Video data stream from SpinnakerVideoSource (auto-generated with aeon_mecha-unknown) -> SpinnakerVideoSource @@ -232,7 +231,6 @@ class SpinnakerVideoSourceVideo(dj.Imported): --- sample_count: int # number of data points acquired from this stream for a given chunk timestamps: longblob # (datetime) timestamps of Video data - hw_counter: longblob hw_timestamp: longblob """ @@ -251,7 +249,6 @@ def key_source(self): def make(self, key): chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - data_dirs = acquisition.Experiment.get_data_directories(key) device_name = (SpinnakerVideoSource & key).fetch1('spinnaker_video_source_name') @@ -286,7 +283,7 @@ def make(self, key): ) -@schema +@schema class UndergroundFeederBeamBreak(dj.Imported): definition = """ # Raw per-chunk BeamBreak data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder @@ -312,7 +309,6 @@ def key_source(self): def make(self, key): chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - data_dirs = acquisition.Experiment.get_data_directories(key) device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') @@ -347,7 +343,7 @@ def make(self, key): ) -@schema +@schema class UndergroundFeederDeliverPellet(dj.Imported): definition = """ # Raw per-chunk DeliverPellet data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder @@ -373,7 +369,6 @@ def key_source(self): def make(self, key): chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - data_dirs = acquisition.Experiment.get_data_directories(key) device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') @@ -408,7 +403,7 @@ def make(self, key): ) -@schema +@schema class UndergroundFeederDepletionState(dj.Imported): definition = """ # Raw per-chunk DepletionState data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder @@ -436,7 +431,6 @@ def key_source(self): def make(self, key): chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - data_dirs = acquisition.Experiment.get_data_directories(key) device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') @@ -471,7 +465,7 @@ def make(self, key): ) -@schema +@schema class UndergroundFeederEncoder(dj.Imported): definition = """ # Raw per-chunk Encoder data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder @@ -498,7 +492,6 @@ def key_source(self): def make(self, key): chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - data_dirs = acquisition.Experiment.get_data_directories(key) device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') @@ -533,7 +526,7 @@ def make(self, key): ) -@schema +@schema class UndergroundFeederManualDelivery(dj.Imported): definition = """ # Raw per-chunk ManualDelivery data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder @@ -559,7 +552,6 @@ def key_source(self): def make(self, key): chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - data_dirs = acquisition.Experiment.get_data_directories(key) device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') @@ -594,7 +586,7 @@ def make(self, key): ) -@schema +@schema class UndergroundFeederMissedPellet(dj.Imported): definition = """ # Raw per-chunk MissedPellet data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder @@ -620,7 +612,6 @@ def key_source(self): def make(self, key): chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - data_dirs = acquisition.Experiment.get_data_directories(key) device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') @@ -655,7 +646,7 @@ def make(self, key): ) -@schema +@schema class UndergroundFeederRetriedDelivery(dj.Imported): definition = """ # Raw per-chunk RetriedDelivery data stream from UndergroundFeeder (auto-generated with aeon_mecha-unknown) -> UndergroundFeeder @@ -681,7 +672,6 @@ def key_source(self): def make(self, key): chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - data_dirs = acquisition.Experiment.get_data_directories(key) device_name = (UndergroundFeeder & key).fetch1('underground_feeder_name') @@ -716,7 +706,7 @@ def make(self, key): ) -@schema +@schema class WeightScaleWeightFiltered(dj.Imported): definition = """ # Raw per-chunk WeightFiltered data stream from WeightScale (auto-generated with aeon_mecha-unknown) -> WeightScale @@ -743,7 +733,6 @@ def key_source(self): def make(self, key): chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - data_dirs = acquisition.Experiment.get_data_directories(key) device_name = (WeightScale & key).fetch1('weight_scale_name') @@ -778,7 +767,7 @@ def make(self, key): ) -@schema +@schema class WeightScaleWeightRaw(dj.Imported): definition = """ # Raw per-chunk WeightRaw data stream from WeightScale (auto-generated with aeon_mecha-unknown) -> WeightScale @@ -805,7 +794,6 @@ def key_source(self): def make(self, key): chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") - data_dirs = acquisition.Experiment.get_data_directories(key) device_name = (WeightScale & key).fetch1('weight_scale_name') @@ -838,3 +826,5 @@ def make(self, key): }, ignore_extra_fields=True, ) + + diff --git a/aeon/dj_pipeline/subject.py b/aeon/dj_pipeline/subject.py index 20ad8bef..09027961 100644 --- a/aeon/dj_pipeline/subject.py +++ b/aeon/dj_pipeline/subject.py @@ -435,6 +435,10 @@ def make(self, key): def get_pyrat_data(endpoint: str, params: dict = None, **kwargs): + """ + Get data from PyRat API. + See docs at: https://swc.pyrat.cloud/api/v3/docs (production) + """ base_url = "https://swc.pyrat.cloud/api/v3/" pyrat_system_token = os.getenv("PYRAT_SYSTEM_TOKEN") pyrat_user_token = os.getenv("PYRAT_USER_TOKEN") diff --git a/aeon/dj_pipeline/tracking.py b/aeon/dj_pipeline/tracking.py index 01b0a039..0c2ff9c6 100644 --- a/aeon/dj_pipeline/tracking.py +++ b/aeon/dj_pipeline/tracking.py @@ -5,9 +5,10 @@ import numpy as np import pandas as pd -from aeon.dj_pipeline import acquisition, dict_to_uuid, get_schema_name, lab, qc, streams +from aeon.dj_pipeline import acquisition, dict_to_uuid, get_schema_name, lab, qc, streams, fetch_stream from aeon.io import api as io_api -from aeon.schema import schemas as aeon_schemas + +aeon_schemas = acquisition.aeon_schemas schema = dj.schema(get_schema_name("tracking")) @@ -162,7 +163,16 @@ def make(self, key): "devices_schema_name" ), ) - stream_reader = getattr(devices_schema, device_name).Pose + + stream_reader = getattr(getattr(devices_schema, device_name), "Pose") + + # special ingestion case for social0.2 full-pose data (using Pose reader from social03) + # fullpose for social0.2 has a different "pattern" for non-fullpose, hence the Pose03 reader + if key["experiment_name"].startswith("social0.2"): + from aeon.io import reader as io_reader + stream_reader = getattr(getattr(devices_schema, device_name), "Pose03") + assert isinstance(stream_reader, io_reader.Pose), "Pose03 is not a Pose reader" + data_dirs = [acquisition.Experiment.get_data_directory(key, "processed")] pose_data = io_api.load( root=data_dirs, @@ -186,6 +196,12 @@ def make(self, key): continue # get anchor part - always the first one of all the body parts + # FIXME: the logic below to get "anchor_part" is not robust, it relies on the ordering of the unique parts + # but if there are missing frames for the actual anchor part, it will be missed + # and another part will be incorrectly chosen as "anchor_part" + # (2024-10-31) - we recently discovered that the parts are not sorted in the same order across frames + # - further exacerbating the flaw in the logic below + # best is to find a robust way to get the anchor part info from the config file for this chunk anchor_part = np.unique(identity_position.part)[0] for part in set(identity_position.part.values): @@ -222,6 +238,121 @@ def make(self, key): self.Part.insert(part_entries) +# ---------- Blob Position Tracking ------------------ + + +@schema +class BlobPosition(dj.Imported): + definition = """ # Blob object position tracking from a particular camera, for a particular chunk + -> acquisition.Chunk + -> streams.SpinnakerVideoSource + --- + object_count: int # number of objects tracked in this chunk + subject_count: int # number of subjects present in the arena during this chunk + subject_names: varchar(256) # names of subjects present in arena during this chunk + """ + + class Object(dj.Part): + definition = """ # Position data of object tracked by a particular camera tracking + -> master + object_id: int # object with id = -1 means "unknown/not sure", could potentially be the same object as those with other id value + --- + identity_name='': varchar(16) + sample_count: int # number of data points acquired from this stream for a given chunk + x: longblob # (px) object's x-position, in the arena's coordinate frame + y: longblob # (px) object's y-position, in the arena's coordinate frame + timestamps: longblob # (datetime) timestamps of the position data + area=null: longblob # (px^2) object's size detected in the camera + """ + + @property + def key_source(self): + ks = ( + acquisition.Chunk + * ( + streams.SpinnakerVideoSource.join(streams.SpinnakerVideoSource.RemovalTime, left=True) + & "spinnaker_video_source_name='CameraTop'" + ) + & "chunk_start >= spinnaker_video_source_install_time" + & 'chunk_start < IFNULL(spinnaker_video_source_removal_time, "2200-01-01")' + ) + return ks - SLEAPTracking # do this only when SLEAPTracking is not available + + def make(self, key): + chunk_start, chunk_end = (acquisition.Chunk & key).fetch1("chunk_start", "chunk_end") + + data_dirs = acquisition.Experiment.get_data_directories(key) + + device_name = (streams.SpinnakerVideoSource & key).fetch1("spinnaker_video_source_name") + + devices_schema = getattr( + aeon_schemas, + (acquisition.Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1( + "devices_schema_name" + ), + ) + + stream_reader = devices_schema.CameraTop.Position + + positiondata = io_api.load( + root=data_dirs, + reader=stream_reader, + start=pd.Timestamp(chunk_start), + end=pd.Timestamp(chunk_end), + ) + + if not len(positiondata): + raise ValueError(f"No Blob position data found for {key['experiment_name']} - {device_name}") + + # replace id=NaN with -1 + positiondata.fillna({"id": -1}, inplace=True) + positiondata["identity_name"] = "" + + # Find animal(s) in the arena during the chunk + # Get all unique subjects that visited the environment over the entire exp; + # For each subject, see 'type' of visit most recent to start of block + # If "Exit", this animal was not in the block. + subject_visits_df = fetch_stream( + acquisition.Environment.SubjectVisits + & {"experiment_name": key["experiment_name"]} + & f'chunk_start <= "{chunk_start}"' + )[:chunk_end] + subject_visits_df = subject_visits_df[subject_visits_df.region == "Environment"] + subject_visits_df = subject_visits_df[~subject_visits_df.id.str.contains("Test", case=False)] + subject_names = [] + for subject_name in set(subject_visits_df.id): + _df = subject_visits_df[subject_visits_df.id == subject_name] + if _df.type.iloc[-1] != "Exit": + subject_names.append(subject_name) + + if len(subject_names) == 1: + # if there is only one known subject, replace all object ids with the subject name + positiondata["id"] = [0] * len(positiondata) + positiondata["identity_name"] = subject_names[0] + + object_positions = [] + for obj_id in set(positiondata.id.values): + obj_position = positiondata[positiondata.id == obj_id] + + object_positions.append( + { + **key, + "object_id": obj_id, + "identity_name": obj_position.identity_name.values[0], + "sample_count": len(obj_position.index.values), + "timestamps": obj_position.index.values, + "x": obj_position.x.values, + "y": obj_position.y.values, + "area": obj_position.area.values, + } + ) + + self.insert1({**key, "object_count": len(object_positions), + "subject_count": len(subject_names), + "subject_names": ",".join(subject_names)}) + self.Object.insert(object_positions) + + # ---------- HELPER ------------------ diff --git a/aeon/dj_pipeline/utils/load_metadata.py b/aeon/dj_pipeline/utils/load_metadata.py index ce1c2775..ce0a248e 100644 --- a/aeon/dj_pipeline/utils/load_metadata.py +++ b/aeon/dj_pipeline/utils/load_metadata.py @@ -4,7 +4,6 @@ import pathlib from collections import defaultdict from pathlib import Path - import datajoint as dj import numpy as np from dotmap import DotMap @@ -16,31 +15,29 @@ logger = dj.logger _weight_scale_rate = 100 _weight_scale_nest = 1 -_aeon_schemas = ["social01", "social02"] def insert_stream_types(): """Insert into streams.streamType table all streams in the aeon schemas.""" - from aeon.schema import schemas as aeon_schemas + from aeon.schema import ingestion_schemas as aeon_schemas streams = dj.VirtualModule("streams", streams_maker.schema_name) - schemas = [getattr(aeon_schemas, aeon_schema) for aeon_schema in _aeon_schemas] - for schema in schemas: - stream_entries = get_stream_entries(schema) + for devices_schema_name in aeon_schemas.__all__: + devices_schema = getattr(aeon_schemas, devices_schema_name) + stream_entries = get_stream_entries(devices_schema) for entry in stream_entries: - q_param = streams.StreamType & {"stream_hash": entry["stream_hash"]} - if q_param: # If the specified stream type already exists - pname = q_param.fetch1("stream_type") - if pname == entry["stream_type"]: - continue - else: - # If the existed stream type does not have the same name: - # human error, trying to add the same content with different name - raise dj.DataJointError(f"The specified stream type already exists - name: {pname}") - else: + try: streams.StreamType.insert1(entry) + logger.info(f"New stream type created: {entry['stream_type']}") + except dj.errors.DuplicateError: + existing_stream = (streams.StreamType.proj( + "stream_reader", "stream_reader_kwargs") + & {"stream_type": entry["stream_type"]}).fetch1() + if existing_stream["stream_reader_kwargs"].get("columns") != entry["stream_reader_kwargs"].get( + "columns"): + logger.warning(f"Stream type already exists:\n\t{entry}\n\t{existing_stream}") def insert_device_types(devices_schema: DotMap, metadata_yml_filepath: Path): @@ -294,7 +291,7 @@ def ingest_epoch_metadata(experiment_name, devices_schema, metadata_yml_filepath return set(epoch_device_types) -# region Get stream & device information +# Get stream & device information def get_stream_entries(devices_schema: DotMap) -> list[dict]: """Returns a list of dictionaries containing the stream entries for a given device. @@ -366,31 +363,25 @@ def _get_class_path(obj): if isinstance(device, DotMap): for stream_type, stream_obj in device.items(): - if stream_obj.__class__.__module__ in [ - "aeon.io.reader", - "aeon.schema.foraging", - "aeon.schema.octagon", - "aeon.schema.social", - ]: - device_info[device_name]["stream_type"].append(stream_type) - device_info[device_name]["stream_reader"].append(_get_class_path(stream_obj)) - - required_args = [ - k for k in inspect.signature(stream_obj.__init__).parameters if k != "self" - ] - pattern = schema_dict[device_name][stream_type].get("pattern") - schema_dict[device_name][stream_type]["pattern"] = pattern.replace( - device_name, "{pattern}" - ) - - kwargs = { - k: v for k, v in schema_dict[device_name][stream_type].items() if k in required_args - } - device_info[device_name]["stream_reader_kwargs"].append(kwargs) - # Add hash - device_info[device_name]["stream_hash"].append( - dict_to_uuid({**kwargs, "stream_reader": _get_class_path(stream_obj)}) - ) + device_info[device_name]["stream_type"].append(stream_type) + device_info[device_name]["stream_reader"].append(_get_class_path(stream_obj)) + + required_args = [ + k for k in inspect.signature(stream_obj.__init__).parameters if k != "self" + ] + pattern = schema_dict[device_name][stream_type].get("pattern") + schema_dict[device_name][stream_type]["pattern"] = pattern.replace( + device_name, "{pattern}" + ) + + kwargs = { + k: v for k, v in schema_dict[device_name][stream_type].items() if k in required_args + } + device_info[device_name]["stream_reader_kwargs"].append(kwargs) + # Add hash + device_info[device_name]["stream_hash"].append( + dict_to_uuid({**kwargs, "stream_reader": _get_class_path(stream_obj)}) + ) else: stream_type = device.__class__.__name__ device_info[device_name]["stream_type"].append(stream_type) @@ -501,6 +492,3 @@ def ingest_epoch_metadata_octagon(experiment_name, metadata_yml_filepath): experiment_table = getattr(streams, f"Experiment{device_type}") if not (experiment_table & {"experiment_name": experiment_name, "device_serial_number": device_sn}): experiment_table.insert1((experiment_name, device_sn, epoch_start, device_name)) - - -# endregion diff --git a/aeon/dj_pipeline/utils/streams_maker.py b/aeon/dj_pipeline/utils/streams_maker.py index 78e5ebaf..bfd669e9 100644 --- a/aeon/dj_pipeline/utils/streams_maker.py +++ b/aeon/dj_pipeline/utils/streams_maker.py @@ -9,7 +9,8 @@ import aeon from aeon.dj_pipeline import acquisition, get_schema_name from aeon.io import api as io_api -from aeon.schema import schemas as aeon_schemas + +aeon_schemas = acquisition.aeon_schemas logger = dj.logger @@ -24,13 +25,12 @@ class StreamType(dj.Lookup): """Catalog of all steam types for the different device types used across Project Aeon. One StreamType corresponds to one reader class in `aeon.io.reader`. The combination of `stream_reader` and `stream_reader_kwargs` should fully specify the data loading routine for a particular device, using the `aeon.io.utils`.""" definition = """ # Catalog of all stream types used across Project Aeon - stream_type : varchar(20) + stream_type : varchar(36) --- stream_reader : varchar(256) # name of the reader class found in `aeon_mecha` package (e.g. aeon.io.reader.Video) stream_reader_kwargs : longblob # keyword arguments to instantiate the reader class stream_description='': varchar(256) stream_hash : uuid # hash of dict(stream_reader_kwargs, stream_reader=stream_reader) - unique index (stream_hash) """ @@ -200,8 +200,8 @@ def main(create_tables=True): "from uuid import UUID\n\n" "import aeon\n" "from aeon.dj_pipeline import acquisition, get_schema_name\n" - "from aeon.io import api as io_api\n" - "from aeon.schema import schemas as aeon_schemas\n\n" + "from aeon.io import api as io_api\n\n" + "aeon_schemas = acquisition.aeon_schemas\n\n" 'schema = dj.Schema(get_schema_name("streams"))\n\n\n' ) f.write(imports_str) diff --git a/aeon/io/reader.py b/aeon/io/reader.py index bf6e8c23..16af3096 100644 --- a/aeon/io/reader.py +++ b/aeon/io/reader.py @@ -11,8 +11,7 @@ from dotmap import DotMap from aeon import util -from aeon.io.api import aeon as aeon_time -from aeon.io.api import chunk, chunk_key +from aeon.io.api import chunk_key _SECONDS_PER_TICK = 32e-6 _payloadtypes = { @@ -76,9 +75,12 @@ def read(self, file): if self.columns is not None and payloadshape[1] < len(self.columns): data = pd.DataFrame(payload, index=seconds, columns=self.columns[: payloadshape[1]]) data[self.columns[payloadshape[1] :]] = math.nan - return data else: - return pd.DataFrame(payload, index=seconds, columns=self.columns) + data = pd.DataFrame(payload, index=seconds, columns=self.columns) + + # remove rows where the index is zero (why? corrupted data in harp files?) + data = data[data.index != 0] + return data class Chunk(Reader): @@ -207,24 +209,6 @@ class Encoder(Harp): def __init__(self, pattern): super().__init__(pattern, columns=["angle", "intensity"]) - def read(self, file, downsample=True): - """Reads encoder data from the specified Harp binary file. - - By default the encoder data is downsampled to 50Hz. Setting downsample to - False or None can be used to force the raw data to be returned. - """ - data = super().read(file) - if downsample is True: - # resample requires a DatetimeIndex so we convert early - data.index = aeon_time(data.index) - - first_index = data.first_valid_index() - if first_index is not None: - # since data is absolute angular position we decimate by taking first of each bin - chunk_origin = chunk(first_index) - data = data.resample("20ms", origin=chunk_origin).first() - return data - class Position(Harp): """Extract 2D position tracking data for a specific camera. @@ -324,18 +308,37 @@ class (int): Int ID of a subject in the environment. """ def __init__(self, pattern: str, model_root: str = "/ceph/aeon/aeon/data/processed"): - """Pose reader constructor.""" - # `pattern` for this reader should typically be '_*' + """Pose reader constructor. + + The pattern for this reader should typically be `__*`. + If a register prefix is required, the pattern should end with a trailing + underscore, e.g. `Camera_202_*`. Otherwise, the pattern should include a + common prefix for the pose model folder excluding the trailing underscore, + e.g. `Camera_model-dir*`. + """ super().__init__(pattern, columns=None) self._model_root = model_root + self._pattern_offset = pattern.rfind("_") + 1 def read(self, file: Path) -> pd.DataFrame: """Reads data from the Harp-binarized tracking file.""" # Get config file from `file`, then bodyparts from config file. - model_dir = Path(*Path(file.stem.replace("_", "/")).parent.parts[-4:]) - config_file_dir = Path(self._model_root) / model_dir - if not config_file_dir.exists(): - raise FileNotFoundError(f"Cannot find model dir {config_file_dir}") + model_dir = Path(file.stem[self._pattern_offset :].replace("_", "/")).parent + + # Check if model directory exists in local or shared directories. + # Local directory is prioritized over shared directory. + local_config_file_dir = file.parent / model_dir + shared_config_file_dir = Path(self._model_root) / model_dir + if local_config_file_dir.exists(): + config_file_dir = local_config_file_dir + elif shared_config_file_dir.exists(): + config_file_dir = shared_config_file_dir + else: + raise FileNotFoundError( + f"""Cannot find model dir in either local ({local_config_file_dir}) \ + or shared ({shared_config_file_dir}) directories""" + ) + config_file = self.get_config_file(config_file_dir) identities = self.get_class_names(config_file) parts = self.get_bodyparts(config_file) @@ -370,7 +373,7 @@ def read(self, file: Path) -> pd.DataFrame: parts = unique_parts # Set new columns, and reformat `data`. - data = self.class_int2str(data, config_file) + data = self.class_int2str(data, identities) n_parts = len(parts) part_data_list = [pd.DataFrame()] * n_parts new_columns = pd.Series(["identity", "identity_likelihood", "part", "x", "y", "part_likelihood"]) @@ -427,18 +430,12 @@ def get_bodyparts(config_file: Path) -> list[str]: return parts @staticmethod - def class_int2str(data: pd.DataFrame, config_file: Path) -> pd.DataFrame: + def class_int2str(data: pd.DataFrame, classes: list[str]) -> pd.DataFrame: """Converts a class integer in a tracking data dataframe to its associated string (subject id).""" - if config_file.stem == "confmap_config": # SLEAP - with open(config_file) as f: - config = json.load(f) - try: - heads = config["model"]["heads"] - classes = util.find_nested_key(heads, "classes") - except KeyError as err: - raise KeyError(f"Cannot find classes in {config_file}.") from err - for i, subj in enumerate(classes): - data.loc[data["identity"] == i, "identity"] = subj + if not classes: + raise ValueError("Classes list cannot be None or empty.") + identity_mapping = dict(enumerate(classes)) + data["identity"] = data["identity"].replace(identity_mapping) return data @classmethod diff --git a/aeon/schema/dataset.py b/aeon/schema/dataset.py deleted file mode 100644 index 0facd64f..00000000 --- a/aeon/schema/dataset.py +++ /dev/null @@ -1,59 +0,0 @@ -from dotmap import DotMap - -import aeon.schema.core as stream -from aeon.schema import foraging, octagon -from aeon.schema.streams import Device - -exp02 = DotMap( - [ - Device("Metadata", stream.Metadata), - Device("ExperimentalMetadata", stream.Environment, stream.MessageLog), - Device("CameraTop", stream.Video, stream.Position, foraging.Region), - Device("CameraEast", stream.Video), - Device("CameraNest", stream.Video), - Device("CameraNorth", stream.Video), - Device("CameraPatch1", stream.Video), - Device("CameraPatch2", stream.Video), - Device("CameraSouth", stream.Video), - Device("CameraWest", stream.Video), - Device("Nest", foraging.Weight), - Device("Patch1", foraging.Patch), - Device("Patch2", foraging.Patch), - ] -) - -exp01 = DotMap( - [ - Device("SessionData", foraging.SessionData), - Device("FrameTop", stream.Video, stream.Position), - Device("FrameEast", stream.Video), - Device("FrameGate", stream.Video), - Device("FrameNorth", stream.Video), - Device("FramePatch1", stream.Video), - Device("FramePatch2", stream.Video), - Device("FrameSouth", stream.Video), - Device("FrameWest", stream.Video), - Device("Patch1", foraging.DepletionFunction, stream.Encoder, foraging.Feeder), - Device("Patch2", foraging.DepletionFunction, stream.Encoder, foraging.Feeder), - ] -) - -octagon01 = DotMap( - [ - Device("Metadata", stream.Metadata), - Device("CameraTop", stream.Video, stream.Position), - Device("CameraColorTop", stream.Video), - Device("ExperimentalMetadata", stream.SubjectState), - Device("Photodiode", octagon.Photodiode), - Device("OSC", octagon.OSC), - Device("TaskLogic", octagon.TaskLogic), - Device("Wall1", octagon.Wall), - Device("Wall2", octagon.Wall), - Device("Wall3", octagon.Wall), - Device("Wall4", octagon.Wall), - Device("Wall5", octagon.Wall), - Device("Wall6", octagon.Wall), - Device("Wall7", octagon.Wall), - Device("Wall8", octagon.Wall), - ] -) diff --git a/aeon/schema/ingestion_schemas.py b/aeon/schema/ingestion_schemas.py new file mode 100644 index 00000000..fe2ee3dd --- /dev/null +++ b/aeon/schema/ingestion_schemas.py @@ -0,0 +1,245 @@ +"""Aeon experiment schemas for DataJoint database ingestion.""" +from os import PathLike + +import pandas as pd +from dotmap import DotMap + +import aeon.schema.core as stream +from aeon.io import reader +from aeon.io.api import aeon as aeon_time, chunk as aeon_chunk +from aeon.schema import foraging, octagon, social_01, social_02, social_03 +from aeon.schema.streams import Device, Stream, StreamGroup + + +# Define new readers +class _Encoder(reader.Encoder): + """A version of the encoder reader that can downsample the data.""" + + def __init__(self, pattern): + super().__init__(pattern) + + def read(self, file: PathLike[str], sr_hz: int = 50) -> pd.DataFrame: + """Reads encoder data from the specified Harp binary file.""" + data = super().read(file) + data.index = aeon_time(data.index) + first_index = data.first_valid_index() + freq = 1 / sr_hz * 1e3 # convert to ms + if first_index is not None: + chunk_origin = aeon_chunk(first_index) + data = data.resample(f"{freq}ms", origin=chunk_origin).first() # take first sample in each resampled bin + return data + + +class _Video(reader.Csv): + """A version of the video reader that retains only the `hw_timestamp` column.""" + + def __init__(self, pattern): + super().__init__(pattern, columns=["hw_timestamp"]) + self._rawcolumns = ["time"] + ['hw_counter', 'hw_timestamp'] + + def read(self, file): + """Reads video metadata from the specified file.""" + data = pd.read_csv(file, header=0, names=self._rawcolumns) + drop_cols = [c for c in data.columns if c not in self.columns + ["time"]] + data.drop(columns=drop_cols, errors="ignore", inplace=True) + data.set_index("time", inplace=True) + return data + + +class Encoder(Stream): + """Wheel magnetic encoder data.""" + + def __init__(self, pattern): + super().__init__(_Encoder(f"{pattern}_90_*")) + + +# Define new streams and stream groups +class Video(Stream): + """Video frame metadata.""" + + def __init__(self, pattern): + super().__init__(_Video(f"{pattern}_*")) + + +class Patch(StreamGroup): + """Data streams for a patch.""" + + def __init__(self, path): + super().__init__(path) + + p = social_02.Patch + e = Encoder + + +# Define schemas +octagon01 = DotMap( + [ + Device("Metadata", stream.Metadata), + Device("CameraTop", Video, stream.Position), + Device("CameraColorTop", Video), + Device("ExperimentalMetadata", stream.SubjectState), + Device("Photodiode", octagon.Photodiode), + Device("OSC", octagon.OSC), + Device("TaskLogic", octagon.TaskLogic), + Device("Wall1", octagon.Wall), + Device("Wall2", octagon.Wall), + Device("Wall3", octagon.Wall), + Device("Wall4", octagon.Wall), + Device("Wall5", octagon.Wall), + Device("Wall6", octagon.Wall), + Device("Wall7", octagon.Wall), + Device("Wall8", octagon.Wall), + ] +) + +exp01 = DotMap( + [ + Device("SessionData", foraging.SessionData), + Device("FrameTop", Video, stream.Position), + Device("FrameEast", Video), + Device("FrameGate", Video), + Device("FrameNorth", Video), + Device("FramePatch1", Video), + Device("FramePatch2", Video), + Device("FrameSouth", Video), + Device("FrameWest", Video), + Device("Patch1", foraging.DepletionFunction, Encoder, foraging.Feeder), + Device("Patch2", foraging.DepletionFunction, Encoder, foraging.Feeder), + ] +) + +exp02 = DotMap( + [ + Device("Metadata", stream.Metadata), + Device("ExperimentalMetadata", stream.Environment, stream.MessageLog), + Device("CameraTop", Video, stream.Position, foraging.Region), + Device("CameraEast", Video), + Device("CameraNest", Video), + Device("CameraNorth", Video), + Device("CameraPatch1", Video), + Device("CameraPatch2", Video), + Device("CameraSouth", Video), + Device("CameraWest", Video), + Device("Nest", foraging.Weight), + Device("Patch1", Patch), + Device("Patch2", Patch), + ] +) + +social01 = DotMap( + [ + Device("Metadata", stream.Metadata), + Device("Environment", social_02.Environment, social_02.SubjectData), + Device("CameraTop", Video, stream.Position, social_01.Pose), + Device("CameraNorth", Video), + Device("CameraSouth", Video), + Device("CameraEast", Video), + Device("CameraWest", Video), + Device("CameraPatch1", Video), + Device("CameraPatch2", Video), + Device("CameraPatch3", Video), + Device("CameraNest", Video), + Device("Nest", social_02.WeightRaw, social_02.WeightFiltered), + Device("Patch1", Patch), + Device("Patch2", Patch), + Device("Patch3", Patch), + Device("RfidGate", social_01.RfidEvents), + Device("RfidNest1", social_01.RfidEvents), + Device("RfidNest2", social_01.RfidEvents), + Device("RfidPatch1", social_01.RfidEvents), + Device("RfidPatch2", social_01.RfidEvents), + Device("RfidPatch3", social_01.RfidEvents), + ] +) + + +social02 = DotMap( + [ + Device("Metadata", stream.Metadata), + Device("Environment", social_02.Environment, social_02.SubjectData), + Device("CameraTop", Video, stream.Position, social_02.Pose, social_02.Pose03), + Device("CameraNorth", Video), + Device("CameraSouth", Video), + Device("CameraEast", Video), + Device("CameraWest", Video), + Device("CameraPatch1", Video), + Device("CameraPatch2", Video), + Device("CameraPatch3", Video), + Device("CameraNest", Video), + Device("Nest", social_02.WeightRaw, social_02.WeightFiltered), + Device("Patch1", Patch), + Device("Patch2", Patch), + Device("Patch3", Patch), + Device("GateRfid", social_02.RfidEvents), + Device("NestRfid1", social_02.RfidEvents), + Device("NestRfid2", social_02.RfidEvents), + Device("Patch1Rfid", social_02.RfidEvents), + Device("Patch2Rfid", social_02.RfidEvents), + Device("Patch3Rfid", social_02.RfidEvents), + ] +) + + +social03 = DotMap( + [ + Device("Metadata", stream.Metadata), + Device("Environment", social_02.Environment, social_02.SubjectData), + Device("CameraTop", Video, stream.Position, social_03.Pose), + Device("CameraNorth", Video), + Device("CameraSouth", Video), + Device("CameraEast", Video), + Device("CameraWest", Video), + Device("CameraNest", Video), + Device("CameraPatch1", Video), + Device("CameraPatch2", Video), + Device("CameraPatch3", Video), + Device("Nest", social_02.WeightRaw, social_02.WeightFiltered), + Device("Patch1", Patch), + Device("Patch2", Patch), + Device("Patch3", Patch), + Device("PatchDummy1", Patch), + Device("NestRfid1", social_02.RfidEvents), + Device("NestRfid2", social_02.RfidEvents), + Device("GateRfid", social_02.RfidEvents), + Device("GateEastRfid", social_02.RfidEvents), + Device("GateWestRfid", social_02.RfidEvents), + Device("Patch1Rfid", social_02.RfidEvents), + Device("Patch2Rfid", social_02.RfidEvents), + Device("Patch3Rfid", social_02.RfidEvents), + Device("PatchDummy1Rfid", social_02.RfidEvents), + ] +) + + +social04 = DotMap( + [ + Device("Metadata", stream.Metadata), + Device("Environment", social_02.Environment, social_02.SubjectData, social_03.EnvironmentActiveConfiguration), + Device("CameraTop", Video, stream.Position, social_03.Pose), + Device("CameraNorth", Video), + Device("CameraSouth", Video), + Device("CameraEast", Video), + Device("CameraWest", Video), + Device("CameraNest", Video), + Device("CameraPatch1", Video), + Device("CameraPatch2", Video), + Device("CameraPatch3", Video), + Device("Nest", social_02.WeightRaw, social_02.WeightFiltered), + Device("Patch1", Patch), + Device("Patch2", Patch), + Device("Patch3", Patch), + Device("PatchDummy1", Patch), + Device("NestRfid1", social_02.RfidEvents), + Device("NestRfid2", social_02.RfidEvents), + Device("GateRfid", social_02.RfidEvents), + Device("GateEastRfid", social_02.RfidEvents), + Device("GateWestRfid", social_02.RfidEvents), + Device("Patch1Rfid", social_02.RfidEvents), + Device("Patch2Rfid", social_02.RfidEvents), + Device("Patch3Rfid", social_02.RfidEvents), + Device("PatchDummy1Rfid", social_02.RfidEvents), + ] +) + +# __all__ = ["octagon01", "exp01", "exp02", "social01", "social02", "social03", "social04"] +__all__ = ["social02", "social03", "social04"] diff --git a/aeon/schema/social_02.py b/aeon/schema/social_02.py index 04946679..9b50cf60 100644 --- a/aeon/schema/social_02.py +++ b/aeon/schema/social_02.py @@ -48,6 +48,12 @@ def __init__(self, path): super().__init__(_reader.Pose(f"{path}_test-node1*")) +class Pose03(Stream): + + def __init__(self, path): + super().__init__(_reader.Pose(f"{path}_202_*")) + + class WeightRaw(Stream): def __init__(self, path): super().__init__(_reader.Harp(f"{path}_200_*", ["weight(g)", "stability"])) diff --git a/aeon/util.py b/aeon/util.py index ceb0637a..4c2ad86c 100644 --- a/aeon/util.py +++ b/aeon/util.py @@ -14,7 +14,7 @@ def find_nested_key(obj: dict | list, key: str) -> Any: found = find_nested_key(v, key) if found: return found - elif isinstance(obj, list): + elif obj is not None: for item in obj: found = find_nested_key(item, key) if found: diff --git a/tests/data/pose/2024-03-01T16-46-12/CameraTop/CameraTop_202_test-node1_topdown-multianimal-id-133_2024-03-02T12-00-00.bin b/tests/data/pose/2024-03-01T16-46-12/CameraTop/CameraTop_202_test-node1_topdown-multianimal-id-133_2024-03-02T12-00-00.bin new file mode 100644 index 00000000..55f13c0f Binary files /dev/null and b/tests/data/pose/2024-03-01T16-46-12/CameraTop/CameraTop_202_test-node1_topdown-multianimal-id-133_2024-03-02T12-00-00.bin differ diff --git a/tests/data/pose/2024-03-01T16-46-12/CameraTop/CameraTop_test-node1_topdown-multianimal-id-133_2024-03-02T12-00-00.bin b/tests/data/pose/2024-03-01T16-46-12/CameraTop/CameraTop_test-node1_topdown-multianimal-id-133_2024-03-02T12-00-00.bin new file mode 100644 index 00000000..806424a8 Binary files /dev/null and b/tests/data/pose/2024-03-01T16-46-12/CameraTop/CameraTop_test-node1_topdown-multianimal-id-133_2024-03-02T12-00-00.bin differ diff --git a/tests/data/pose/2024-03-01T16-46-12/CameraTop/test-node1/topdown-multianimal-id-133/confmap_config.json b/tests/data/pose/2024-03-01T16-46-12/CameraTop/test-node1/topdown-multianimal-id-133/confmap_config.json new file mode 100644 index 00000000..5a2084b0 --- /dev/null +++ b/tests/data/pose/2024-03-01T16-46-12/CameraTop/test-node1/topdown-multianimal-id-133/confmap_config.json @@ -0,0 +1,202 @@ +{ + "data": { + "labels": { + "training_labels": "social_dev_b5350ff/aeon3_social_dev_b5350ff_ceph.slp", + "validation_labels": null, + "validation_fraction": 0.1, + "test_labels": null, + "split_by_inds": false, + "training_inds": null, + "validation_inds": null, + "test_inds": null, + "search_path_hints": [], + "skeletons": [ + { + "directed": true, + "graph": { + "name": "Skeleton-1", + "num_edges_inserted": 0 + }, + "links": [], + "multigraph": true, + "nodes": [ + { + "id": { + "py/object": "sleap.skeleton.Node", + "py/state": { + "py/tuple": [ + "centroid", + 1.0 + ] + } + } + } + ] + } + ] + }, + "preprocessing": { + "ensure_rgb": false, + "ensure_grayscale": false, + "imagenet_mode": null, + "input_scaling": 1.0, + "pad_to_stride": 16, + "resize_and_pad_to_target": true, + "target_height": 1080, + "target_width": 1440 + }, + "instance_cropping": { + "center_on_part": "centroid", + "crop_size": 96, + "crop_size_detection_padding": 16 + } + }, + "model": { + "backbone": { + "leap": null, + "unet": { + "stem_stride": null, + "max_stride": 16, + "output_stride": 2, + "filters": 16, + "filters_rate": 1.5, + "middle_block": true, + "up_interpolate": false, + "stacks": 1 + }, + "hourglass": null, + "resnet": null, + "pretrained_encoder": null + }, + "heads": { + "single_instance": null, + "centroid": null, + "centered_instance": null, + "multi_instance": null, + "multi_class_bottomup": null, + "multi_class_topdown": { + "confmaps": { + "anchor_part": "centroid", + "part_names": [ + "centroid" + ], + "sigma": 1.5, + "output_stride": 2, + "loss_weight": 1.0, + "offset_refinement": false + }, + "class_vectors": { + "classes": [ + "BAA-1104045", + "BAA-1104047" + ], + "num_fc_layers": 3, + "num_fc_units": 256, + "global_pool": true, + "output_stride": 2, + "loss_weight": 0.01 + } + } + }, + "base_checkpoint": null + }, + "optimization": { + "preload_data": true, + "augmentation_config": { + "rotate": true, + "rotation_min_angle": -180.0, + "rotation_max_angle": 180.0, + "translate": false, + "translate_min": -5, + "translate_max": 5, + "scale": false, + "scale_min": 0.9, + "scale_max": 1.1, + "uniform_noise": false, + "uniform_noise_min_val": 0.0, + "uniform_noise_max_val": 10.0, + "gaussian_noise": false, + "gaussian_noise_mean": 5.0, + "gaussian_noise_stddev": 1.0, + "contrast": false, + "contrast_min_gamma": 0.5, + "contrast_max_gamma": 2.0, + "brightness": false, + "brightness_min_val": 0.0, + "brightness_max_val": 10.0, + "random_crop": false, + "random_crop_height": 256, + "random_crop_width": 256, + "random_flip": false, + "flip_horizontal": true + }, + "online_shuffling": true, + "shuffle_buffer_size": 128, + "prefetch": true, + "batch_size": 4, + "batches_per_epoch": 469, + "min_batches_per_epoch": 200, + "val_batches_per_epoch": 54, + "min_val_batches_per_epoch": 10, + "epochs": 200, + "optimizer": "adam", + "initial_learning_rate": 0.0001, + "learning_rate_schedule": { + "reduce_on_plateau": true, + "reduction_factor": 0.1, + "plateau_min_delta": 1e-08, + "plateau_patience": 20, + "plateau_cooldown": 3, + "min_learning_rate": 1e-08 + }, + "hard_keypoint_mining": { + "online_mining": false, + "hard_to_easy_ratio": 2.0, + "min_hard_keypoints": 2, + "max_hard_keypoints": null, + "loss_scale": 5.0 + }, + "early_stopping": { + "stop_training_on_plateau": true, + "plateau_min_delta": 1e-08, + "plateau_patience": 20 + } + }, + "outputs": { + "save_outputs": true, + "run_name": "aeon3_social_dev_b5350ff_ceph_topdown_top.centered_instance_multiclass", + "run_name_prefix": "", + "run_name_suffix": "", + "runs_folder": "social_dev_b5350ff/models", + "tags": [], + "save_visualizations": true, + "delete_viz_images": true, + "zip_outputs": false, + "log_to_csv": true, + "checkpointing": { + "initial_model": true, + "best_model": true, + "every_epoch": false, + "latest_model": false, + "final_model": false + }, + "tensorboard": { + "write_logs": false, + "loss_frequency": "epoch", + "architecture_graph": false, + "profile_graph": false, + "visualizations": true + }, + "zmq": { + "subscribe_to_controller": false, + "controller_address": "tcp://127.0.0.1:9000", + "controller_polling_timeout": 10, + "publish_updates": false, + "publish_address": "tcp://127.0.0.1:9001" + } + }, + "name": "", + "description": "", + "sleap_version": "1.3.1", + "filename": "Z:/aeon/data/processed/test-node1/4310907/2024-01-12T19-00-00/topdown-multianimal-id-133/confmap_config.json" +} \ No newline at end of file diff --git a/tests/io/test_api.py b/tests/io/test_api.py index 095439de..2a491c55 100644 --- a/tests/io/test_api.py +++ b/tests/io/test_api.py @@ -5,16 +5,17 @@ from pytest import mark import aeon +from aeon.schema.ingestion_schemas import social03 from aeon.schema.schemas import exp02 -nonmonotonic_path = Path(__file__).parent.parent / "data" / "nonmonotonic" monotonic_path = Path(__file__).parent.parent / "data" / "monotonic" +nonmonotonic_path = Path(__file__).parent.parent / "data" / "nonmonotonic" @mark.api def test_load_start_only(): data = aeon.load( - nonmonotonic_path, exp02.Patch2.Encoder, start=pd.Timestamp("2022-06-06T13:00:49"), downsample=None + nonmonotonic_path, exp02.Patch2.Encoder, start=pd.Timestamp("2022-06-06T13:00:49") ) assert len(data) > 0 @@ -22,7 +23,7 @@ def test_load_start_only(): @mark.api def test_load_end_only(): data = aeon.load( - nonmonotonic_path, exp02.Patch2.Encoder, end=pd.Timestamp("2022-06-06T13:00:49"), downsample=None + nonmonotonic_path, exp02.Patch2.Encoder, end=pd.Timestamp("2022-06-06T13:00:49") ) assert len(data) > 0 @@ -35,22 +36,22 @@ def test_load_filter_nonchunked(): @mark.api def test_load_monotonic(): - data = aeon.load(monotonic_path, exp02.Patch2.Encoder, downsample=None) + data = aeon.load(monotonic_path, exp02.Patch2.Encoder) assert len(data) > 0 assert data.index.is_monotonic_increasing @mark.api def test_load_nonmonotonic(): - data = aeon.load(nonmonotonic_path, exp02.Patch2.Encoder, downsample=None) + data = aeon.load(nonmonotonic_path, exp02.Patch2.Encoder) assert not data.index.is_monotonic_increasing @mark.api def test_load_encoder_with_downsampling(): DOWNSAMPLE_PERIOD = 0.02 - data = aeon.load(monotonic_path, exp02.Patch2.Encoder, downsample=True) - raw_data = aeon.load(monotonic_path, exp02.Patch2.Encoder, downsample=None) + data = aeon.load(monotonic_path, social03.Patch2.Encoder) + raw_data = aeon.load(monotonic_path, exp02.Patch2.Encoder) # Check that the length of the downsampled data is less than the raw data assert len(data) < len(raw_data) diff --git a/tests/io/test_reader.py b/tests/io/test_reader.py new file mode 100644 index 00000000..640768ab --- /dev/null +++ b/tests/io/test_reader.py @@ -0,0 +1,25 @@ +from pathlib import Path + +import pytest +from pytest import mark + +import aeon +from aeon.schema.schemas import social02, social03 + +pose_path = Path(__file__).parent.parent / "data" / "pose" + + +@mark.api +def test_Pose_read_local_model_dir(): + data = aeon.load(pose_path, social02.CameraTop.Pose) + assert len(data) > 0 + + +@mark.api +def test_Pose_read_local_model_dir_with_register_prefix(): + data = aeon.load(pose_path, social03.CameraTop.Pose) + assert len(data) > 0 + + +if __name__ == "__main__": + pytest.main()