Skip to content

Commit

Permalink
Merge pull request #253 from SainsburyWellcomeCentre/update_pose_reader
Browse files Browse the repository at this point in the history
Update pose reader
  • Loading branch information
jkbhagatio authored Sep 14, 2023
2 parents eb306bd + aaacd53 commit 3b53953
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 19 deletions.
52 changes: 34 additions & 18 deletions aeon/schema/social.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
"""Readers for data relevant to Social experiments."""

from pathlib import Path
from typing import List, Union
import json
from pathlib import Path

import numpy as np
import pandas as pd

from aeon import util
import aeon.io.reader as _reader
from aeon import util


class Pose(_reader.Harp):
Expand All @@ -25,63 +25,79 @@ def __init__(self, pattern: str, extension: str="bin"):
# `pattern` for this reader should typically be '<hpcnode>_<jobid>*'
super().__init__(pattern, columns=None, extension=extension)

def read(self, file: Path, ceph_proc_dir: Path=Path("/ceph/aeon/aeon/data/processed")) -> pd.DataFrame:
def read(
self, file: Path, ceph_proc_dir: str | Path = "/ceph/aeon/aeon/data/processed"
) -> pd.DataFrame:
"""Reads data from the Harp-binarized tracking file."""
# Get config file from `file`, then bodyparts from config file.
model_dir = Path(file.stem.replace("_", "/")).parent
config_file_dir = ceph_proc_dir / model_dir
assert config_file_dir.exists(), f"Cannot find model dir {config_file_dir}"
if not config_file_dir.exists():
raise FileNotFoundError(f"Cannot find model dir {config_file_dir}")
config_file = get_config_file(config_file_dir)
parts = self.get_bodyparts(config_file)

# Using bodyparts, assign column names to Harp register values, and read data in default format.
columns = ["class", "class_likelihood"]
for part in parts:
columns.extend([f"{part}_x", f"{part}_y", f"{part}_likelihood"])
self.columns = columns
data = super().read(file)


# Drop any repeat parts.
unique_parts, unique_idxs = np.unique(parts, return_index=True)
repeat_idxs = np.setdiff1d(np.arange(len(parts)), unique_idxs)
if repeat_idxs: # drop x, y, and likelihood cols for repeat parts (skip first 5 cols)
init_rep_part_col_idx = (repeat_idxs - 1) * 3 + 5
rep_part_col_idxs = np.concatenate([np.arange(i, i + 3) for i in init_rep_part_col_idx])
keep_part_col_idxs = np.setdiff1d(np.arange(len(data.columns)), rep_part_col_idxs)
data = data.iloc[:, keep_part_col_idxs]
parts = unique_parts

# Set new columns, and reformat `data`.
n_parts = len(parts)
part_data_list = [None] * n_parts
part_data_list = [pd.DataFrame()] * n_parts
new_columns = ["class", "class_likelihood", "part", "x", "y", "part_likelihood"]
new_data = pd.DataFrame(columns=new_columns)
for i, part in enumerate(parts):
part_columns = ["class", "class_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"]
part_data = data[part_columns]
part_data = pd.DataFrame(data[part_columns])
part_data.insert(2, "part", part)
part_data.columns = new_columns
part_data_list[i] = part_data
new_data = pd.concat(part_data_list)
return new_data.sort_index()

def get_bodyparts(self, file: Path) -> Union[None, List[str]]:
def get_bodyparts(self, file: Path) -> list[str]:
"""Returns a list of bodyparts from a model's config file."""
parts = None
parts = []
with open(file) as f:
config = json.load(f)
if file.stem == "confmap_config": # SLEAP
try:
heads = config["model"]["heads"]
parts = util.find_nested_key(heads, "part_names")
parts = [util.find_nested_key(heads, "anchor_part")]
parts += util.find_nested_key(heads, "part_names")
except KeyError as err:
raise KeyError(f"Cannot find bodyparts in {file}.") from err
if not parts:
raise KeyError(f"Cannot find bodyparts in {file}.") from err
return parts


def get_config_file(
config_file_dir: Path,
config_file_names: List[str]=[
"confmap_config.json", # SLEAP (add others for other trackers to this list)
],
):
config_file_names: None | list[str] = None,
) -> Path:
"""Returns the config file from a model's config directory."""
if config_file_names is None:
config_file_names = ["confmap_config.json"] # SLEAP (add for other trackers to this list)
config_file = None
for f in config_file_names:
if (config_file_dir / f).exists():
config_file = config_file_dir / f
break
assert config_file is not None, f"Cannot find config file in {config_file_dir}"
if config_file is None:
raise FileNotFoundError(f"Cannot find config file in {config_file_dir}")
return config_file


Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ reportAssertAlwaysTrue = "error"
reportSelfClsParameterName = "error"
reportUnusedExpression = "error"
reportMatchNotExhaustive = "error"
reportImplicitOverride = "error"
reportShadowedImports = "error"
# *Note*: we may want to set all 'ReportOptional*' rules to "none", but leaving 'em default for now
venvPath = "."
Expand Down

0 comments on commit 3b53953

Please sign in to comment.