-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Saving mota output #180
base: main
Are you sure you want to change the base?
Saving mota output #180
Changes from all commits
bb148b6
2ed0673
1592236
bdf477f
fb87dec
409ef4a
5e910b4
51c7459
4ec7825
2e25910
0d36020
29da996
d6291d1
c8b033f
f54fdf0
1a140cf
444c915
90d7376
a90f9c4
44d8062
95b06de
5595135
829f7a2
c3dce1f
10f9512
dec2a03
64d583c
0aa5040
7d80256
ca779ad
1bf8735
a960fa0
aaf0c48
a1cf6d3
95d3d47
e74ff93
92cddc9
167f79b
3dad25e
11ca37f
305a599
028de49
1e3ef67
a8f3942
68a8c2b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,10 +1,14 @@ | ||||||
import csv | ||||||
import logging | ||||||
from pathlib import Path | ||||||
from typing import Any, Dict, Optional, Tuple | ||||||
|
||||||
import numpy as np | ||||||
|
||||||
from crabs.tracker.utils.tracking import extract_bounding_box_info | ||||||
from crabs.tracker.utils.tracking import ( | ||||||
extract_bounding_box_info, | ||||||
save_tracking_mota_metrics, | ||||||
) | ||||||
|
||||||
|
||||||
class TrackerEvaluate: | ||||||
|
@@ -13,6 +17,7 @@ def __init__( | |||||
gt_dir: str, | ||||||
predicted_boxes_id: list[np.ndarray], | ||||||
iou_threshold: float, | ||||||
tracking_output_dir: Path, | ||||||
): | ||||||
""" | ||||||
Initialize the TrackerEvaluate class with ground truth directory, tracked list, and IoU threshold. | ||||||
|
@@ -32,6 +37,7 @@ def __init__( | |||||
self.gt_dir = gt_dir | ||||||
self.predicted_boxes_id = predicted_boxes_id | ||||||
self.iou_threshold = iou_threshold | ||||||
self.tracking_output_dir = tracking_output_dir | ||||||
|
||||||
def get_predicted_data(self) -> Dict[int, Dict[str, Any]]: | ||||||
""" | ||||||
|
@@ -226,7 +232,7 @@ def evaluate_mota( | |||||
pred_data: Dict[str, np.ndarray], | ||||||
iou_threshold: float, | ||||||
gt_to_tracked_id_previous_frame: Optional[Dict[int, int]], | ||||||
) -> Tuple[float, Dict[int, int]]: | ||||||
) -> Tuple[float, int, int, int, int, int, Dict[int, int]]: | ||||||
""" | ||||||
Evaluate MOTA (Multiple Object Tracking Accuracy). | ||||||
|
||||||
|
@@ -254,6 +260,7 @@ def evaluate_mota( | |||||
""" | ||||||
total_gt = len(gt_data["bbox"]) | ||||||
false_positive = 0 | ||||||
true_positive = 0 | ||||||
indices_of_matched_gt_boxes = set() | ||||||
gt_to_tracked_id_current_frame = {} | ||||||
|
||||||
|
@@ -278,6 +285,7 @@ def evaluate_mota( | |||||
index_gt_not_match = j | ||||||
|
||||||
if index_gt_best_match is not None: | ||||||
true_positive += 1 | ||||||
# Successfully found a matching ground truth box for the tracked box. | ||||||
indices_of_matched_gt_boxes.add(index_gt_best_match) | ||||||
# Map ground truth ID to tracked ID | ||||||
|
@@ -299,7 +307,15 @@ def evaluate_mota( | |||||
mota = ( | ||||||
1 - (missed_detections + false_positive + num_switches) / total_gt | ||||||
) | ||||||
return mota, gt_to_tracked_id_current_frame | ||||||
return ( | ||||||
mota, | ||||||
true_positive, | ||||||
missed_detections, | ||||||
false_positive, | ||||||
num_switches, | ||||||
total_gt, | ||||||
gt_to_tracked_id_current_frame, | ||||||
) | ||||||
|
||||||
def evaluate_tracking( | ||||||
self, | ||||||
|
@@ -323,19 +339,46 @@ def evaluate_tracking( | |||||
""" | ||||||
mota_values = [] | ||||||
prev_frame_id_map: Optional[dict] = None | ||||||
results: dict[str, Any] = { | ||||||
"Frame Number": [], | ||||||
"Total Ground Truth": [], | ||||||
"True Positives": [], | ||||||
"Missed Detections": [], | ||||||
"False Positives": [], | ||||||
"Number of Switches": [], | ||||||
"Mota": [], | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
} | ||||||
|
||||||
for frame_number in sorted(ground_truth_dict.keys()): | ||||||
gt_data_frame = ground_truth_dict[frame_number] | ||||||
|
||||||
if frame_number < len(predicted_dict): | ||||||
pred_data_frame = predicted_dict[frame_number] | ||||||
mota, prev_frame_id_map = self.evaluate_mota( | ||||||
|
||||||
( | ||||||
mota, | ||||||
true_positives, | ||||||
missed_detections, | ||||||
false_positives, | ||||||
num_switches, | ||||||
total_gt, | ||||||
prev_frame_id_map, | ||||||
) = self.evaluate_mota( | ||||||
gt_data_frame, | ||||||
pred_data_frame, | ||||||
self.iou_threshold, | ||||||
prev_frame_id_map, | ||||||
) | ||||||
mota_values.append(mota) | ||||||
results["Frame Number"].append(frame_number) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we make for key in results.keys():
results[key].append(mota_dict[key]) |
||||||
results["Total Ground Truth"].append(total_gt) | ||||||
results["True Positives"].append(true_positives) | ||||||
results["Missed Detections"].append(missed_detections) | ||||||
results["False Positives"].append(false_positives) | ||||||
results["Number of Switches"].append(num_switches) | ||||||
results["Mota"].append(mota) | ||||||
|
||||||
save_tracking_mota_metrics(self.tracking_output_dir, results) | ||||||
|
||||||
return mota_values | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,9 +1,11 @@ | ||||||
import argparse | ||||||
import csv | ||||||
import os | ||||||
from datetime import datetime | ||||||
from pathlib import Path | ||||||
|
||||||
import cv2 | ||||||
import matplotlib.pyplot as plt | ||||||
import numpy as np | ||||||
|
||||||
from crabs.detector.utils.visualization import draw_bbox | ||||||
|
@@ -154,6 +156,7 @@ def save_required_output( | |||||
frame_copy = frame.copy() | ||||||
for bbox in tracked_boxes: | ||||||
xmin, ymin, xmax, ymax, id = bbox | ||||||
print(f"Calling draw_bbox with {bbox}") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
draw_bbox( | ||||||
frame_copy, | ||||||
(xmin, ymin), | ||||||
|
@@ -178,3 +181,149 @@ def release_video(video_output) -> None: | |||||
""" | ||||||
if video_output: | ||||||
video_output.release() | ||||||
|
||||||
|
||||||
def read_metrics_from_csv(filename): | ||||||
""" | ||||||
Read the tracking output metrics from a CSV file. | ||||||
To be called by plot_output_histogram. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
filename : str | ||||||
Name of the CSV file to read. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
tuple: | ||||||
Tuple containing lists of true positives, missed detections, | ||||||
false positives, number of switches, and total ground truth for each frame. | ||||||
""" | ||||||
true_positives_list = [] | ||||||
missed_detections_list = [] | ||||||
false_positives_list = [] | ||||||
num_switches_list = [] | ||||||
total_ground_truth_list = [] | ||||||
mota_value_list = [] | ||||||
|
||||||
with open(filename, mode="r") as file: | ||||||
reader = csv.DictReader(file) | ||||||
for row in reader: | ||||||
true_positives_list.append(int(row["True Positives"])) | ||||||
missed_detections_list.append(int(row["Missed Detections"])) | ||||||
false_positives_list.append(int(row["False Positives"])) | ||||||
num_switches_list.append(int(row["Number of Switches"])) | ||||||
total_ground_truth_list.append(int(row["Total Ground Truth"])) | ||||||
mota_value_list.append(float(row["Mota"])) | ||||||
|
||||||
return ( | ||||||
true_positives_list, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe this tuple can be a dict instead? It's a bit less of a code smell There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think if we read the csv as a pandas dataframe instead we can extract the columns more efficiently (that is, without explicit looping). There is also a dataframe |
||||||
missed_detections_list, | ||||||
false_positives_list, | ||||||
num_switches_list, | ||||||
total_ground_truth_list, | ||||||
mota_value_list, | ||||||
) | ||||||
|
||||||
|
||||||
def plot_output_histogram(filename): | ||||||
""" | ||||||
Plot metrics along with the total ground truth for each frame. | ||||||
|
||||||
Example usage: | ||||||
> filename = <video_name>/tracking_metrics_output.csv | ||||||
> python crabs/tracker/utils/io.py filename | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
true_positives_list : list[int] | ||||||
List of counts of true positives for each frame. | ||||||
missed_detections_list : list[int] | ||||||
List of counts of missed detections for each frame. | ||||||
false_positives_list : list[int] | ||||||
List of counts of false positives for each frame. | ||||||
num_switches_list : list[int] | ||||||
List of counts of identity switches for each frame. | ||||||
total_ground_truth_list : list[int] | ||||||
List of total ground truth objects for each frame. | ||||||
""" | ||||||
( | ||||||
true_positives_list, | ||||||
missed_detections_list, | ||||||
false_positives_list, | ||||||
num_switches_list, | ||||||
total_ground_truth_list, | ||||||
mota_value_list, | ||||||
) = read_metrics_from_csv(filename) | ||||||
filepath = Path(filename) | ||||||
plot_name = filepath.name | ||||||
|
||||||
num_frames = len(true_positives_list) | ||||||
frames = range(1, num_frames + 1) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
||||||
plt.figure(figsize=(10, 6)) | ||||||
|
||||||
overall_mota = sum(mota_value_list) / len(mota_value_list) | ||||||
|
||||||
# Calculate percentages | ||||||
true_positives_percentage = [ | ||||||
tp / gt * 100 if gt > 0 else 0 | ||||||
for tp, gt in zip(true_positives_list, total_ground_truth_list) | ||||||
] | ||||||
missed_detections_percentage = [ | ||||||
md / gt * 100 if gt > 0 else 0 | ||||||
for md, gt in zip(missed_detections_list, total_ground_truth_list) | ||||||
] | ||||||
false_positives_percentage = [ | ||||||
fp / gt * 100 if gt > 0 else 0 | ||||||
for fp, gt in zip(false_positives_list, total_ground_truth_list) | ||||||
] | ||||||
num_switches_percentage = [ | ||||||
ns / gt * 100 if gt > 0 else 0 | ||||||
for ns, gt in zip(num_switches_list, total_ground_truth_list) | ||||||
] | ||||||
|
||||||
# Plot metrics | ||||||
plt.plot( | ||||||
frames, | ||||||
true_positives_percentage, | ||||||
label=f"True Positives ({sum(true_positives_list)})", | ||||||
color="g", | ||||||
) | ||||||
plt.plot( | ||||||
frames, | ||||||
missed_detections_percentage, | ||||||
label=f"Missed Detections ({sum(missed_detections_list)})", | ||||||
color="r", | ||||||
) | ||||||
plt.plot( | ||||||
frames, | ||||||
false_positives_percentage, | ||||||
label=f"False Positives ({sum(false_positives_list)})", | ||||||
color="b", | ||||||
) | ||||||
plt.plot( | ||||||
frames, | ||||||
num_switches_percentage, | ||||||
label=f"Number of Switches ({sum(num_switches_list)})", | ||||||
color="y", | ||||||
) | ||||||
|
||||||
plt.xlabel("Frame Number") | ||||||
plt.ylabel("Percentage of Total Ground Truth (%)") | ||||||
plt.title(f"{plot_name}_mota:{overall_mota:.2f}") | ||||||
|
||||||
plt.legend() | ||||||
plt.savefig(f"{plot_name}.pdf") | ||||||
plt.show() | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
parser = argparse.ArgumentParser(description="Plot output histogram.") | ||||||
parser.add_argument( | ||||||
"filename", | ||||||
type=str, | ||||||
help="Path to the CSV file containing the metrics", | ||||||
) | ||||||
args = parser.parse_args() | ||||||
plot_output_histogram(args.filename) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
|
||
import cv2 | ||
import numpy as np | ||
import pandas as pd | ||
|
||
|
||
def extract_bounding_box_info(row: list[str]) -> Dict[str, Any]: | ||
|
@@ -152,3 +153,12 @@ def prep_sort(prediction: dict, score_threshold: float) -> np.ndarray: | |
pred_sort.append(bbox) | ||
|
||
return np.asarray(pred_sort) | ||
|
||
|
||
def save_tracking_mota_metrics( | ||
tracking_output_dir: Path, | ||
track_results: dict[str, Any], | ||
) -> None: | ||
track_df = pd.DataFrame(track_results) | ||
output_filename = f"{tracking_output_dir}/tracking_metrics_output.csv" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to follow our usual convention, could we timestamp the output directory? We would need to ensure it goes in the same directory as the video output if requested (or other outputs that may be requested) |
||
track_df.to_csv(output_filename, index=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Returning a long tuple is sometimes considered a code smell.
Maybe we can pass mota and its components as a dict to reduce this?