Skip to content

Commit

Permalink
Refactored detection+tracking (#241)
Browse files Browse the repository at this point in the history
* Refactored detection+tracking. Same MOTA on gt clip as reference (but tracker eval needs review separately)

* Adapt tests

* Fix test format SORT

* Factor out video opening

* Make predictions dict and tracked predictions dict more consistent to avoid reformatting

* Fix tests
  • Loading branch information
sfmig authored Nov 8, 2024
1 parent 9b4722e commit e371bf9
Show file tree
Hide file tree
Showing 8 changed files with 832 additions and 541 deletions.
117 changes: 72 additions & 45 deletions crabs/tracker/evaluate_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,26 @@

import csv
import logging
from pathlib import Path
from typing import Any, Optional

import numpy as np

from crabs.tracker.utils.tracking import extract_bounding_box_info
from crabs.tracker.utils.tracking import (
extract_bounding_box_info,
save_tracking_mota_metrics,
)


class TrackerEvaluate:
"""Interface to evaluate tracker."""

def __init__(
self,
gt_dir: str,
predicted_boxes_id: list[np.ndarray],
gt_dir: str, # annotations_file
predicted_boxes_dict: dict,
iou_threshold: float,
tracking_output_dir: Path,
):
"""Initialize the TrackerEvaluate class.
Expand All @@ -27,47 +32,25 @@ def __init__(
----------
gt_dir : str
Directory path of the ground truth CSV file.
predicted_boxes_id : list[np.ndarray]
List of numpy arrays containing predicted bounding boxes and IDs.
predicted_boxes_dict : dict
Dictionary mapping frame indices to bounding boxes arrays
(under "tracked_boxes"), ids (under "ids") and detection scores
(under "scores"). The bounding boxes array have shape (n, 4) where
n is the number of boxes in the frame and the 4 columns are (xmin,
ymin, xmax, ymax).
iou_threshold : float
Intersection over Union (IoU) threshold for evaluating
tracking performance.
tracking_output_dir : Path
Path to the directory where the tracking output will be saved.
"""
self.gt_dir = gt_dir
self.predicted_boxes_id = predicted_boxes_id
self.predicted_boxes_dict = predicted_boxes_dict
self.iou_threshold = iou_threshold
self.tracking_output_dir = tracking_output_dir
self.last_known_predicted_ids: dict = {}

def get_predicted_data(self) -> dict[int, dict[str, Any]]:
"""Format predicted bounding box and ID as dictionary.
Dictionary keys are frame numbers.
Returns
-------
dict[int, dict[str, Any]]:
A dictionary where the key is the frame number and the value is
another dictionary containing:
- 'bbox': A numpy array with shape (N, 4) containing coordinates
of the bounding boxes [x, y, x + width, y + height] for every
object in the frame.
- 'id': A numpy array containing the IDs of the tracked objects.
"""
predicted_dict: dict[int, dict[str, Any]] = {}

for frame_number, frame_data in enumerate(self.predicted_boxes_id):
if frame_data.size == 0:
continue

bboxes = frame_data[:, :4]
ids = frame_data[:, 4]

predicted_dict[frame_number] = {"bbox": bboxes, "id": ids}

return predicted_dict

def get_ground_truth_data(self) -> dict[int, dict[str, Any]]:
"""Fromat ground truth bounding box data as dict with key frame number.
Expand All @@ -82,6 +65,8 @@ def get_ground_truth_data(self) -> dict[int, dict[str, Any]]:
- 'id': The ground truth ID
"""
# TODO: refactor with pandas

with open(self.gt_dir) as csvfile:
csvreader = csv.reader(csvfile)
next(csvreader) # Skip the header row
Expand All @@ -91,7 +76,10 @@ def get_ground_truth_data(self) -> dict[int, dict[str, Any]]:

# Format as a dictionary with key = frame number
ground_truth_dict: dict = {}

# loop thru annotations
for data in ground_truth_data:
# Get frame, bbox, id
frame_number = data["frame_number"]
bbox = np.array(
[
Expand All @@ -104,9 +92,11 @@ def get_ground_truth_data(self) -> dict[int, dict[str, Any]]:
)
track_id = int(float(data["id"]))

# If frame does not exist in dict: initialise
if frame_number not in ground_truth_dict:
ground_truth_dict[frame_number] = {"bbox": [], "id": []}

# Append bbox and id to the dictionary
ground_truth_dict[frame_number]["bbox"].append(bbox)
ground_truth_dict[frame_number]["id"].append(track_id)

Expand Down Expand Up @@ -264,13 +254,13 @@ def count_identity_switches( # noqa: C901

return switch_counter

def evaluate_mota(
def compute_mota_one_frame(
self,
gt_data: dict[str, np.ndarray],
pred_data: dict[str, np.ndarray],
iou_threshold: float,
gt_to_tracked_id_previous_frame: Optional[dict[int, int]],
) -> tuple[float, dict[int, int]]:
) -> tuple[float, int, int, int, int, int, dict[int, int]]:
"""Evaluate MOTA (Multiple Object Tracking Accuracy).
Parameters
Expand Down Expand Up @@ -301,11 +291,12 @@ def evaluate_mota(
"""
total_gt = len(gt_data["bbox"])
false_positive = 0
true_positive = 0
indices_of_matched_gt_boxes = set()
gt_to_tracked_id_current_frame = {}

pred_boxes = pred_data["bbox"]
pred_ids = pred_data["id"]
pred_boxes = pred_data["tracked_boxes"]
pred_ids = pred_data["ids"]

gt_boxes = gt_data["bbox"]
gt_ids = gt_data["id"]
Expand All @@ -325,6 +316,7 @@ def evaluate_mota(
index_gt_not_match = j

if index_gt_best_match is not None:
true_positive += 1
# Successfully found a matching ground truth box for the
# tracked box.
indices_of_matched_gt_boxes.add(index_gt_best_match)
Expand All @@ -347,8 +339,15 @@ def evaluate_mota(
mota = (
1 - (missed_detections + false_positive + num_switches) / total_gt
)

return mota, gt_to_tracked_id_current_frame
return (
mota,
true_positive,
missed_detections,
false_positive,
num_switches,
total_gt,
gt_to_tracked_id_current_frame,
)

def evaluate_tracking(
self,
Expand All @@ -364,7 +363,7 @@ def evaluate_tracking(
frame, organized by frame number.
predicted_dict : dict
Dictionary containing predicted bounding boxes and IDs for each
frame, organized by frame number.
frame, organized by frame _index_.
Returns
-------
Expand All @@ -375,27 +374,55 @@ def evaluate_tracking(
"""
mota_values = []
prev_frame_id_map: Optional[dict] = None
results: dict[str, Any] = {
"Frame Number": [],
"Total Ground Truth": [],
"True Positives": [],
"Missed Detections": [],
"False Positives": [],
"Number of Switches": [],
"MOTA": [],
}

for frame_number in sorted(ground_truth_dict.keys()):
gt_data_frame = ground_truth_dict[frame_number]

if frame_number < len(predicted_dict):
pred_data_frame = predicted_dict[frame_number]
mota, prev_frame_id_map = self.evaluate_mota(

(
mota,
true_positives,
missed_detections,
false_positives,
num_switches,
total_gt,
prev_frame_id_map,
) = self.compute_mota_one_frame(
gt_data_frame,
pred_data_frame,
self.iou_threshold,
prev_frame_id_map,
)
mota_values.append(mota)
results["Frame Number"].append(frame_number)
results["Total Ground Truth"].append(total_gt)
results["True Positives"].append(true_positives)
results["Missed Detections"].append(missed_detections)
results["False Positives"].append(false_positives)
results["Number of Switches"].append(num_switches)
results["MOTA"].append(mota)

save_tracking_mota_metrics(self.tracking_output_dir, results)

return mota_values

def run_evaluation(self) -> None:
"""Run evaluation of tracking based on tracking ground truth."""
predicted_dict = self.get_predicted_data()
ground_truth_dict = self.get_ground_truth_data()
mota_values = self.evaluate_tracking(ground_truth_dict, predicted_dict)
mota_values = self.evaluate_tracking(
ground_truth_dict, self.predicted_boxes_dict
)

overall_mota = np.mean(mota_values)
logging.info("Overall MOTA: %f" % overall_mota) # noqa: UP031
Loading

0 comments on commit e371bf9

Please sign in to comment.