Skip to content

Commit

Permalink
Merge pull request #309 from ttngu207/datajoint_social01
Browse files Browse the repository at this point in the history
Block-level analysis and plots - draft 1
  • Loading branch information
JaerongA authored Jan 30, 2024
2 parents 5220b0d + 4f65773 commit c0afa90
Show file tree
Hide file tree
Showing 9 changed files with 939 additions and 107 deletions.
17 changes: 17 additions & 0 deletions aeon/dj_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,23 @@ def dict_to_uuid(key) -> uuid.UUID:
return uuid.UUID(hex=hashed.hexdigest())


def fetch_stream(query, drop_pk=True):
"""
Provided a query containing data from a Stream table,
fetch and aggregate the data into one DataFrame indexed by "time"
"""
df = (query & "sample_count > 0").fetch(format="frame").reset_index()
cols2explode = [
c for c in query.heading.secondary_attributes if query.heading.attributes[c].type == "longblob"
]
df = df.explode(column=cols2explode)
cols2drop = ["sample_count"] + (query.primary_key if drop_pk else [])
df.drop(columns=cols2drop, inplace=True, errors="ignore")
df.rename(columns={"timestamps": "time"}, inplace=True)
df.set_index("time", inplace=True)
return df


try:
from . import streams
except ImportError:
Expand Down
165 changes: 161 additions & 4 deletions aeon/dj_pipeline/acquisition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime
import pathlib

import re
import datajoint as dj
import numpy as np
import pandas as pd
Expand Down Expand Up @@ -670,7 +670,7 @@ def make(self, key):
"devices_schema_name"
),
)
device = devices_schema.ExperimentalMetadata
device = devices_schema.Environment

try:
# handles corrupted files - issue: https://github.com/SainsburyWellcomeCentre/aeon_mecha/issues/153
Expand All @@ -684,12 +684,18 @@ def make(self, key):
logger.warning("Can't read from device.MessageLog")
log_messages = pd.DataFrame()

state_messages = io_api.load(
env_states = io_api.load(
root=raw_data_dir.as_posix(),
reader=device.EnvironmentState,
start=pd.Timestamp(chunk_start),
end=pd.Timestamp(chunk_end),
)
block_states = io_api.load(
root=raw_data_dir.as_posix(),
reader=device.BlockState,
start=pd.Timestamp(chunk_start),
end=pd.Timestamp(chunk_end),
)

self.insert1(key)
self.Message.insert(
Expand All @@ -712,13 +718,147 @@ def make(self, key):
"message": r.state,
"message_type": "EnvironmentState",
}
for _, r in state_messages.iterrows()
for _, r in env_states.iterrows()
),
skip_duplicates=True,
)


# ------------------- ENVIRONMENT --------------------


@schema
class Environment(dj.Imported):
definition = """ # Experiment environments
-> Chunk
"""

class EnvironmentState(dj.Part):
definition = """
-> master
---
sample_count: int # number of data points acquired from this stream for a given chunk
timestamps: longblob # (datetime) timestamps
state: longblob
"""

class BlockState(dj.Part):
definition = """
-> master
---
sample_count: int # number of data points acquired from this stream for a given chunk
timestamps: longblob # (datetime) timestamps
pellet_ct: longblob
pellet_ct_thresh: longblob
due_time: longblob
"""

class LightEvents(dj.Part):
definition = """
-> master
---
sample_count: int # number of data points acquired from this stream for a given chunk
timestamps: longblob # (datetime) timestamps
channel: longblob
value: longblob
"""

class MessageLog(dj.Part):
definition = """
-> master
---
sample_count: int # number of data points acquired from this stream for a given chunk
timestamps: longblob # (datetime)
priority: longblob
type: longblob
message: longblob
"""

class SubjectState(dj.Part):
definition = """
-> master
---
sample_count: int # number of data points acquired from this stream for a given chunk
timestamps: longblob # (datetime) timestamps
id: longblob
weight: longblob
type: longblob
"""

class SubjectVisits(dj.Part):
definition = """
-> master
---
sample_count: int # number of data points acquired from this stream for a given chunk
timestamps: longblob # (datetime) timestamps
id: longblob
type: longblob
region: longblob
"""

class SubjectWeight(dj.Part):
definition = """
-> master
---
sample_count: int # number of data points acquired from this stream for a given chunk
timestamps: longblob # (datetime) timestamps
weight: longblob
confidence: longblob
subject_id: longblob
int_id: longblob
"""

def make(self, key):
chunk_start, chunk_end = (Chunk & key).fetch1("chunk_start", "chunk_end")

# Populate the part table
raw_data_dir = Experiment.get_data_directory(key)
devices_schema = getattr(
aeon_schemas,
(Experiment.DevicesSchema & {"experiment_name": key["experiment_name"]}).fetch1(
"devices_schema_name"
),
)
device = devices_schema.Environment

self.insert1(key)

for stream_type, part_table in [
("EnvironmentState", self.EnvironmentState),
("BlockState", self.BlockState),
("LightEvents", self.LightEvents),
("MessageLog", self.MessageLog),
("SubjectState", self.SubjectState),
("SubjectVisits", self.SubjectVisits),
("SubjectWeight", self.SubjectWeight),
]:
stream_reader = getattr(device, stream_type)

stream_data = io_api.load(
root=raw_data_dir.as_posix(),
reader=stream_reader,
start=pd.Timestamp(chunk_start),
end=pd.Timestamp(chunk_end),
)

part_table.insert1(
{
**key,
"sample_count": len(stream_data),
"timestamps": stream_data.index.values,
**{
re.sub(r"\([^)]*\)", "", c): stream_data[c].values
for c in stream_reader.columns
if not c.startswith("_")
},
},
ignore_extra_fields=True,
)


# ------------------- EVENTS --------------------


@schema
class FoodPatchEvent(dj.Imported):
definition = """ # events associated with a given ExperimentFoodPatch
Expand Down Expand Up @@ -1220,3 +1360,20 @@ def _load_legacy_subjectdata(experiment_name, data_dir, start, end):
subject_data.sort_index(inplace=True)

return subject_data


def create_chunk_restriction(experiment_name, start_time, end_time):
"""
Create a time restriction string for the chunks between the specified "start" and "end" times
"""
start_restriction = f'"{start_time}" BETWEEN chunk_start AND chunk_end'
end_restriction = f'"{end_time}" BETWEEN chunk_start AND chunk_end'
start_query = Chunk & {"experiment_name": experiment_name} & start_restriction
end_query = Chunk & {"experiment_name": experiment_name} & end_restriction
if not (start_query and end_query):
raise ValueError(f"No Chunk found between {start_time} and {end_time}")
time_restriction = (
f'chunk_start >= "{min(start_query.fetch("chunk_start"))}"'
f' AND chunk_start < "{max(end_query.fetch("chunk_end"))}"'
)
return time_restriction
Loading

0 comments on commit c0afa90

Please sign in to comment.