Skip to content

Commit

Permalink
Passed pre-commit checks
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbhagatio committed Sep 12, 2023
1 parent 298faaf commit 7458fcf
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
29 changes: 15 additions & 14 deletions aeon/schema/social.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Readers for data relevant to Social experiments."""

from pathlib import Path
from typing import List, Union
import json
from pathlib import Path

import numpy as np
import pandas as pd

from aeon import util
import aeon.io.reader as _reader
from aeon import util


class Pose(_reader.Harp):
Expand All @@ -26,15 +25,17 @@ def __init__(self, pattern: str, extension: str="bin"):
# `pattern` for this reader should typically be '<hpcnode>_<jobid>*'
super().__init__(pattern, columns=None, extension=extension)

def read(self, file: Path, ceph_proc_dir: Path=Path("/ceph/aeon/aeon/data/processed")) -> pd.DataFrame:
def read(
self, file: Path, ceph_proc_dir: str | Path = "/ceph/aeon/aeon/data/processed"
) -> pd.DataFrame:
"""Reads data from the Harp-binarized tracking file."""
# Get config file from `file`, then bodyparts from config file.
model_dir = Path(file.stem.replace("_", "/")).parent
config_file_dir = ceph_proc_dir / model_dir
assert config_file_dir.exists(), f"Cannot find model dir {config_file_dir}"
config_file = get_config_file(config_file_dir)
parts = self.get_bodyparts(config_file)

# Using bodyparts, assign column names to Harp register values, and read data in default format.
columns = ["class", "class_likelihood"]
for part in parts:
Expand All @@ -51,24 +52,24 @@ def read(self, file: Path, ceph_proc_dir: Path=Path("/ceph/aeon/aeon/data/proces
keep_part_col_idxs = np.setdiff1d(np.arange(len(data.columns)), rep_part_col_idxs)
data = data.iloc[:, keep_part_col_idxs]
parts = unique_parts

# Set new columns, and reformat `data`.
n_parts = len(parts)
part_data_list = [None] * n_parts
part_data_list = [pd.DataFrame()] * n_parts
new_columns = ["class", "class_likelihood", "part", "x", "y", "part_likelihood"]
new_data = pd.DataFrame(columns=new_columns)
for i, part in enumerate(parts):
part_columns = ["class", "class_likelihood", f"{part}_x", f"{part}_y", f"{part}_likelihood"]
part_data = data[part_columns]
part_data = pd.DataFrame(data[part_columns])
part_data.insert(2, "part", part)
part_data.columns = new_columns
part_data_list[i] = part_data
new_data = pd.concat(part_data_list)
return new_data.sort_index()

def get_bodyparts(self, file: Path) -> Union[None, List[str]]:
def get_bodyparts(self, file: Path) -> list[str]:
"""Returns a list of bodyparts from a model's config file."""
parts = None
parts = []
with open(file) as f:
config = json.load(f)
if file.stem == "confmap_config": # SLEAP
Expand All @@ -84,11 +85,11 @@ def get_bodyparts(self, file: Path) -> Union[None, List[str]]:

def get_config_file(
config_file_dir: Path,
config_file_names: List[str]=[
"confmap_config.json", # SLEAP (add others for other trackers to this list)
],
):
config_file_names: None | list[str] = None,
) -> Path:
"""Returns the config file from a model's config directory."""
if config_file_names is None:
config_file_names = ["confmap_config.json"]
config_file = None
for f in config_file_names:
if (config_file_dir / f).exists():
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ reportAssertAlwaysTrue = "error"
reportSelfClsParameterName = "error"
reportUnusedExpression = "error"
reportMatchNotExhaustive = "error"
reportImplicitOverride = "error"
reportShadowedImports = "error"
# *Note*: we may want to set all 'ReportOptional*' rules to "none", but leaving 'em default for now
venvPath = "."
Expand Down

0 comments on commit 7458fcf

Please sign in to comment.