Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create a general function to assign events to subjects #335

Open
jkbhagatio opened this issue Feb 15, 2024 · 0 comments
Open

Create a general function to assign events to subjects #335

jkbhagatio opened this issue Feb 15, 2024 · 0 comments
Assignees

Comments

@jkbhagatio
Copy link
Member

Currently, I have some logic to assign pellets and wheel values to subjects. This could be made into a general function such that any event of interest can be assigned to a subject when not a priori known.

Example of current logic:

# <s Get pose-tracking info in order to do subject-specific assignments
pose_df = aeon.load(block.root, social02.CameraTop.Pose, block.start, block.end)
pose_df = reader.Pose.class_int2str(pose_df, block.sleap_model_dir)
if len(subjects) == 1:  # fix mistaken sleap assignments for single-subject blocks
    pose_df["class"] = subjects[0]
# /s>
# <s Get per patch data (fill in `patch_info`, `cum_wheel_dist`, `pellet_info` cols of `blocks_df`)
patch_stats_df = pd.DataFrame(index=patches, columns=["mean", "offset"])  # -> patch_info
cum_wheel_dist_dm = DotMap()  # -> cum_wheel_dist
pellets_stats_df = pd.DataFrame(columns=["time", "patch", "threshold", "id"])  # -> pellet_info
for i, patch in enumerate(patches):
    # <ss Get wheel data
    r = eval(f"social02.{patch}.Encoder")
    wheel_df = aeon.load(block.root, r, block.start, block.end)[::50].round(1).astype(np.float32)
    cum_wheel_dist = -distancetravelled(wheel_df.angle)
    # /ss>
    # <ss Get pellets data
    r = eval(f"social02.{patch}.DepletionState")
    patch_df = aeon.load(block.root, r, block.start, block.end)
    rate, offset = patch_df[["rate", "offset"]].iloc[0]
    patch_stats_df.loc[patch, ["mean", "offset"]] = (1 / rate // 100 * 100, offset)
    patch_df_good_indxs = np.concatenate((np.diff(patch_df.index) > pd.Timedelta("1s"), (True,)))
    patch_df_for_pellets_df = patch_df[patch_df_good_indxs].reset_index()[["time", "threshold"]]
    patch_df_for_pellets_df["patch"] = patch
    patch_df_for_pellets_df["id"] = None
    patch_df_for_pellets_df.dropna(subset=["threshold"], inplace=True)
    # drop 1st val as is from block start
    patch_df_for_pellets_df = patch_df_for_pellets_df.iloc[1:].reset_index(drop=True)
    # /ss>
    # <ss Assign data to subjects
    if len(subjects) == 1:
        cum_wheel_dist_dm[patch] = cum_wheel_dist.to_frame(name=subjects[0])
        patch_df_for_pellets_df["id"] = subjects[0]
    else:
        # <sss Assign id based on which subject was closest to patch at time of event
        # <ssss Get distance-to-patch at each pose data timestep
        patch_xy = np.array(patch_locs[patch][arena]).astype(np.uint32)
        subjects_xy = pose_df[pose_df["part"] == "centroid"][["x", "y"]].values
        dist_to_patch = np.sqrt(np.sum((subjects_xy - patch_xy) ** 2, axis=1))
        dist_to_patch_df = pose_df[["class"]].copy()
        dist_to_patch_df["dist_to_patch"] = dist_to_patch
        # /ssss>
        # <ssss Get distance-to-patch at each wheel ts and pel del ts, organized by subject
        dist_to_patch_wheel_ts_id_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subjects)
        dist_to_patch_pel_ts_id_df = pd.DataFrame(
            index=patch_df_for_pellets_df["time"], columns=subjects
        )
        for subject in subjects:
            # Find closest match between pose_df indices and wheel indices
            dist_to_patch_wheel_ts_subj = pd.merge_asof(
                left=dist_to_patch_wheel_ts_id_df[subject],
                right=dist_to_patch_df[dist_to_patch_df["class"] == subject],
                left_index=True,
                right_index=True,
                direction="forward",
                tolerance=pd.Timedelta("100ms"),
            )
            dist_to_patch_wheel_ts_id_df[subject] = dist_to_patch_wheel_ts_subj["dist_to_patch"]
            # Find closest match between pose_df indices and pel indices
            dist_to_patch_pel_ts_subj = pd.merge_asof(
                left=dist_to_patch_pel_ts_id_df[subject],
                right=dist_to_patch_df[dist_to_patch_df["class"] == subject],
                left_index=True,
                right_index=True,
                direction="forward",
                tolerance=pd.Timedelta("200ms"),
            )
            dist_to_patch_pel_ts_id_df[subject] = dist_to_patch_pel_ts_subj["dist_to_patch"]
        # /ssss>
        # <ssss Get closest subject to patch at each pel del timestep
        patch_df_for_pellets_df["id"] = dist_to_patch_pel_ts_id_df.idxmin(axis=1).values
        # /ssss>
        # <ssss Get closest subject to patch at each wheel timestep
        cum_wheel_dist_subj_df = pd.DataFrame(index=cum_wheel_dist.index, columns=subjects, data=0.)
        closest_subjects = dist_to_patch_wheel_ts_id_df.idxmin(axis=1)
        wheel_dist = cum_wheel_dist.diff().fillna(cum_wheel_dist.iloc[0])
        # Assign wheel dist to closest subject for each wheel timestep
        for subject in subjects:
            subj_idxs = cum_wheel_dist_subj_df[closest_subjects == subject].index
            cum_wheel_dist_subj_df.loc[subj_idxs, subject] = wheel_dist[subj_idxs]
        cum_wheel_dist_dm[patch] = cum_wheel_dist_subj_df.cumsum(axis=0)
        # /ssss> #/sss> #/ss>
    pellets_stats_df = pd.concat([pellets_stats_df, patch_df_for_pellets_df], ignore_index=True)
@jkbhagatio jkbhagatio added this to the Social0.2 Ongoing milestone Feb 26, 2024
@jkbhagatio jkbhagatio self-assigned this Feb 26, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant