diff --git a/conftest.py b/conftest.py index f762c948..5296bee6 100644 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,6 @@ """Pytest configuration file.""" pytest_plugins = [ + "tests.fixtures.integration", "tests.fixtures.frame_extraction", ] diff --git a/crabs/tracker/evaluate_tracker.py b/crabs/tracker/evaluate_tracker.py index 088658ef..ed5ac5a8 100644 --- a/crabs/tracker/evaluate_tracker.py +++ b/crabs/tracker/evaluate_tracker.py @@ -18,6 +18,7 @@ class TrackerEvaluate: def __init__( self, + input_video_file_root: str, gt_dir: str, # annotations_file predicted_boxes_dict: dict, iou_threshold: float, @@ -30,6 +31,8 @@ def __init__( Parameters ---------- + input_video_file_root : str + Filename without extension to the input video file. gt_dir : str Directory path of the ground truth CSV file. predicted_boxes_dict : dict @@ -45,6 +48,7 @@ def __init__( Path to the directory where the tracking output will be saved. """ + self.input_video_file_root = input_video_file_root self.gt_dir = gt_dir self.predicted_boxes_dict = predicted_boxes_dict self.iou_threshold = iou_threshold @@ -384,11 +388,16 @@ def evaluate_tracking( "MOTA": [], } - for frame_number in sorted(ground_truth_dict.keys()): + for frame_index, frame_number in enumerate( + sorted(ground_truth_dict.keys()) + ): + # assuming all frames have GT data gt_data_frame = ground_truth_dict[frame_number] - if frame_number < len(predicted_dict): - pred_data_frame = predicted_dict[frame_number] + if frame_number <= len(predicted_dict): + pred_data_frame = predicted_dict[ + frame_index + ] # 0-based indexing ( mota, @@ -405,7 +414,9 @@ def evaluate_tracking( prev_frame_id_map, ) mota_values.append(mota) - results["Frame Number"].append(frame_number) + results["Frame Number"].append( + frame_number + ) # TODO: change to index! results["Total Ground Truth"].append(total_gt) results["True Positives"].append(true_positives) results["Missed Detections"].append(missed_detections) @@ -413,7 +424,9 @@ def evaluate_tracking( results["Number of Switches"].append(num_switches) results["MOTA"].append(mota) - save_tracking_mota_metrics(self.tracking_output_dir, results) + save_tracking_mota_metrics( + self.tracking_output_dir, self.input_video_file_root, results + ) return mota_values diff --git a/crabs/tracker/track_video.py b/crabs/tracker/track_video.py index 81ad2cef..54febcda 100644 --- a/crabs/tracker/track_video.py +++ b/crabs/tracker/track_video.py @@ -305,6 +305,7 @@ def detect_and_track_video(self) -> None: # Evaluate tracker if ground truth is passed if self.args.annotations_file: evaluation = TrackerEvaluate( + self.input_video_file_root, self.args.annotations_file, tracked_bboxes_dict, self.config["iou_threshold"], diff --git a/crabs/tracker/utils/tracking.py b/crabs/tracker/utils/tracking.py index 514e65ad..3513c7e8 100644 --- a/crabs/tracker/utils/tracking.py +++ b/crabs/tracker/utils/tracking.py @@ -82,9 +82,12 @@ def extract_bounding_box_info(row: list[str]) -> dict[str, Any]: def save_tracking_mota_metrics( tracking_output_dir: Path, + input_video_file_root: str, track_results: dict[str, Any], ) -> None: """Save tracking metrics to a CSV file.""" track_df = pd.DataFrame(track_results) - output_filename = f"{tracking_output_dir}/tracking_metrics_output.csv" + output_filename = ( + f"{tracking_output_dir}/{input_video_file_root}_tracking_metrics.csv" + ) track_df.to_csv(output_filename, index=False) diff --git a/pyproject.toml b/pyproject.toml index 300419d6..6e9d803d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,8 @@ dev = [ "ruff", "setuptools_scm", "check-manifest", + "pooch", + "tqdm", # "codespell", # "pandas-stubs", # "types-attrs", diff --git a/tests/fixtures/integration.py b/tests/fixtures/integration.py new file mode 100644 index 00000000..9b951b91 --- /dev/null +++ b/tests/fixtures/integration.py @@ -0,0 +1,81 @@ +"""Pytest fixtures for integration tests.""" + +from datetime import datetime +from pathlib import Path + +import pooch +import pytest + +GIN_TEST_DATA_REPO = "https://gin.g-node.org/SainsburyWellcomeCentre/crabs-exploration-test-data" + + +# @pytest.fixture(autouse=True) +# def mock_home_directory(monkeypatch: pytest.MonkeyPatch): +# """Monkeypatch pathlib.Path.home(). + +# Instead of returning the usual home path, the +# monkeypatched version returns the path to +# Path.home() / ".mock-home". This +# is to avoid local tests interfering with the +# potentially existing user data on the same machine. + +# Parameters +# ---------- +# monkeypatch : pytest.MonkeyPatch +# a monkeypatch fixture + +# """ +# # define mock home path +# home_path = Path.home() # actual home path +# mock_home_path = home_path / ".mock-home" + +# # create mock home directory if it doesn't exist +# if not mock_home_path.exists(): +# mock_home_path.mkdir() + +# # monkeypatch Path.home() to point to the mock home +# def mock_home(): +# return mock_home_path + +# monkeypatch.setattr(Path, "home", mock_home) + + +@pytest.fixture(scope="session") +def pooch_registry() -> dict: + """Pooch registry for the test data. + + This fixture is common for all the test session. The + file registry is downloaded fresh for every test session. + + Returns + ------- + dict + URL and hash of the GIN repository with the test data + + """ + # Use pytest fixture? Should it be wiped out after a session? + # Initialise a pooch registry for the test data + registry = pooch.create( + Path.home() / ".crabs-exploration-test-data", + base_url=f"{GIN_TEST_DATA_REPO}/raw/master/test_data", + ) + + # Download the registry file from GIN to the pooch cache + # force to download it every time by using a timestamped file name + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + file_registry = pooch.retrieve( + url=f"{GIN_TEST_DATA_REPO}/raw/master/files-registry.txt", + known_hash=None, + fname=f"files-registry_{timestamp}.txt", + path=Path.home() / ".crabs-exploration-test-data", + ) + + # Load registry file onto pooch registry + registry.load_registry( + file_registry, + ) + + # Remove registry file + Path(file_registry).unlink() + + return registry diff --git a/tests/test_integration/test_detect_and_track.py b/tests/test_integration/test_detect_and_track.py new file mode 100644 index 00000000..35e12da9 --- /dev/null +++ b/tests/test_integration/test_detect_and_track.py @@ -0,0 +1,166 @@ +import re +import subprocess +from pathlib import Path + +import cv2 +import pooch +import pytest + +from crabs.tracker.utils.io import open_video + + +@pytest.fixture() +def input_data_paths(pooch_registry: pooch.Pooch): + """Fixture to get the input data for a detector+tracking run. + + Returns + ------- + dict + Dictionary with the paths to the input video, annotations, + config. + + """ + input_data_paths = {} + video_root_name = "04.09.2023-04-Right_RE_test_3_frames" + input_data_paths["video_root_name"] = video_root_name + + # get trained model from pooch registry + list_files_ml_runs = pooch_registry.fetch( + "ml-runs.zip", + processor=pooch.Unzip( + extract_dir="", + ), + progressbar=True, + ) + input_data_paths["ckpt"] = [ + file for file in list_files_ml_runs if file.endswith("last.ckpt") + ][0] + + # get input video, annotations and config + input_data_paths["video"] = pooch_registry.fetch( + f"{video_root_name}/{video_root_name}.mp4" + ) + input_data_paths["annotations"] = pooch_registry.fetch( + f"{video_root_name}/{video_root_name}_ground_truth.csv" + ) + input_data_paths["tracking_config"] = pooch_registry.fetch( + f"{video_root_name}/tracking_config.yaml" + ) + + return input_data_paths + + +@pytest.mark.parametrize( + "flags_to_append", + [ + [], + ["--save_video"], + ["--save_frames"], + ["--save_video", "--save_frames"], + ], +) +def test_detect_and_track_video( + input_data_paths: dict, tmp_path: Path, flags_to_append: list +): + """Test the detect-and-track-video entry point. + + Checks: + - status code of the command + - existence of csv file with predictions + - existence of csv file with tracking metrics + - existence of video file if requested + - existence of exported frames if requested + - MOTA score is as expected + + """ + # # get expected output + # path_to_tracked_boxes = pooch_registry.fetch( + # f"{sample_video_dir}/04.09.2023-04-Right_RE_test_3_frames_tracks.csv" + # ) + # path_to_tracking_metrics = pooch_registry.fetch( + # f"{video_root_name}/tracking_metrics_output.csv" + # ) + + # run detect-and-track-video with the test data + main_command = [ + "detect-and-track-video", + f"--trained_model_path={input_data_paths['ckpt']}", + f"--video_path={input_data_paths['video']}", + f"--config_file={input_data_paths['tracking_config']}", + f"--annotations_file={input_data_paths['annotations']}", + "--accelerator=cpu", + # f"--output_dir={tmp_path}", + ] + main_command.extend(flags_to_append) + completed_process = subprocess.run( + main_command, + check=True, + cwd=tmp_path, # set cwd to pytest tmpdir if no output_dir is passed + ) + + # check the command runs successfully + assert completed_process.returncode == 0 + + # check the tracking output directory is created + pattern = re.compile(r"tracking_output_\d{8}_\d{6}") + list_subdirs = [x for x in tmp_path.iterdir() if x.is_dir()] + tracking_output_dir = list_subdirs[0] + assert len(list_subdirs) == 1 + assert pattern.match(tracking_output_dir.stem) + + # check csv with predictions exists + predictions_csv = ( + tmp_path + / tracking_output_dir + / f"{input_data_paths['video_root_name']}_tracks.csv" + ) + assert (predictions_csv).exists() + + # check csv with tracking metrics exists + tracking_metrics_csv = ( + tmp_path / tracking_output_dir / "tracking_metrics_output.csv" + ) + assert (tracking_metrics_csv).exists() + + # check content of tracking metrics csv is as expected + # # read the csv + # tracking_metrics_df = pd.read_csv(tracking_metrics_csv) + # expected_tracking_metrics_df = pd.read_csv(path_to_tracking_metrics) + + # # assert dataframes are the same + # pd.testing.assert_frame_equal( + # tracking_metrics_df, expected_tracking_metrics_df + # ) + + # if the video is requested: check it exists + if "--save_video" in flags_to_append: + assert ( + tmp_path + / tracking_output_dir + / f"{input_data_paths['video_root_name']}_tracks.mp4" + ).exists() + + # if the frames are requested: check they exist + if "--save_frames" in flags_to_append: + input_video_object = open_video(input_data_paths["video"]) + total_n_frames = int(input_video_object.get(cv2.CAP_PROP_FRAME_COUNT)) + + # check subdirectory exists + frames_subdir = ( + tmp_path + / tracking_output_dir + / f"{input_data_paths['video_root_name']}_frames" + ) + assert frames_subdir.exists() + + # check files + pattern = re.compile(r"frame_\d{8}.png") + list_files = [x for x in frames_subdir.iterdir() if x.is_file()] + + assert len(list_files) == total_n_frames + assert all(pattern.match(x.name) for x in list_files) + + # check the MOTA score is as expected + # capture logs + # INFO:root:All 3 frames processed + # INFO:root:Overall MOTA: 0.860465 diff --git a/tests/test_unit/test_evaluate_tracker.py b/tests/test_unit/test_evaluate_tracker.py index aa6162e6..e5a79d5c 100644 --- a/tests/test_unit/test_evaluate_tracker.py +++ b/tests/test_unit/test_evaluate_tracker.py @@ -10,7 +10,8 @@ def tracker_evaluate_interface(): annotations_file_csv = Path(__file__).parents[1] / "data" / "gt_test.csv" return TrackerEvaluate( - annotations_file_csv, + input_video_file_root="/path/to/video.mp4", + gt_dir=annotations_file_csv, predicted_boxes_dict={}, iou_threshold=0.1, tracking_output_dir="/path/output", diff --git a/tests/test_unit/test_track_video.py b/tests/test_unit/test_track_video.py index 3614d6bf..21a36d66 100644 --- a/tests/test_unit/test_track_video.py +++ b/tests/test_unit/test_track_video.py @@ -16,8 +16,9 @@ def mock_args(): config_file="/path/to/config.yaml", video_path="/path/to/video.mp4", trained_model_path="path/to/model.ckpt", - output_dir=tmp_dir, accelerator="gpu", + output_dir=tmp_dir, + output_dir_no_timestamp=None, annotations_file=None, save_video=None, save_frames=None,