Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MOTA revisited #181

Merged
merged 57 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
73ea420
modify id switches
nikk-nikaznan Jun 5, 2024
768b81e
change the mota and test
nikk-nikaznan Jun 5, 2024
3036b06
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 5, 2024
f18a44c
change the variable
nikk-nikaznan Jun 14, 2024
6c0b0ed
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 18, 2024
a851e9f
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 21, 2024
611bd60
some bug fixed, load from checkpoint
nikk-nikaznan Jun 24, 2024
11bded7
Merge branch 'nikkna/id_switches' of github.com:SainsburyWellcomeCent…
nikk-nikaznan Jun 24, 2024
4b9a794
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 24, 2024
cfe67ca
change list type, add gt_ids
nikk-nikaznan Jun 25, 2024
7de7561
fixed the error id switches
nikk-nikaznan Jun 25, 2024
25f36c7
changes id switches
nikk-nikaznan Jun 25, 2024
de0b9cd
fixing some test and type hint
nikk-nikaznan Jun 27, 2024
5e56e1a
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 27, 2024
6aa1b63
fixing test, parametrize the test with additional test
nikk-nikaznan Jun 28, 2024
05faa18
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 28, 2024
a592c1a
cleane dup
nikk-nikaznan Jun 28, 2024
eefcf33
checking some test
nikk-nikaznan Jun 28, 2024
f0bcf65
rebase
nikk-nikaznan Jun 28, 2024
da0c62b
cleaned up
nikk-nikaznan Jun 28, 2024
01cb0f4
test works
nikk-nikaznan Jun 28, 2024
3ca1872
test works
nikk-nikaznan Jun 28, 2024
4d32f77
aded specific example
nikk-nikaznan Jun 28, 2024
ff059ea
some more test
nikk-nikaznan Jun 28, 2024
1195854
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 3, 2024
67fe295
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 3, 2024
9bf3a73
combine gt functions, fix test
nikk-nikaznan Jul 3, 2024
93915e1
rename test
nikk-nikaznan Jul 3, 2024
e1a8537
cleaned up linting
nikk-nikaznan Jul 3, 2024
d6401e1
adding some more description
nikk-nikaznan Jul 3, 2024
64ae8e8
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jul 3, 2024
73ae064
change the nested folder structure for output
nikk-nikaznan Jul 3, 2024
163cc06
adding device to cli
nikk-nikaznan Jul 4, 2024
cc6bbcf
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jul 5, 2024
f042b2b
attempt yesterday
nikk-nikaznan Jul 5, 2024
56ff81d
small changes in docstring
nikk-nikaznan Jul 5, 2024
73607b4
Update crabs/tracker/utils/io.py
nikk-nikaznan Jul 5, 2024
f8c91a9
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 5, 2024
21930ac
changes for gt dict
nikk-nikaznan Jul 5, 2024
0e8a687
predicted as dict
nikk-nikaznan Jul 5, 2024
6e12530
rename varibale, fix test
nikk-nikaznan Jul 5, 2024
36757a5
reviewing id switch
nikk-nikaznan Jul 5, 2024
5b5e1de
commented out the test that fail
nikk-nikaznan Jul 5, 2024
e8f4446
commented out the test that fail
nikk-nikaznan Jul 5, 2024
a464d15
seems working
nikk-nikaznan Jul 5, 2024
98e77c3
small modification for the test
nikk-nikaznan Jul 5, 2024
41ab8cc
cleaned up
nikk-nikaznan Jul 5, 2024
491ae68
cleaned up
nikk-nikaznan Jul 8, 2024
43e3c86
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jul 8, 2024
5e11991
Update crabs/tracker/track_video.py
nikk-nikaznan Jul 9, 2024
8b8b9c9
Update crabs/tracker/track_video.py
nikk-nikaznan Jul 9, 2024
cc1e8c6
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 9, 2024
fcf6e66
Update crabs/tracker/track_video.py
nikk-nikaznan Jul 9, 2024
7a2ba9f
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 9, 2024
fa7fa08
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 9, 2024
24935ae
fixed frame_number vs frame_idx
nikk-nikaznan Jul 9, 2024
1d45dd6
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crabs/detection_tracking/evaluate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def compute_confusion_matrix_elements(

def get_mlflow_parameters_from_ckpt(trained_model_path: str) -> dict:
"""Get MLflow client from ckpt path and associated params."""
import mlflow
from mlflow.tracking import MlflowClient

# roughly assert the format of the path is correct
# Note: to check if this is an MLflow chekcpoint,
Expand All @@ -129,7 +129,7 @@ def get_mlflow_parameters_from_ckpt(trained_model_path: str) -> dict:
ckpt_runID = Path(trained_model_path).parents[1].stem

# create an Mlflow client to interface with mlflow runs
mlrun_client = mlflow.tracking.MlflowClient(
mlrun_client = MlflowClient(
tracking_uri=ckpt_mlruns_path,
)

Expand Down
51 changes: 34 additions & 17 deletions crabs/detection_tracking/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torchvision.transforms.v2 as transforms
from sort import Sort

from crabs.detection_tracking.models import FasterRCNN
from crabs.detection_tracking.tracking_utils import (
evaluate_mota,
get_ground_truth_data,
Expand All @@ -19,6 +20,8 @@
draw_bbox,
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
nikk-nikaznan marked this conversation as resolved.
Show resolved Hide resolved


class DetectorInference:
"""
Expand Down Expand Up @@ -65,12 +68,11 @@ def load_trained_model(self) -> torch.nn.Module:
-------
torch.nn.Module
"""
model = torch.load(
self.args.model_dir,
map_location=torch.device(self.args.accelerator),
)
model.eval()
return model
# Get trained model
nikk-nikaznan marked this conversation as resolved.
Show resolved Hide resolved
trained_model = FasterRCNN.load_from_checkpoint(self.args.model_dir)
trained_model.eval()
trained_model.to(DEVICE)
return trained_model

def prep_sort(self, prediction: dict) -> np.ndarray:
"""
Expand Down Expand Up @@ -156,6 +158,7 @@ def prep_csv_writer(self) -> Tuple[Any, TextIO]:
def evaluate_tracking(
self,
gt_boxes_list: list,
gt_ids_list: list,
tracked_boxes_list: list,
iou_threshold: float,
) -> list[float]:
Expand All @@ -164,9 +167,11 @@ def evaluate_tracking(

Parameters
----------
gt_boxes_list : list[list[float]]
gt_boxes_list : list[float]
List of ground truth bounding boxes for each frame.
tracked_boxes_list : list[list[float]]
gt_id_list : list[float]
List of ground truth ID for each frame.
tracked_boxes_list : list[float]
List of tracked bounding boxes for each frame.
iou_threshold : float
The IoU threshold used to determine matches between ground truth and tracked boxes.
Expand All @@ -177,15 +182,21 @@ def evaluate_tracking(
The computed MOTA (Multi-Object Tracking Accuracy) score for the tracking performance.
"""
mota_values = []
prev_frame_ids: Optional[list[list[int]]] = None
# prev_frame_ids = None
for gt_boxes, tracked_boxes in zip(gt_boxes_list, tracked_boxes_list):
mota = evaluate_mota(
gt_boxes, tracked_boxes, iou_threshold, prev_frame_ids
# prev_frame_ids: Optional[list[int]] = None
prev_frame_id_map: Optional[dict] = None
for gt_boxes, gt_ids, tracked_boxes in zip(
gt_boxes_list, gt_ids_list, tracked_boxes_list
):
mota, prev_frame_id_map = evaluate_mota(
gt_boxes,
gt_ids,
tracked_boxes,
iou_threshold,
prev_frame_id_map,
)
mota_values.append(mota)
# Update previous frame IDs for the next iteration
prev_frame_ids = [[box[-1] for box in tracked_boxes]]
# # Update previous frame IDs for the next iteration
# prev_frame_ids = [box[-1] for box in tracked_boxes]

return mota_values

Expand Down Expand Up @@ -304,14 +315,20 @@ def run_inference(self):
self.prep_sort(prediction)
tracked_boxes = self.update_tracking(prediction)
self.save_required_output(tracked_boxes, frame, frame_number)
self.tracked_list.append(tracked_boxes)

# update frame
frame_number += 1

if self.args.gt_dir:
gt_boxes_list = get_ground_truth_data(self.args.gt_dir)
gt_boxes_list, gt_ids_list = get_ground_truth_data(
self.args.gt_dir
)
mota_values = self.evaluate_tracking(
gt_boxes_list, self.tracked_list, self.iou_threshold
gt_boxes_list,
gt_ids_list,
self.tracked_list,
self.iou_threshold,
)
overall_mota = np.mean(mota_values)
print("Overall MOTA:", overall_mota)
Expand Down
128 changes: 78 additions & 50 deletions crabs/detection_tracking/tracking_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import json
import logging
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple

import cv2
import numpy as np
Expand Down Expand Up @@ -50,46 +50,50 @@ def calculate_iou(box1: np.ndarray, box2: np.ndarray) -> float:


def count_identity_switches(
prev_frame_ids: Optional[list[list[int]]],
current_frame_ids: Optional[list[list[int]]],
prev_frame_id_map: Optional[Dict[int, int]],
current_frame_id_map: Dict[int, int],
) -> int:
"""
Count the number of identity switches between two sets of object IDs.

Parameters
----------
prev_frame_ids : Optional[list[list[int]]]
List of object IDs in the previous frame.
current_frame_ids : Optional[list[list[int]]]
List of object IDs in the current frame.
prev_frame_id_map : Optional[Dict[int, int]]
A dictionary mapping ground truth IDs to predicted IDs from the previous frame.
gt_to_tracked_map : Dict[int, int]
A dictionary mapping ground truth IDs to predicted IDs for the current frame.


Returns
-------
int
The number of identity switches between the two sets of object IDs.
"""

if prev_frame_ids is None or current_frame_ids is None:
if prev_frame_id_map is None:
return 0
nikk-nikaznan marked this conversation as resolved.
Show resolved Hide resolved

# Initialize count of identity switches
num_switches = 0

prev_ids = set(prev_frame_ids[0])
current_ids = set(current_frame_ids[0])
switch_count = 0

# Calculate the number of switches by finding the difference in IDs
num_switches = len(prev_ids.symmetric_difference(current_ids))
for current_gt_id, current_tracked_id in current_frame_id_map.items():
prev_tracked_id = prev_frame_id_map.get(current_gt_id)
nikk-nikaznan marked this conversation as resolved.
Show resolved Hide resolved
if prev_tracked_id is not None:
if prev_tracked_id != current_tracked_id:
switch_count += 1
else:
if current_tracked_id != current_gt_id:
switch_count += 1

return num_switches
return switch_count


def evaluate_mota(
gt_boxes: np.ndarray,
gt_ids: np.ndarray,
tracked_boxes: np.ndarray,
iou_threshold: float,
prev_frame_ids: Optional[list[list[int]]],
) -> float:
prev_frame_id_map: Optional[Dict[int, int]],
) -> Tuple[float, Dict[int, int]]:
"""
Evaluate MOTA (Multiple Object Tracking Accuracy).

Expand All @@ -99,17 +103,21 @@ def evaluate_mota(
----------
gt_boxes : np.ndarray
Ground truth bounding boxes of objects.
gt_ids : np.ndarray
Ground truth IDs corresponding to the bounding boxes.
tracked_boxes : np.ndarray
Tracked bounding boxes of objects.
iou_threshold : float
Intersection over Union (IoU) threshold for considering a match.
prev_frame_ids : Optional[list[list[int]]]
IDs from the previous frame for identity switch detection.
prev_frame_id_map : Optional[Dict[int, int]]
A dictionary mapping ground truth IDs to predicted IDs from the previous frame.

Returns
-------
float
The computed MOTA (Multi-Object Tracking Accuracy) score for the tracking performance.
Dict[int, int]
A dictionary mapping ground truth IDs to predicted IDs for the current frame.

Notes
-----
Expand All @@ -122,40 +130,41 @@ def evaluate_mota(
- Identity Switches: Instances where the tracking algorithm assigns a different ID to an object compared to its ID in the previous frame.
- Total Ground Truth: The total number of ground truth objects in the scene.

The MOTA score ranges from 0 to 1, with higher values indicating better tracking performance.
The MOTA score ranges from -inf to 1, with higher values indicating better tracking performance.
A MOTA score of 1 indicates perfect tracking, where there are no missed detections, false positives, or identity switches.
"""
total_gt = len(gt_boxes)
false_positive = 0
matched_gt_boxes = set()
gt_to_tracked_map = {}

for i, tracked_box in enumerate(tracked_boxes):
best_iou = 0.0
best_match = None

for j, gt_box in enumerate(gt_boxes):
iou = calculate_iou(gt_box[:4], tracked_box[:4])
if iou > iou_threshold and iou > best_iou:
best_iou = iou
best_match = j
if j not in matched_gt_boxes:
iou = calculate_iou(gt_box[:4], tracked_box[:4])
if iou > iou_threshold and iou > best_iou:
best_iou = iou
best_match = j

if best_match is not None:
# successfully found a matching ground truth box for the tracked box.
# set the corresponding ground truth box to None.
gt_boxes[best_match] = None
matched_gt_boxes.add(best_match)
# Map ground truth ID to tracked ID
gt_to_tracked_map[int(gt_ids[best_match])] = int(tracked_box[-1])
else:
false_positive += 1

missed_detections = 0
for box in gt_boxes:
if box is not None and not np.all(np.isnan(box)):
# if true ground truth box was not matched with any tracked box
missed_detections += 1

tracked_ids = [[box[-1] for box in tracked_boxes]]
missed_detections = total_gt - len(matched_gt_boxes)

num_switches = count_identity_switches(prev_frame_ids, tracked_ids)
num_switches = count_identity_switches(
prev_frame_id_map, gt_to_tracked_map
)

mota = 1 - (missed_detections + false_positive + num_switches) / total_gt
return mota
return mota, gt_to_tracked_map


def extract_bounding_box_info(row: list[str]) -> Dict[str, Any]:
Expand Down Expand Up @@ -194,8 +203,10 @@ def extract_bounding_box_info(row: list[str]) -> Dict[str, Any]:


def create_gt_list(
ground_truth_data: list[Dict[str, Any]], gt_boxes_list: list[np.ndarray]
) -> list[np.ndarray]:
ground_truth_data: list[Dict[str, Any]],
gt_boxes_list: list[np.ndarray],
gt_ids_list: list[np.ndarray],
) -> Tuple[list[np.ndarray], list[np.ndarray]]:
"""
Creates a list of ground truth bounding boxes organized by frame number.

Expand All @@ -208,8 +219,10 @@ def create_gt_list(

Returns
-------
list[np.ndarray]:
A list containing ground truth bounding boxes organized by frame number.
Tuple[List[np.ndarray], List[np.ndarray]]:
A tuple containing two lists:
- A list of numpy arrays with ground truth bounding box data organized by frame number.
- A list of numpy arrays with ground truth IDs organized by frame number.
"""
for data in ground_truth_data:
frame_number = data["frame_number"]
Expand All @@ -219,22 +232,30 @@ def create_gt_list(
data["y"],
data["x"] + data["width"],
data["y"] + data["height"],
data["id"],
],
dtype=np.float32,
)
track_id = np.array([data["id"]], dtype=np.float32)

if gt_boxes_list[frame_number].size == 0:
gt_boxes_list[frame_number] = bbox.reshape(
1, -1
) # Initialize as a 2D array
gt_ids_list[frame_number] = track_id
else:
gt_boxes_list[frame_number] = np.vstack(
[gt_boxes_list[frame_number], bbox]
)
return gt_boxes_list
gt_ids_list[frame_number] = np.hstack(
[gt_ids_list[frame_number], track_id]
)

return gt_boxes_list, gt_ids_list


def get_ground_truth_data(gt_dir: str) -> list[np.ndarray]:
def get_ground_truth_data(
gt_dir: str,
) -> Tuple[list[np.ndarray], list[np.ndarray]]:
"""
Extract ground truth bounding box data from a CSV file.

Expand All @@ -245,10 +266,12 @@ def get_ground_truth_data(gt_dir: str) -> list[np.ndarray]:

Returns
-------
list[np.ndarray]:
A list containing ground truth bounding box data organized by frame number.
The numpy array represent the coordinates and ID of the bounding box in the order:
x, y, x + width, y + height, ID
Tuple[List[np.ndarray], List[np.ndarray]]:
A tuple containing two lists:
- A list of numpy arrays with ground truth bounding box data organized by frame number.
Each numpy array represents the coordinates of the bounding boxes in the order:
x, y, x + width, y + height
- A list of numpy arrays with ground truth IDs organized by frame number.
"""
ground_truth_data = []
max_frame_number = 0
Expand All @@ -262,11 +285,16 @@ def get_ground_truth_data(gt_dir: str) -> list[np.ndarray]:
ground_truth_data.append(data)
max_frame_number = max(max_frame_number, data["frame_number"])

# Initialize a list to store the ground truth bounding boxes for each frame
# Initialize lists to store the ground truth bounding boxes and IDs for each frame
gt_boxes_list = [np.array([]) for _ in range(max_frame_number + 1)]
gt_ids_list = [np.array([]) for _ in range(max_frame_number + 1)]

# Populate the gt_boxes_list and gt_id_list
gt_boxes_list, gt_ids_list = create_gt_list(
ground_truth_data, gt_boxes_list, gt_ids_list
)

gt_boxes_list = create_gt_list(ground_truth_data, gt_boxes_list)
return gt_boxes_list
return gt_boxes_list, gt_ids_list


def write_tracked_bbox_to_csv(
Expand Down
Loading