Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix tracking evaluation #243

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Pytest configuration file."""

pytest_plugins = [
"tests.fixtures.integration",
"tests.fixtures.frame_extraction",
]
23 changes: 18 additions & 5 deletions crabs/tracker/evaluate_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class TrackerEvaluate:

def __init__(
self,
input_video_file_root: str,
gt_dir: str, # annotations_file
predicted_boxes_dict: dict,
iou_threshold: float,
Expand All @@ -30,6 +31,8 @@ def __init__(

Parameters
----------
input_video_file_root : str
Filename without extension to the input video file.
gt_dir : str
Directory path of the ground truth CSV file.
predicted_boxes_dict : dict
Expand All @@ -45,6 +48,7 @@ def __init__(
Path to the directory where the tracking output will be saved.

"""
self.input_video_file_root = input_video_file_root
self.gt_dir = gt_dir
self.predicted_boxes_dict = predicted_boxes_dict
self.iou_threshold = iou_threshold
Expand Down Expand Up @@ -384,11 +388,16 @@ def evaluate_tracking(
"MOTA": [],
}

for frame_number in sorted(ground_truth_dict.keys()):
for frame_index, frame_number in enumerate(
sorted(ground_truth_dict.keys())
):
# assuming all frames have GT data
gt_data_frame = ground_truth_dict[frame_number]

if frame_number < len(predicted_dict):
pred_data_frame = predicted_dict[frame_number]
if frame_number <= len(predicted_dict):
pred_data_frame = predicted_dict[
frame_index
] # 0-based indexing

(
mota,
Expand All @@ -405,15 +414,19 @@ def evaluate_tracking(
prev_frame_id_map,
)
mota_values.append(mota)
results["Frame Number"].append(frame_number)
results["Frame Number"].append(
frame_number
) # TODO: change to index!
results["Total Ground Truth"].append(total_gt)
results["True Positives"].append(true_positives)
results["Missed Detections"].append(missed_detections)
results["False Positives"].append(false_positives)
results["Number of Switches"].append(num_switches)
results["MOTA"].append(mota)

save_tracking_mota_metrics(self.tracking_output_dir, results)
save_tracking_mota_metrics(
self.tracking_output_dir, self.input_video_file_root, results
)

return mota_values

Expand Down
1 change: 1 addition & 0 deletions crabs/tracker/track_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ def detect_and_track_video(self) -> None:
# Evaluate tracker if ground truth is passed
if self.args.annotations_file:
evaluation = TrackerEvaluate(
self.input_video_file_root,
self.args.annotations_file,
tracked_bboxes_dict,
self.config["iou_threshold"],
Expand Down
5 changes: 4 additions & 1 deletion crabs/tracker/utils/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,12 @@ def extract_bounding_box_info(row: list[str]) -> dict[str, Any]:

def save_tracking_mota_metrics(
tracking_output_dir: Path,
input_video_file_root: str,
track_results: dict[str, Any],
) -> None:
"""Save tracking metrics to a CSV file."""
track_df = pd.DataFrame(track_results)
output_filename = f"{tracking_output_dir}/tracking_metrics_output.csv"
output_filename = (
f"{tracking_output_dir}/{input_video_file_root}_tracking_metrics.csv"
)
track_df.to_csv(output_filename, index=False)
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ dev = [
"ruff",
"setuptools_scm",
"check-manifest",
"pooch",
"tqdm",
# "codespell",
# "pandas-stubs",
# "types-attrs",
Expand Down
81 changes: 81 additions & 0 deletions tests/fixtures/integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Pytest fixtures for integration tests."""

from datetime import datetime
from pathlib import Path

import pooch
import pytest

GIN_TEST_DATA_REPO = "https://gin.g-node.org/SainsburyWellcomeCentre/crabs-exploration-test-data"


# @pytest.fixture(autouse=True)
# def mock_home_directory(monkeypatch: pytest.MonkeyPatch):
# """Monkeypatch pathlib.Path.home().

# Instead of returning the usual home path, the
# monkeypatched version returns the path to
# Path.home() / ".mock-home". This
# is to avoid local tests interfering with the
# potentially existing user data on the same machine.

# Parameters
# ----------
# monkeypatch : pytest.MonkeyPatch
# a monkeypatch fixture

# """
# # define mock home path
# home_path = Path.home() # actual home path
# mock_home_path = home_path / ".mock-home"

# # create mock home directory if it doesn't exist
# if not mock_home_path.exists():
# mock_home_path.mkdir()

# # monkeypatch Path.home() to point to the mock home
# def mock_home():
# return mock_home_path

# monkeypatch.setattr(Path, "home", mock_home)


@pytest.fixture(scope="session")
def pooch_registry() -> dict:
"""Pooch registry for the test data.

This fixture is common for all the test session. The
file registry is downloaded fresh for every test session.

Returns
-------
dict
URL and hash of the GIN repository with the test data

"""
# Use pytest fixture? Should it be wiped out after a session?
# Initialise a pooch registry for the test data
registry = pooch.create(
Path.home() / ".crabs-exploration-test-data",
base_url=f"{GIN_TEST_DATA_REPO}/raw/master/test_data",
)

# Download the registry file from GIN to the pooch cache
# force to download it every time by using a timestamped file name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_registry = pooch.retrieve(
url=f"{GIN_TEST_DATA_REPO}/raw/master/files-registry.txt",
known_hash=None,
fname=f"files-registry_{timestamp}.txt",
path=Path.home() / ".crabs-exploration-test-data",
)

# Load registry file onto pooch registry
registry.load_registry(
file_registry,
)

# Remove registry file
Path(file_registry).unlink()

return registry
166 changes: 166 additions & 0 deletions tests/test_integration/test_detect_and_track.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
import re
import subprocess
from pathlib import Path

import cv2
import pooch
import pytest

from crabs.tracker.utils.io import open_video


@pytest.fixture()
def input_data_paths(pooch_registry: pooch.Pooch):
"""Fixture to get the input data for a detector+tracking run.

Returns
-------
dict
Dictionary with the paths to the input video, annotations,
config.

"""
input_data_paths = {}
video_root_name = "04.09.2023-04-Right_RE_test_3_frames"
input_data_paths["video_root_name"] = video_root_name

# get trained model from pooch registry
list_files_ml_runs = pooch_registry.fetch(
"ml-runs.zip",
processor=pooch.Unzip(
extract_dir="",
),
progressbar=True,
)
input_data_paths["ckpt"] = [
file for file in list_files_ml_runs if file.endswith("last.ckpt")
][0]

# get input video, annotations and config
input_data_paths["video"] = pooch_registry.fetch(
f"{video_root_name}/{video_root_name}.mp4"
)
input_data_paths["annotations"] = pooch_registry.fetch(
f"{video_root_name}/{video_root_name}_ground_truth.csv"
)
input_data_paths["tracking_config"] = pooch_registry.fetch(
f"{video_root_name}/tracking_config.yaml"
)

return input_data_paths


@pytest.mark.parametrize(
"flags_to_append",
[
[],
["--save_video"],
["--save_frames"],
["--save_video", "--save_frames"],
],
)
def test_detect_and_track_video(
input_data_paths: dict, tmp_path: Path, flags_to_append: list
):
"""Test the detect-and-track-video entry point.

Checks:
- status code of the command
- existence of csv file with predictions
- existence of csv file with tracking metrics
- existence of video file if requested
- existence of exported frames if requested
- MOTA score is as expected

"""
# # get expected output
# path_to_tracked_boxes = pooch_registry.fetch(
# f"{sample_video_dir}/04.09.2023-04-Right_RE_test_3_frames_tracks.csv"
# )
# path_to_tracking_metrics = pooch_registry.fetch(
# f"{video_root_name}/tracking_metrics_output.csv"
# )

# run detect-and-track-video with the test data
main_command = [
"detect-and-track-video",
f"--trained_model_path={input_data_paths['ckpt']}",
f"--video_path={input_data_paths['video']}",
f"--config_file={input_data_paths['tracking_config']}",
f"--annotations_file={input_data_paths['annotations']}",
"--accelerator=cpu",
# f"--output_dir={tmp_path}",
]
main_command.extend(flags_to_append)
completed_process = subprocess.run(
main_command,
check=True,
cwd=tmp_path, # set cwd to pytest tmpdir if no output_dir is passed
)

# check the command runs successfully
assert completed_process.returncode == 0

# check the tracking output directory is created
pattern = re.compile(r"tracking_output_\d{8}_\d{6}")
list_subdirs = [x for x in tmp_path.iterdir() if x.is_dir()]
tracking_output_dir = list_subdirs[0]
assert len(list_subdirs) == 1
assert pattern.match(tracking_output_dir.stem)

# check csv with predictions exists
predictions_csv = (
tmp_path
/ tracking_output_dir
/ f"{input_data_paths['video_root_name']}_tracks.csv"
)
assert (predictions_csv).exists()

# check csv with tracking metrics exists
tracking_metrics_csv = (
tmp_path / tracking_output_dir / "tracking_metrics_output.csv"
)
assert (tracking_metrics_csv).exists()

# check content of tracking metrics csv is as expected
# # read the csv
# tracking_metrics_df = pd.read_csv(tracking_metrics_csv)
# expected_tracking_metrics_df = pd.read_csv(path_to_tracking_metrics)

# # assert dataframes are the same
# pd.testing.assert_frame_equal(
# tracking_metrics_df, expected_tracking_metrics_df
# )

# if the video is requested: check it exists
if "--save_video" in flags_to_append:
assert (
tmp_path
/ tracking_output_dir
/ f"{input_data_paths['video_root_name']}_tracks.mp4"
).exists()

# if the frames are requested: check they exist
if "--save_frames" in flags_to_append:
input_video_object = open_video(input_data_paths["video"])
total_n_frames = int(input_video_object.get(cv2.CAP_PROP_FRAME_COUNT))

# check subdirectory exists
frames_subdir = (
tmp_path
/ tracking_output_dir
/ f"{input_data_paths['video_root_name']}_frames"
)
assert frames_subdir.exists()

# check files
pattern = re.compile(r"frame_\d{8}.png")
list_files = [x for x in frames_subdir.iterdir() if x.is_file()]

assert len(list_files) == total_n_frames
assert all(pattern.match(x.name) for x in list_files)

# check the MOTA score is as expected
# capture logs
# INFO:root:All 3 frames processed
# INFO:root:Overall MOTA: 0.860465
3 changes: 2 additions & 1 deletion tests/test_unit/test_evaluate_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
def tracker_evaluate_interface():
annotations_file_csv = Path(__file__).parents[1] / "data" / "gt_test.csv"
return TrackerEvaluate(
annotations_file_csv,
input_video_file_root="/path/to/video.mp4",
gt_dir=annotations_file_csv,
predicted_boxes_dict={},
iou_threshold=0.1,
tracking_output_dir="/path/output",
Expand Down
3 changes: 2 additions & 1 deletion tests/test_unit/test_track_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@ def mock_args():
config_file="/path/to/config.yaml",
video_path="/path/to/video.mp4",
trained_model_path="path/to/model.ckpt",
output_dir=tmp_dir,
accelerator="gpu",
output_dir=tmp_dir,
output_dir_no_timestamp=None,
annotations_file=None,
save_video=None,
save_frames=None,
Expand Down
Loading