Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ground truth tracking video #216

Open
wants to merge 61 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
73ea420
modify id switches
nikk-nikaznan Jun 5, 2024
768b81e
change the mota and test
nikk-nikaznan Jun 5, 2024
3036b06
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 5, 2024
f18a44c
change the variable
nikk-nikaznan Jun 14, 2024
6c0b0ed
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 18, 2024
a851e9f
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 21, 2024
611bd60
some bug fixed, load from checkpoint
nikk-nikaznan Jun 24, 2024
11bded7
Merge branch 'nikkna/id_switches' of github.com:SainsburyWellcomeCent…
nikk-nikaznan Jun 24, 2024
4b9a794
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 24, 2024
cfe67ca
change list type, add gt_ids
nikk-nikaznan Jun 25, 2024
7de7561
fixed the error id switches
nikk-nikaznan Jun 25, 2024
25f36c7
changes id switches
nikk-nikaznan Jun 25, 2024
de0b9cd
fixing some test and type hint
nikk-nikaznan Jun 27, 2024
5e56e1a
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 27, 2024
6aa1b63
fixing test, parametrize the test with additional test
nikk-nikaznan Jun 28, 2024
05faa18
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 28, 2024
a592c1a
cleane dup
nikk-nikaznan Jun 28, 2024
eefcf33
checking some test
nikk-nikaznan Jun 28, 2024
f0bcf65
rebase
nikk-nikaznan Jun 28, 2024
da0c62b
cleaned up
nikk-nikaznan Jun 28, 2024
01cb0f4
test works
nikk-nikaznan Jun 28, 2024
3ca1872
test works
nikk-nikaznan Jun 28, 2024
4d32f77
aded specific example
nikk-nikaznan Jun 28, 2024
ff059ea
some more test
nikk-nikaznan Jun 28, 2024
1195854
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 3, 2024
67fe295
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 3, 2024
9bf3a73
combine gt functions, fix test
nikk-nikaznan Jul 3, 2024
93915e1
rename test
nikk-nikaznan Jul 3, 2024
e1a8537
cleaned up linting
nikk-nikaznan Jul 3, 2024
d6401e1
adding some more description
nikk-nikaznan Jul 3, 2024
64ae8e8
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jul 3, 2024
73ae064
change the nested folder structure for output
nikk-nikaznan Jul 3, 2024
163cc06
adding device to cli
nikk-nikaznan Jul 4, 2024
cc6bbcf
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jul 5, 2024
f042b2b
attempt yesterday
nikk-nikaznan Jul 5, 2024
56ff81d
small changes in docstring
nikk-nikaznan Jul 5, 2024
73607b4
Update crabs/tracker/utils/io.py
nikk-nikaznan Jul 5, 2024
f8c91a9
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 5, 2024
21930ac
changes for gt dict
nikk-nikaznan Jul 5, 2024
0e8a687
predicted as dict
nikk-nikaznan Jul 5, 2024
6e12530
rename varibale, fix test
nikk-nikaznan Jul 5, 2024
36757a5
reviewing id switch
nikk-nikaznan Jul 5, 2024
5b5e1de
commented out the test that fail
nikk-nikaznan Jul 5, 2024
e8f4446
commented out the test that fail
nikk-nikaznan Jul 5, 2024
a464d15
seems working
nikk-nikaznan Jul 5, 2024
98e77c3
small modification for the test
nikk-nikaznan Jul 5, 2024
41ab8cc
cleaned up
nikk-nikaznan Jul 5, 2024
491ae68
cleaned up
nikk-nikaznan Jul 8, 2024
ad8f83e
add gt tracking video
nikk-nikaznan Jul 8, 2024
5a7f32f
Merge branch 'main' into nikkna/gt_tracking_video
nikk-nikaznan Jul 8, 2024
749dc18
cleaned up
nikk-nikaznan Jul 8, 2024
0ee9bfe
adding test
nikk-nikaznan Jul 8, 2024
4db520d
test for predicted dict
nikk-nikaznan Jul 8, 2024
810a536
Merge branch 'main' into nikkna/gt_tracking_video
nikk-nikaznan Jul 9, 2024
7c06e2e
cleaned up from the latest merge in main
nikk-nikaznan Jul 9, 2024
9d0c185
fixed test
nikk-nikaznan Jul 9, 2024
14379e7
cleaned up import
nikk-nikaznan Jul 9, 2024
07c99d4
the revised version
nikk-nikaznan Jul 11, 2024
85b5b98
moved iou function to utils
nikk-nikaznan Jul 11, 2024
29ed15a
Merge branch 'main' into nikkna/gt_tracking_video
nikk-nikaznan Jul 19, 2024
2e1553d
fixed the gt to not save video by default
nikk-nikaznan Jul 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 3 additions & 128 deletions crabs/tracker/evaluate_tracker.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import csv
import logging
from typing import Any, Dict, Optional, Tuple

import numpy as np

from crabs.tracker.utils.tracking import extract_bounding_box_info
from crabs.tracker.utils.tracking import calculate_iou


class TrackerEvaluate:
def __init__(
self,
gt_dir: str,
predicted_boxes_id: list[np.ndarray],
iou_threshold: float,
):
"""
Expand All @@ -21,134 +19,13 @@ def __init__(
----------
gt_dir : str
Directory path of the ground truth CSV file.
tracked_list : List[np.ndarray]
A list where each element is a numpy array representing tracked objects in a frame.
Each numpy array has shape (N, 5), where N is the number of objects.
The columns are [x1, y1, x2, y2, id], where (x1, y1) and (x2, y2)
define the bounding box and id is the object ID.
iou_threshold : float
Intersection over Union (IoU) threshold for evaluating tracking performance.
"""
self.gt_dir = gt_dir
self.predicted_boxes_id = predicted_boxes_id
self.iou_threshold = iou_threshold
self.last_known_predicted_ids: Dict = {}

def get_predicted_data(self) -> Dict[int, Dict[str, Any]]:
"""
Convert predicted bounding box and ID into a dictionary organized by frame number.

Returns
-------
Dict[int, Dict[str, Any]]:
A dictionary where the key is the frame number and the value is another dictionary containing:
- 'bbox': A numpy array with shape (N, 4) containing coordinates of the bounding boxes
[x, y, x + width, y + height] for every object in the frame.
- 'id': A numpy array containing the IDs of the tracked objects.
"""
predicted_dict: Dict[int, Dict[str, Any]] = {}

for frame_number, frame_data in enumerate(self.predicted_boxes_id):
if frame_data.size == 0:
continue

bboxes = frame_data[:, :4]
ids = frame_data[:, 4]

predicted_dict[frame_number] = {"bbox": bboxes, "id": ids}

return predicted_dict

def get_ground_truth_data(self) -> Dict[int, Dict[str, Any]]:
"""
Extract ground truth bounding box data from a CSV file and organize it by frame number.

Returns
-------
Dict[int, Dict[str, Any]]:
A dictionary where the key is the frame number and the value is another dictionary containing:
- 'bbox': A numpy arrays with shape of (N, 4) containing coordinates of the bounding box
[x, y, x + width, y + height] for every crabs in the frame.
- 'id': The ground truth ID
"""
with open(self.gt_dir, "r") as csvfile:
csvreader = csv.reader(csvfile)
next(csvreader) # Skip the header row
ground_truth_data = [
extract_bounding_box_info(row) for row in csvreader
]

# Format as a dictionary with key = frame number
ground_truth_dict: dict = {}
for data in ground_truth_data:
frame_number = data["frame_number"]
bbox = np.array(
[
data["x"],
data["y"],
data["x"] + data["width"],
data["y"] + data["height"],
],
dtype=np.float32,
)
track_id = int(float(data["id"]))

if frame_number not in ground_truth_dict:
ground_truth_dict[frame_number] = {"bbox": [], "id": []}

ground_truth_dict[frame_number]["bbox"].append(bbox)
ground_truth_dict[frame_number]["id"].append(track_id)

# format as numpy arrays
for frame_number in ground_truth_dict:
ground_truth_dict[frame_number]["bbox"] = np.array(
ground_truth_dict[frame_number]["bbox"], dtype=np.float32
)
ground_truth_dict[frame_number]["id"] = np.array(
ground_truth_dict[frame_number]["id"], dtype=np.float32
)
return ground_truth_dict

def calculate_iou(self, box1: np.ndarray, box2: np.ndarray) -> float:
"""
Calculate IoU (Intersection over Union) of two bounding boxes.

Parameters
----------
box1 (np.ndarray):
Coordinates [x1, y1, x2, y2] of the first bounding box.
Here, (x1, y1) represents the top-left corner, and (x2, y2) represents the bottom-right corner.
box2 (np.ndarray):
Coordinates [x1, y1, x2, y2] of the second bounding box.
Here, (x1, y1) represents the top-left corner, and (x2, y2) represents the bottom-right corner.

Returns
-------
float:
IoU value.
"""
x1_box1, y1_box1, x2_box1, y2_box1 = box1
x1_box2, y1_box2, x2_box2, y2_box2 = box2

# Calculate intersection coordinates
x1_intersect = max(x1_box1, x1_box2)
y1_intersect = max(y1_box1, y1_box2)
x2_intersect = min(x2_box1, x2_box2)
y2_intersect = min(y2_box1, y2_box2)

# Calculate area of intersection rectangle
intersect_width = max(0, x2_intersect - x1_intersect + 1)
intersect_height = max(0, y2_intersect - y1_intersect + 1)
intersect_area = intersect_width * intersect_height

# Calculate area of individual bounding boxes
box1_area = (x2_box1 - x1_box1 + 1) * (y2_box1 - y1_box1 + 1)
box2_area = (x2_box2 - x1_box2 + 1) * (y2_box2 - y1_box2 + 1)

iou = intersect_area / float(box1_area + box2_area - intersect_area)

return iou

def count_identity_switches(
self,
gt_to_tracked_id_previous_frame: Optional[Dict[int, int]],
Expand Down Expand Up @@ -293,7 +170,7 @@ def evaluate_mota(

for j, gt_box in enumerate(gt_boxes):
if j not in indices_of_matched_gt_boxes:
iou = self.calculate_iou(gt_box, pred_box)
iou = calculate_iou(gt_box, pred_box)
if iou > iou_threshold and iou > best_iou:
best_iou = iou
index_gt_best_match = j
Expand Down Expand Up @@ -363,12 +240,10 @@ def evaluate_tracking(

return mota_values

def run_evaluation(self) -> None:
def run_evaluation(self, predicted_dict, ground_truth_dict) -> None:
"""
Run evaluation of tracking based on tracking ground truth.
"""
predicted_dict = self.get_predicted_data()
ground_truth_dict = self.get_ground_truth_data()
mota_values = self.evaluate_tracking(ground_truth_dict, predicted_dict)

overall_mota = np.mean(mota_values)
Expand Down
22 changes: 18 additions & 4 deletions crabs/tracker/track_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
release_video,
save_required_output,
)
from crabs.tracker.utils.tracking import prep_sort
from crabs.tracker.utils.tracking import (
get_ground_truth_data,
get_predicted_data,
prep_sort,
)


class Tracking:
Expand Down Expand Up @@ -153,11 +157,20 @@ def run_tracking(self):
Run object detection + tracking on the video frames.
"""
# If we pass ground truth: check the path exist
if self.args.gt_path and not os.path.exists(self.args.gt_path):
if (self.args.gt_path is not None) and (
not os.path.exists(str(self.args.gt_path))
):
logging.info(
f"Ground truth file {self.args.gt_path} does not exist. Exiting..."
)
return
# if the path exist, we get the ground_truth_dict
elif self.args.gt_path is not None:
ground_truth_dict = get_ground_truth_data(
self.args.gt_path,
)
elif self.args.gt_path is None:
ground_truth_dict = {}

# initialisation
frame_idx = 0
Expand Down Expand Up @@ -203,6 +216,7 @@ def run_tracking(self):
frame,
frame_idx + 1,
pred_scores,
ground_truth_dict,
nikk-nikaznan marked this conversation as resolved.
Show resolved Hide resolved
)

# update frame number
Expand All @@ -211,10 +225,10 @@ def run_tracking(self):
if self.args.gt_path:
evaluation = TrackerEvaluate(
self.args.gt_path,
self.tracked_bbox_id,
self.config["iou_threshold"],
)
evaluation.run_evaluation()
predicted_dict = get_predicted_data(self.tracked_bbox_id)
evaluation.run_evaluation(predicted_dict, ground_truth_dict)

# Close input video
self.video.release()
Expand Down
50 changes: 41 additions & 9 deletions crabs/tracker/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional

import cv2
import numpy as np
Expand Down Expand Up @@ -108,6 +109,7 @@ def save_required_output(
frame: np.ndarray,
frame_number: int,
pred_scores: np.ndarray,
ground_truth_dict: Optional[Dict[int, Dict[str, Any]]] = None,
) -> None:
"""
Handle the output based on argument options.
Expand All @@ -134,6 +136,8 @@ def save_required_output(
The frame number.
pred_scores : np.ndarray
The prediction score from detector
ground_truth_dict : dict
Dictionary containing ground truth bounding boxes and IDs for each frame, organized by frame number.
"""
frame_name = f"{video_file_root}_frame_{frame_number:08d}.png"

Expand All @@ -152,15 +156,43 @@ def save_required_output(

if save_video:
frame_copy = frame.copy()
for bbox in tracked_boxes:
xmin, ymin, xmax, ymax, id = bbox
draw_bbox(
frame_copy,
(xmin, ymin),
(xmax, ymax),
(0, 0, 255),
f"id : {int(id)}",
)

if ground_truth_dict and frame_number in ground_truth_dict:
for bbox, obj_id in zip(
ground_truth_dict[frame_number]["bbox"],
ground_truth_dict[frame_number]["id"],
):
x1, y1, x2, y2 = map(int, bbox)
draw_bbox(
frame_copy,
(x1, y1),
(x2, y2),
(0, 255, 0), # Green for ground truth
f"GT ID: {int(obj_id)}",
)

# Draw tracked bounding boxes and IDs
for bbox in tracked_boxes:
xmin, ymin, xmax, ymax, obj_id = bbox
draw_bbox(
frame_copy,
(int(xmin), int(ymin)),
(int(xmax), int(ymax)),
(0, 0, 255), # Red for predictions
f"Pred ID: {int(obj_id)}",
)

else:
for bbox in tracked_boxes:
xmin, ymin, xmax, ymax, id = bbox
draw_bbox(
frame_copy,
(xmin, ymin),
(xmax, ymax),
(0, 0, 255),
f"id : {int(id)}",
)

video_output.write(frame_copy)


Expand Down
Loading