Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Review MOTA histogram plotting script #239

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions crabs/tracker/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""Utility functions for handling input and output operations."""

import argparse
import csv
import os
from datetime import datetime
from pathlib import Path

import cv2
import matplotlib.pyplot as plt
import numpy as np

from crabs.detector.utils.visualization import draw_bbox
Expand Down Expand Up @@ -163,6 +165,7 @@ def save_required_output(
frame_copy = frame.copy()
for bbox in tracked_boxes:
xmin, ymin, xmax, ymax, id = bbox

draw_bbox(
frame_copy,
(xmin, ymin),
Expand All @@ -183,3 +186,153 @@ def release_video(video_output) -> None:
"""Release the video file if it's open."""
if video_output:
video_output.release()


def read_metrics_from_csv(filename):
"""Read the tracking output metrics from a CSV file.

To be called by plot_output_histogram.

Parameters
----------
filename : str
Name of the CSV file to read.

Returns
-------
tuple:
Tuple containing lists of true positives, missed detections,
false positives, number of switches, and total ground truth
for each frame.

"""
true_positives_list = []
missed_detections_list = []
false_positives_list = []
num_switches_list = []
total_ground_truth_list = []
mota_value_list = []

with open(filename) as file:
reader = csv.DictReader(file)
for row in reader:
true_positives_list.append(int(row["True Positives"]))
missed_detections_list.append(int(row["Missed Detections"]))
false_positives_list.append(int(row["False Positives"]))
num_switches_list.append(int(row["Number of Switches"]))
total_ground_truth_list.append(int(row["Total Ground Truth"]))
mota_value_list.append(float(row["Mota"]))

return (
true_positives_list,
missed_detections_list,
false_positives_list,
num_switches_list,
total_ground_truth_list,
mota_value_list,
)


def plot_output_histogram(filename):
"""Plot metrics along with the total ground truth for each frame.

Example usage:
> filename = <video_name>/tracking_metrics_output.csv
> python crabs/tracker/utils/io.py filename

Parameters
----------
filename : str
Name of the CSV file to read.
true_positives_list : list[int]
List of counts of true positives for each frame.
missed_detections_list : list[int]
List of counts of missed detections for each frame.
false_positives_list : list[int]
List of counts of false positives for each frame.
num_switches_list : list[int]
List of counts of identity switches for each frame.
total_ground_truth_list : list[int]
List of total ground truth objects for each frame.

"""
(
true_positives_list,
missed_detections_list,
false_positives_list,
num_switches_list,
total_ground_truth_list,
mota_value_list,
) = read_metrics_from_csv(filename)
filepath = Path(filename)
plot_name = filepath.name

num_frames = len(true_positives_list)
frame_numbers = range(1, num_frames + 1)

plt.figure(figsize=(10, 6))

overall_mota = sum(mota_value_list) / len(mota_value_list)

# Calculate percentages
true_positives_percentage = [
tp / gt * 100 if gt > 0 else 0
for tp, gt in zip(true_positives_list, total_ground_truth_list)
]
missed_detections_percentage = [
md / gt * 100 if gt > 0 else 0
for md, gt in zip(missed_detections_list, total_ground_truth_list)
]
false_positives_percentage = [
fp / gt * 100 if gt > 0 else 0
for fp, gt in zip(false_positives_list, total_ground_truth_list)
]
num_switches_percentage = [
ns / gt * 100 if gt > 0 else 0
for ns, gt in zip(num_switches_list, total_ground_truth_list)
]

# Plot metrics
plt.plot(
frame_numbers,
true_positives_percentage,
label=f"True Positives ({sum(true_positives_list)})",
color="g",
)
plt.plot(
frame_numbers,
missed_detections_percentage,
label=f"Missed Detections ({sum(missed_detections_list)})",
color="r",
)
plt.plot(
frame_numbers,
false_positives_percentage,
label=f"False Positives ({sum(false_positives_list)})",
color="b",
)
plt.plot(
frame_numbers,
num_switches_percentage,
label=f"Number of Switches ({sum(num_switches_list)})",
color="y",
)

plt.xlabel("Frame Number")
plt.ylabel("Percentage of Total Ground Truth (%)")
plt.title(f"{plot_name}_mota:{overall_mota:.2f}")

plt.legend()
plt.savefig(f"{plot_name}.pdf")
plt.show()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Plot output histogram.")
parser.add_argument(
"filename",
type=str,
help="Path to the CSV file containing the metrics",
)
args = parser.parse_args()
plot_output_histogram(args.filename)