diff --git a/crabs/detector/config/faster_rcnn.yaml b/crabs/detector/config/faster_rcnn.yaml index 66893684..00a13b5a 100644 --- a/crabs/detector/config/faster_rcnn.yaml +++ b/crabs/detector/config/faster_rcnn.yaml @@ -12,7 +12,7 @@ num_classes: 2 # ------------------------------- # Training & validation parameters # ------------------------------- -n_epochs: 250 +n_epochs: 1 learning_rate: 0.00005 wdecay: 0.00005 batch_size_train: 4 diff --git a/crabs/tracker/evaluate_tracker.py b/crabs/tracker/evaluate_tracker.py index 74f65db7..f6f4aa8f 100644 --- a/crabs/tracker/evaluate_tracker.py +++ b/crabs/tracker/evaluate_tracker.py @@ -1,6 +1,6 @@ import csv import logging -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple import numpy as np @@ -8,31 +8,77 @@ class TrackerEvaluate: - def __init__(self, gt_dir: str, tracked_list: list, iou_threshold: float): - self.gt_dir = gt_dir - self.tracked_list = tracked_list - self.iou_threshold = iou_threshold - - def create_gt_list( + def __init__( self, - ground_truth_data: list[Dict[str, Any]], - gt_boxes_list: list[np.ndarray], - ) -> list[np.ndarray]: + gt_dir: str, + predicted_boxes_id: list[np.ndarray], + iou_threshold: float, + ): """ - Creates a list of ground truth bounding boxes organized by frame number. + Initialize the TrackerEvaluate class with ground truth directory, tracked list, and IoU threshold. Parameters ---------- - ground_truth_data : list[Dict[str, Any]] - A list containing ground truth bounding box data organized by frame number. - gt_boxes_list : list[np.ndarray] - A list to store the ground truth bounding boxes for each frame. + gt_dir : str + Directory path of the ground truth CSV file. + tracked_list : List[np.ndarray] + A list where each element is a numpy array representing tracked objects in a frame. + Each numpy array has shape (N, 5), where N is the number of objects. + The columns are [x1, y1, x2, y2, id], where (x1, y1) and (x2, y2) + define the bounding box and id is the object ID. + iou_threshold : float + Intersection over Union (IoU) threshold for evaluating tracking performance. + """ + self.gt_dir = gt_dir + self.predicted_boxes_id = predicted_boxes_id + self.iou_threshold = iou_threshold + + def get_predicted_data(self) -> Dict[int, Dict[str, Any]]: + """ + Convert predicted bounding box and ID into a dictionary organized by frame number. Returns ------- - list[np.ndarray]: - A list containing ground truth bounding boxes organized by frame number. + Dict[int, Dict[str, Any]]: + A dictionary where the key is the frame number and the value is another dictionary containing: + - 'bbox': A numpy array with shape (N, 4) containing coordinates of the bounding boxes + [x, y, x + width, y + height] for every object in the frame. + - 'id': A numpy array containing the IDs of the tracked objects. """ + predicted_dict: Dict[int, Dict[str, Any]] = {} + + for frame_number, frame_data in enumerate(self.predicted_boxes_id): + if frame_data.size == 0: + continue + + bboxes = frame_data[:, :4] + ids = frame_data[:, 4] + + predicted_dict[frame_number] = {"bbox": bboxes, "id": ids} + + return predicted_dict + + def get_ground_truth_data(self) -> Dict[int, Dict[str, Any]]: + """ + Extract ground truth bounding box data from a CSV file and organize it by frame number. + + Returns + ------- + Dict[int, Dict[str, Any]]: + A dictionary where the key is the frame number and the value is another dictionary containing: + - 'bbox': A numpy arrays with shape of (N, 4) containing coordinates of the bounding box + [x, y, x + width, y + height] for every crabs in the frame. + - 'id': The ground truth ID + """ + with open(self.gt_dir, "r") as csvfile: + csvreader = csv.reader(csvfile) + next(csvreader) # Skip the header row + ground_truth_data = [ + extract_bounding_box_info(row) for row in csvreader + ] + + # Format as a dictionary with key = frame number + ground_truth_dict: dict = {} for data in ground_truth_data: frame_number = data["frame_number"] bbox = np.array( @@ -41,53 +87,26 @@ def create_gt_list( data["y"], data["x"] + data["width"], data["y"] + data["height"], - data["id"], ], dtype=np.float32, ) - if gt_boxes_list[frame_number].size == 0: - gt_boxes_list[frame_number] = bbox.reshape( - 1, -1 - ) # Initialize as a 2D array - else: - gt_boxes_list[frame_number] = np.vstack( - [gt_boxes_list[frame_number], bbox] - ) - return gt_boxes_list - - def get_ground_truth_data(self) -> list[np.ndarray]: - """ - Extract ground truth bounding box data from a CSV file. + track_id = int(float(data["id"])) - Parameters - ---------- - gt_dir : str - The path to the CSV file containing ground truth data. + if frame_number not in ground_truth_dict: + ground_truth_dict[frame_number] = {"bbox": [], "id": []} - Returns - ------- - list[np.ndarray]: - A list containing ground truth bounding box data organized by frame number. - The numpy array represent the coordinates and ID of the bounding box in the order: - x, y, x + width, y + height, ID - """ - ground_truth_data = [] - max_frame_number = 0 + ground_truth_dict[frame_number]["bbox"].append(bbox) + ground_truth_dict[frame_number]["id"].append(track_id) - # Open the CSV file and read its contents line by line - with open(self.gt_dir, "r") as csvfile: - csvreader = csv.reader(csvfile) - next(csvreader) # Skip the header row - for row in csvreader: - data = extract_bounding_box_info(row) - ground_truth_data.append(data) - max_frame_number = max(max_frame_number, data["frame_number"]) - - # Initialize a list to store the ground truth bounding boxes for each frame - gt_boxes_list = [np.array([]) for _ in range(max_frame_number + 1)] - - gt_boxes_list = self.create_gt_list(ground_truth_data, gt_boxes_list) - return gt_boxes_list + # format as numpy arrays + for frame_number in ground_truth_dict: + ground_truth_dict[frame_number]["bbox"] = np.array( + ground_truth_dict[frame_number]["bbox"], dtype=np.float32 + ) + ground_truth_dict[frame_number]["id"] = np.array( + ground_truth_dict[frame_number]["id"], dtype=np.float32 + ) + return ground_truth_dict def calculate_iou(self, box1: np.ndarray, box2: np.ndarray) -> float: """ @@ -131,127 +150,171 @@ def calculate_iou(self, box1: np.ndarray, box2: np.ndarray) -> float: def count_identity_switches( self, - prev_frame_ids: Optional[list[list[int]]], - current_frame_ids: Optional[list[list[int]]], + gt_to_tracked_id_previous_frame: Optional[Dict[int, int]], + gt_to_tracked_id_current_frame: Dict[int, int], ) -> int: """ Count the number of identity switches between two sets of object IDs. Parameters ---------- - prev_frame_ids : Optional[list[list[int]]] - List of object IDs in the previous frame. - current_frame_ids : Optional[list[list[int]]] - List of object IDs in the current frame. + gt_to_tracked_id_previous_frame : Optional[Dict[int, int]] + A dictionary mapping ground truth IDs to predicted IDs from the previous frame. + gt_to_tracked_id_current_frame : Dict[int, int] + A dictionary mapping ground truth IDs to predicted IDs for the current frame. Returns ------- int The number of identity switches between the two sets of object IDs. """ - - if prev_frame_ids is None or current_frame_ids is None: + if gt_to_tracked_id_previous_frame is None: return 0 - # Initialize count of identity switches - num_switches = 0 - - prev_ids = set(prev_frame_ids[0]) - current_ids = set(current_frame_ids[0]) - - # Calculate the number of switches by finding the difference in IDs - num_switches = len(prev_ids.symmetric_difference(current_ids)) - - return num_switches + switch_counter = 0 + + # Compute sets of ground truth IDs for current and previous frames + gt_ids_current_frame = set(gt_to_tracked_id_current_frame.keys()) + gt_ids_prev_frame = set(gt_to_tracked_id_previous_frame.keys()) + + # Compute lists of ground truth IDs that continue, disappear, and appear + gt_ids_cont = list(gt_ids_current_frame & gt_ids_prev_frame) + gt_ids_disappear = list(gt_ids_prev_frame - gt_ids_current_frame) + gt_ids_appear = list(gt_ids_current_frame - gt_ids_prev_frame) + + # Store used predicted IDs to avoid double counting + # In `used_pred_ids` we log IDs from either the current or the previous frame that have been involved in an already counted ID switch. + used_pred_ids = set() + + # Case 1: Objects that continue to exist + for gt_id in gt_ids_cont: + previous_pred_id = gt_to_tracked_id_previous_frame.get(gt_id) + current_pred_id = gt_to_tracked_id_current_frame.get(gt_id) + if not np.isnan(previous_pred_id) and not np.isnan( + current_pred_id + ): + if current_pred_id != previous_pred_id: + switch_counter += 1 + used_pred_ids.add(current_pred_id) + + # Case 2: Objects that disappear + for gt_id in gt_ids_disappear: + previous_pred_id = gt_to_tracked_id_previous_frame.get(gt_id) + if not np.isnan( + previous_pred_id + ): # Exclude if missed detection in previous frame + if previous_pred_id in gt_to_tracked_id_current_frame.values(): + if previous_pred_id not in used_pred_ids: + switch_counter += 1 + used_pred_ids.add(previous_pred_id) + + # Case 3: Objects that appear + for gt_id in gt_ids_appear: + current_pred_id = gt_to_tracked_id_current_frame.get(gt_id) + if not np.isnan( + current_pred_id + ): # Exclude if missed detection in current frame + if current_pred_id in gt_to_tracked_id_previous_frame.values(): + if previous_pred_id not in used_pred_ids: + switch_counter += 1 + + return switch_counter def evaluate_mota( self, - gt_boxes: np.ndarray, - tracked_boxes: np.ndarray, + gt_data: Dict[str, np.ndarray], + pred_data: Dict[str, np.ndarray], iou_threshold: float, - prev_frame_ids: Optional[list[list[int]]], - ) -> float: + gt_to_tracked_id_previous_frame: Optional[Dict[int, int]], + ) -> Tuple[float, Dict[int, int]]: """ Evaluate MOTA (Multiple Object Tracking Accuracy). - MOTA is a metric used to evaluate the performance of object tracking algorithms. - Parameters ---------- - gt_boxes : np.ndarray - Ground truth bounding boxes of objects. - tracked_boxes : np.ndarray - Tracked bounding boxes of objects. + gt_data : Dict[str, np.ndarray] + Dictionary containing ground truth bounding boxes and IDs. + - 'bbox': Bounding boxes with shape (N, 4). + - 'id': Ground truth IDs with shape (N,). + pred_data : Dict[str, np.ndarray] + Dictionary containing predicted bounding boxes and IDs. + - 'bbox': Bounding boxes with shape (N, 4). + - 'id': Predicted IDs with shape (N,). iou_threshold : float Intersection over Union (IoU) threshold for considering a match. - prev_frame_ids : Optional[list[list[int]]] - IDs from the previous frame for identity switch detection. + gt_to_tracked_id_previous_frame : Optional[Dict[int, int]] + A dictionary mapping ground truth IDs to predicted IDs from the previous frame. Returns ------- float The computed MOTA (Multi-Object Tracking Accuracy) score for the tracking performance. - - Notes - ----- - MOTA is calculated using the following formula: - - MOTA = 1 - (Missed Detections + False Positives + Identity Switches) / Total Ground Truth - - - Missed Detections: Instances where the ground truth objects were not detected by the tracking algorithm. - - False Positives: Instances where the tracking algorithm produces a detection where there is no corresponding ground truth object. - - Identity Switches: Instances where the tracking algorithm assigns a different ID to an object compared to its ID in the previous frame. - - Total Ground Truth: The total number of ground truth objects in the scene. - - The MOTA score ranges from 0 to 1, with higher values indicating better tracking performance. - A MOTA score of 1 indicates perfect tracking, where there are no missed detections, false positives, or identity switches. + Dict[int, int] + A dictionary mapping ground truth IDs to predicted IDs for the current frame. """ - total_gt = len(gt_boxes) + total_gt = len(gt_data["bbox"]) false_positive = 0 + indices_of_matched_gt_boxes = set() + gt_to_tracked_id_current_frame = {} - for i, tracked_box in enumerate(tracked_boxes): + pred_boxes = pred_data["bbox"] + pred_ids = pred_data["id"] + + gt_boxes = gt_data["bbox"] + gt_ids = gt_data["id"] + + for i, (pred_box, pred_id) in enumerate(zip(pred_boxes, pred_ids)): best_iou = 0.0 - best_match = None + index_gt_best_match = None + index_gt_not_match = None for j, gt_box in enumerate(gt_boxes): - iou = self.calculate_iou(gt_box[:4], tracked_box[:4]) - if iou > iou_threshold and iou > best_iou: - best_iou = iou - best_match = j - if best_match is not None: - # successfully found a matching ground truth box for the tracked box. - # set the corresponding ground truth box to None. - gt_boxes[best_match] = None + if j not in indices_of_matched_gt_boxes: + iou = self.calculate_iou(gt_box, pred_box) + if iou > iou_threshold and iou > best_iou: + best_iou = iou + index_gt_best_match = j + else: + index_gt_not_match = j + + if index_gt_best_match is not None: + # Successfully found a matching ground truth box for the tracked box. + indices_of_matched_gt_boxes.add(index_gt_best_match) + # Map ground truth ID to tracked ID + gt_to_tracked_id_current_frame[ + int(gt_ids[index_gt_best_match]) + ] = int(pred_id) else: false_positive += 1 + if index_gt_not_match is not None: + gt_to_tracked_id_current_frame[ + int(gt_ids[index_gt_not_match]) + ] = np.nan - missed_detections = 0 - for box in gt_boxes: - if box is not None and not np.all(np.isnan(box)): - # if true ground truth box was not matched with any tracked box - missed_detections += 1 - - tracked_ids = [[box[-1] for box in tracked_boxes]] - + missed_detections = total_gt - len(indices_of_matched_gt_boxes) num_switches = self.count_identity_switches( - prev_frame_ids, tracked_ids + gt_to_tracked_id_previous_frame, gt_to_tracked_id_current_frame ) mota = ( 1 - (missed_detections + false_positive + num_switches) / total_gt ) - return mota + return mota, gt_to_tracked_id_current_frame - def evaluate_tracking(self, gt_boxes_list: list) -> list[float]: + def evaluate_tracking( + self, + ground_truth_dict: Dict[int, Dict[str, Any]], + predicted_dict: Dict[int, Dict[str, Any]], + ) -> list[float]: """ Evaluate tracking performance using the Multi-Object Tracking Accuracy (MOTA) metric. Parameters ---------- - gt_boxes_list : list[list[float]] - List of ground truth bounding boxes for each frame. - tracked_boxes_list : list[list[float]] - List of tracked bounding boxes for each frame. + ground_truth_dict : dict + Dictionary containing ground truth bounding boxes and IDs for each frame, organized by frame number. + predicted_dict : dict + Dictionary containing predicted bounding boxes and IDs for each frame, organized by frame number. Returns ------- @@ -259,17 +322,20 @@ def evaluate_tracking(self, gt_boxes_list: list) -> list[float]: The computed MOTA (Multi-Object Tracking Accuracy) score for the tracking performance. """ mota_values = [] - prev_frame_ids: Optional[list[list[int]]] = None - for gt_boxes, tracked_boxes in zip(gt_boxes_list, self.tracked_list): - mota = self.evaluate_mota( - gt_boxes, - tracked_boxes, - self.iou_threshold, - prev_frame_ids, - ) - mota_values.append(mota) - # Update previous frame IDs for the next iteration - prev_frame_ids = [[box[-1] for box in tracked_boxes]] + prev_frame_id_map: Optional[dict] = None + + for frame_number in sorted(ground_truth_dict.keys()): + gt_data_frame = ground_truth_dict[frame_number] + + if frame_number < len(predicted_dict): + pred_data_frame = predicted_dict[frame_number] + mota, prev_frame_id_map = self.evaluate_mota( + gt_data_frame, + pred_data_frame, + self.iou_threshold, + prev_frame_id_map, + ) + mota_values.append(mota) return mota_values @@ -277,7 +343,8 @@ def run_evaluation(self) -> None: """ Run evaluation of tracking based on tracking ground truth. """ - gt_boxes_list = self.get_ground_truth_data() - mota_values = self.evaluate_tracking(gt_boxes_list) + predicted_dict = self.get_predicted_data() + ground_truth_dict = self.get_ground_truth_data() + mota_values = self.evaluate_tracking(ground_truth_dict, predicted_dict) overall_mota = np.mean(mota_values) logging.info("Overall MOTA: %f" % overall_mota) diff --git a/crabs/tracker/track_video.py b/crabs/tracker/track_video.py index 99d1c211..e65fd07c 100644 --- a/crabs/tracker/track_video.py +++ b/crabs/tracker/track_video.py @@ -22,8 +22,6 @@ ) from crabs.tracker.utils.tracking import prep_sort -DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - class Tracking: """ @@ -47,15 +45,13 @@ class Tracking: def __init__(self, args: argparse.Namespace) -> None: self.args = args - self.config_file = args.config_file - self.load_config_yaml() # TODO: load config from trained model (like in evaluation)? - self.video_path = args.video_path - self.video_file_root = f"{Path(self.video_path).stem}" self.trained_model_path = self.args.trained_model_path + self.device = self.args.device - self.trained_model = self.load_trained_model() + self.setup() + self.prep_outputs() self.sort_tracker = Sort( max_age=self.config["max_age"], @@ -63,42 +59,35 @@ def __init__(self, args: argparse.Namespace) -> None: iou_threshold=self.config["iou_threshold"], ) - ( - self.csv_writer, - self.csv_file, - self.tracking_output_dir, - ) = prep_csv_writer(self.args.output_dir, self.video_file_root) - - def load_config_yaml(self): + def setup(self): """ - Load yaml file that contains config parameters. + Load tracking config, trained model and input video path. """ with open(self.config_file, "r") as f: self.config = yaml.safe_load(f) - def load_trained_model(self) -> torch.nn.Module: - """ - Load the trained model. - - Returns - ------- - torch.nn.Module - """ # Get trained model - trained_model = FasterRCNN.load_from_checkpoint( + self.trained_model = FasterRCNN.load_from_checkpoint( self.trained_model_path ) - trained_model.eval() - trained_model.to(DEVICE) # Should device be a CLI? - return trained_model + self.trained_model.eval() + self.trained_model.to(self.device) - def load_video(self) -> None: - """ - Load the input video, and prepare the output video if required. - """ + # Load the input video self.video = cv2.VideoCapture(self.video_path) if not self.video.isOpened(): raise Exception("Error opening video file") + self.video_file_root = f"{Path(self.video_path).stem}" + + def prep_outputs(self): + """ + Prepare csv writer and if required, video writer. + """ + ( + self.csv_writer, + self.csv_file, + self.tracking_output_dir, + ) = prep_csv_writer(self.args.output_dir, self.video_file_root) if self.args.save_video: frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH)) @@ -107,7 +96,6 @@ def load_video(self) -> None: self.video_output = prep_video_writer( self.tracking_output_dir, - self.video_file_root, frame_width, frame_height, cap_fps, @@ -135,7 +123,7 @@ def get_prediction(self, frame: np.ndarray) -> torch.Tensor: transforms.ToDtype(torch.float32, scale=True), ] ) - img = transform(frame).to(DEVICE) + img = transform(frame).to(self.device) img = img.unsqueeze(0) with torch.no_grad(): prediction = self.trained_model(img) @@ -156,9 +144,9 @@ def update_tracking(self, prediction: dict) -> list[list[float]]: list of tracked bounding boxes after updating the tracking system. """ pred_sort = prep_sort(prediction, self.config["score_threshold"]) - tracked_boxes = self.sort_tracker.update(pred_sort) - self.tracked_list.append(tracked_boxes) - return tracked_boxes + tracked_boxes_id_per_frame = self.sort_tracker.update(pred_sort) + self.tracked_bbox_id.append(tracked_boxes_id_per_frame) + return tracked_boxes_id_per_frame def run_tracking(self): """ @@ -171,24 +159,31 @@ def run_tracking(self): ) return - # In any case run inference # initialisation - frame_number = 1 - self.tracked_list = [] + frame_idx = 0 + self.tracked_bbox_id = [] # Loop through frames of the video in batches while self.video.isOpened(): # Break if beyond end frame (mostly for debugging) if ( self.args.max_frames_to_read - and frame_number > self.args.max_frames_to_read + and frame_idx + 1 > self.args.max_frames_to_read ): break + # get total n frames + total_frames = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT)) + # read frame ret, frame = self.video.read() - if not ret: - print("No frame read. Exiting...") + if not ret and (frame_idx == total_frames): + logging.info(f"All {total_frames} frames processed") + break + elif not ret: + logging.info( + f"Cannot read frame {frame_idx+1}/{total_frames}. Exiting..." + ) break # predict bounding boxes @@ -196,7 +191,7 @@ def run_tracking(self): pred_scores = prediction[0]["scores"].detach().cpu().numpy() # run tracking - tracked_boxes = self.update_tracking(prediction) + tracked_boxes_id_per_frame = self.update_tracking(prediction) save_required_output( self.video_file_root, self.args.save_frames, @@ -204,19 +199,19 @@ def run_tracking(self): self.csv_writer, self.args.save_video, self.video_output, - tracked_boxes, + tracked_boxes_id_per_frame, frame, - frame_number, + frame_idx + 1, pred_scores, ) # update frame number - frame_number += 1 + frame_idx += 1 if self.args.gt_path: evaluation = TrackerEvaluate( self.args.gt_path, - self.tracked_list, + self.tracked_bbox_id, self.config["iou_threshold"], ) evaluation.run_evaluation() @@ -247,7 +242,6 @@ def main(args) -> None: """ inference = Tracking(args) - inference.load_video() inference.run_tracking() @@ -277,7 +271,7 @@ def tracking_parse_args(args): parser.add_argument( "--output_dir", type=str, - default="crabs_track_output", + default="tracking_output", help="Directory to save the track output", # is this a csv or a video? (or both) ) parser.add_argument( @@ -305,6 +299,12 @@ def tracking_parse_args(args): action="store_true", help="Save frame to be used in correcting track labelling", ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="device for pytorch either cpu or cuda", + ) return parser.parse_args(args) diff --git a/crabs/tracker/utils/io.py b/crabs/tracker/utils/io.py index 36d0e544..b01b9750 100644 --- a/crabs/tracker/utils/io.py +++ b/crabs/tracker/utils/io.py @@ -1,5 +1,6 @@ import csv import os +from datetime import datetime from pathlib import Path import cv2 @@ -29,13 +30,13 @@ def prep_csv_writer(output_dir: str, video_file_root: str): A tuple containing the CSV writer, the CSV file object, and the tracking output directory path. """ - crabs_tracks_label_dir = Path(output_dir) / "crabs_tracks_label" - tracking_output_dir = crabs_tracks_label_dir / video_file_root + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + tracking_output_dir = Path(output_dir + f"_{timestamp}") / video_file_root # Create the subdirectory for the specific video file root tracking_output_dir.mkdir(parents=True, exist_ok=True) csv_file = open( - f"{str(tracking_output_dir / video_file_root)}.csv", + f"{str(tracking_output_dir)}/predicted_tracks.csv", "w", ) csv_writer = csv.writer(csv_file) @@ -59,7 +60,6 @@ def prep_csv_writer(output_dir: str, video_file_root: str): def prep_video_writer( output_dir: str, - video_file_root: str, frame_width: int, frame_height: int, cap_fps: float, @@ -87,7 +87,7 @@ def prep_video_writer( """ output_file = os.path.join( output_dir, - f"{os.path.basename(video_file_root)}_output_video.mp4", + "tracked_video.mp4", ) output_codec = cv2.VideoWriter_fourcc("m", "p", "4", "v") video_output = cv2.VideoWriter( diff --git a/crabs/tracker/utils/tracking.py b/crabs/tracker/utils/tracking.py index 3954ab14..b18e9045 100644 --- a/crabs/tracker/utils/tracking.py +++ b/crabs/tracker/utils/tracking.py @@ -31,7 +31,7 @@ def extract_bounding_box_info(row: list[str]) -> Dict[str, Any]: height = region_shape_attributes["height"] track_id = region_attributes["track"] - frame_number = int(filename.split("_")[-1].split(".")[0]) - 1 + frame_number = int(filename.split("_")[-1].split(".")[0]) return { "frame_number": frame_number, "x": x, diff --git a/tests/data/gt_test.csv b/tests/data/gt_test.csv index bec6d39f..1c130e07 100644 --- a/tests/data/gt_test.csv +++ b/tests/data/gt_test.csv @@ -1,4 +1,4 @@ filename,file_size,file_attributes,region_count,region_id,region_shape_attributes,region_attributes -frame_00000001.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":2894.860594987354,""y"":975.8516839863181,""width"":51,""height"":41}","{""track"":""2.0""}" -frame_00000001.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":940.6088870891139,""y"":1192.6369631796642,""width"":49,""height"":38}","{""track"":""1.0""}" -frame_00000002.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":940.6088870891139,""y"":1192.6369631796642,""width"":49,""height"":38}","{""track"":""2.0""}" +frame_00000011.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":2894.860594987354,""y"":975.8516839863181,""width"":51,""height"":41}","{""track"":""2.0""}" +frame_00000011.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":940.6088870891139,""y"":1192.6369631796642,""width"":49,""height"":38}","{""track"":""1.0""}" +frame_00000021.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":940.6088870891139,""y"":1192.6369631796642,""width"":49,""height"":38}","{""track"":""2.0""}" diff --git a/tests/test_unit/test_evaluate_tracker.py b/tests/test_unit/test_evaluate_tracker.py new file mode 100644 index 00000000..9ccd645c --- /dev/null +++ b/tests/test_unit/test_evaluate_tracker.py @@ -0,0 +1,452 @@ +from pathlib import Path + +import numpy as np +import pytest + +from crabs.tracker.evaluate_tracker import TrackerEvaluate + + +@pytest.fixture +def evaluation(): + test_csv_file = Path(__file__).parents[1] / "data" / "gt_test.csv" + return TrackerEvaluate( + test_csv_file, predicted_boxes_id=[], iou_threshold=0.1 + ) + + +def test_get_ground_truth_data(evaluation): + ground_truth_dict = evaluation.get_ground_truth_data() + + assert isinstance(ground_truth_dict, dict) + assert all( + isinstance(frame_data, dict) + for frame_data in ground_truth_dict.values() + ) + + for frame_number, data in ground_truth_dict.items(): + assert isinstance(frame_number, int) + assert isinstance(data["bbox"], np.ndarray) + assert isinstance(data["id"], np.ndarray) + assert data["bbox"].shape[1] == 4 + + +def test_ground_truth_data_from_csv(evaluation): + expected_data = { + 11: { + "bbox": np.array( + [ + [2894.8606, 975.8517, 2945.8606, 1016.8517], + [940.6089, 1192.637, 989.6089, 1230.637], + ], + dtype=np.float32, + ), + "id": np.array([2.0, 1.0], dtype=np.float32), + }, + 21: { + "bbox": np.array( + [[940.6089, 1192.637, 989.6089, 1230.637]], dtype=np.float32 + ), + "id": np.array([2.0], dtype=np.float32), + }, + } + + ground_truth_dict = evaluation.get_ground_truth_data() + + for frame_number, expected_frame_data in expected_data.items(): + assert frame_number in ground_truth_dict + + assert len(ground_truth_dict[frame_number]["bbox"]) == len( + expected_frame_data["bbox"] + ) + for bbox, expected_bbox in zip( + ground_truth_dict[frame_number]["bbox"], + expected_frame_data["bbox"], + ): + assert np.allclose( + bbox, expected_bbox + ), f"Frame {frame_number}, bbox mismatch" + + assert np.array_equal( + ground_truth_dict[frame_number]["id"], expected_frame_data["id"] + ), f"Frame {frame_number}, id mismatch" + + +@pytest.mark.parametrize( + "prev_frame_id_map, current_frame_id_map, expected_output", + [ + (None, {1: 11, 2: 12, 3: 13, 4: 14}, 0), # no previous frame + # ----- a crab (GT=3) that continues to exist --------- + ( + {1: 11, 2: 12, 3: 13, 4: 14}, + {1: 11, 2: 12, 3: 13, 4: 14}, + 0, + ), # correct + ( + {1: 11, 2: 12, 3: 13, 4: 14}, + {1: 11, 2: 12, 3: np.nan, 4: 14}, + 0, + ), # crab is missed detection in current frame + ( + {1: 11, 2: 12, 3: np.nan, 4: 14}, + {1: 11, 2: 12, 3: 13, 4: 14}, + 0, + ), # crab is missed detection in previous frame + ( + {1: 11, 2: 12, 3: 13, 4: 14}, + {1: 11, 2: 12, 3: 15, 4: 14}, + 1, + ), # crab is re-IDed in current frame + ( + {1: 11, 2: 12, 3: 13, 4: 14}, + {1: 11, 2: 12, 3: 14}, + 1, + ), # crab swaps ID with a disappearing crab + ( + {1: 11, 2: 12, 3: 13}, + {1: 11, 2: 12, 4: 13}, + 1, + ), # disappear crab swaps ID with an appearing crab + ( + {1: 11, 2: 12, 3: 13}, + {1: 11, 2: 12, 3: 99, 4: 13}, + 2, + ), # crab swaps ID with an appearing crab + ( + {1: 11, 2: 12, 3: 13, 4: 14}, + {1: 11, 2: 12, 3: 14, 4: 13}, + 2, + ), # crab swaps ID with another crab that continues to exist + # ----- a crab (GT=4) disappears --------- + ( + {1: 11, 2: 12, 3: 13, 4: 14}, + {1: 11, 2: 12, 3: 13}, + 0, + ), # correct + ( + {1: 11, 2: 12, 3: 13, 4: 14}, + {1: 11, 2: 12, 3: 14}, + 1, + ), # crab disappears and another pre-existing one takes its ID + ( + {1: 11, 2: 12, 3: 13, 4: 14}, + {1: 11, 2: 12, 3: 13, 5: 14}, + 1, + ), # crab disappears and an appearing one takes its ID + ( + {1: 11, 2: 12, 3: 13, 4: np.nan}, + {1: 11, 2: 12, 3: 13}, + 0, + ), # crab disappears but was missed detection in frame f-1 + ( + {1: 11, 2: 12, 3: 13, 4: np.nan}, + {1: 11, 2: 12, 3: 13, 5: np.nan}, + 0, + ), # crab disappears but was missed detection in frame f-1, with a new missed crab in frame f + ( + {1: 11, 2: 12, 3: 13, 4: np.nan}, + {1: 11, 2: 12, 3: np.nan}, + 0, + ), # crab disappears but was missed detection in frame f-1, and existing crab was missed in frame f + # ----- a crab (GT=4) appears --------- + ( + {1: 11, 2: 12, 3: 13}, + {1: 11, 2: 12, 3: 13, 4: 14}, + 0, + ), # correct + ( + {1: 11, 2: 12, 3: 14}, + {1: 11, 2: 12, 3: 13, 4: 14}, + 2, + ), # crab that appears gets ID of a pre-existing crab + ( + {1: 11, 2: 12, 3: 13}, + {1: 11, 2: 12, 4: 13}, + 1, + ), # crab that appears gets ID of a crab that disappears + ( + {1: 11, 2: 12, 3: 13}, + {1: 11, 2: 12, 3: 13, 4: np.nan}, + 0, + ), # missed detection in current frame + ( + {1: 11, 2: 12, 3: 13, 5: np.nan}, + {1: 11, 2: 12, 3: 13, 4: np.nan}, + 0, + ), # crab that appears is missed detection in current frame, and another missed detection in previous frame disappears + ( + {1: 11, 2: 12, 3: np.nan}, + {1: 11, 2: 12, 3: 13, 4: np.nan}, + 0, + ), # crab that appears is missed detection in current frame, and a pre-existing crab is missed detection in previous frame + ], +) +def test_count_identity_switches( + evaluation, prev_frame_id_map, current_frame_id_map, expected_output +): + assert ( + evaluation.count_identity_switches( + prev_frame_id_map, current_frame_id_map + ) + == expected_output + ) + + +@pytest.mark.parametrize( + "box1, box2, expected_iou", + [ + ([0, 0, 10, 10], [5, 5, 12, 12], 0.25), + ([0, 0, 10, 10], [0, 0, 10, 10], 1.0), + ([0, 0, 10, 10], [20, 20, 30, 30], 0.0), + ([0, 0, 10, 10], [5, 15, 15, 25], 0.0), + ], +) +def test_calculate_iou(box1, box2, expected_iou, evaluation): + box1 = np.array(box1) + box2 = np.array(box2) + + iou = evaluation.calculate_iou(box1, box2) + + # Check if IoU matches expected value + assert iou == pytest.approx(expected_iou, abs=1e-2) + + +@pytest.mark.parametrize( + "gt_data, pred_data, prev_frame_id_map, expected_mota", + [ + # perfect tracking + ( + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([1, 2, 3]), + }, + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([11, 12, 13]), + }, + {1: 11, 2: 12, 3: 13}, + 1.0, + ), + ( + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([1, 2, 3]), + }, + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([11, 12, 13]), + }, + {1: 11, 12: 2, 3: np.nan}, + 1.0, + ), + # ID switch + ( + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([1, 2, 3]), + }, + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([11, 12, 14]), + }, + {1: 11, 2: 12, 3: 13}, + 2 / 3, + ), + # missed detection + ( + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([1, 2, 4]), + }, + { + "bbox": np.array( + [[10.0, 10.0, 20.0, 20.0], [30.0, 30.0, 40.0, 40.0]] + ), + "id": np.array([11, 12]), + }, + {1: 11, 2: 12, 3: 13}, + 2 / 3, + ), + # false positive + ( + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([1, 2, 3]), + }, + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + [70.0, 70.0, 80.0, 80.0], + ] + ), + "id": np.array([11, 12, 13, 14]), + }, + {1: 11, 2: 12, 3: 13}, + 2 / 3, + ), + # low IOU and ID switch + ( + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([1, 2, 3]), + }, + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 30.0, 30.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([11, 12, 14]), + }, + {1: 11, 2: 12, 3: 13}, + 0, + ), + # low IOU and ID switch on same box + ( + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([1, 2, 3]), + }, + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 30.0, 30.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([11, 14, 13]), + }, + {1: 11, 2: 12, 3: 13}, + 1 / 3, + ), + # current tracked id = prev tracked id, but prev_gt_id != current gt id + ( + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([1, 2, 4]), + }, + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([11, 12, 13]), + }, + {1: 11, 2: 12, 3: 13}, + 2 / 3, + ), + # ID swapped + ( + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([1, 2, 3]), + }, + { + "bbox": np.array( + [ + [10.0, 10.0, 20.0, 20.0], + [30.0, 30.0, 40.0, 40.0], + [50.0, 50.0, 60.0, 60.0], + ] + ), + "id": np.array([11, 13, 12]), + }, + {1: 11, 2: 12, 3: 13}, + 1 / 3, + ), + ], +) +def test_evaluate_mota( + gt_data, + pred_data, + prev_frame_id_map, + expected_mota, + evaluation, +): + mota, _ = evaluation.evaluate_mota( + gt_data, + pred_data, + 0.1, + prev_frame_id_map, + ) + assert mota == pytest.approx(expected_mota) diff --git a/tests/test_unit/test_track_video.py b/tests/test_unit/test_track_video.py new file mode 100644 index 00000000..6833242b --- /dev/null +++ b/tests/test_unit/test_track_video.py @@ -0,0 +1,60 @@ +import tempfile +from argparse import Namespace +from pathlib import Path +from unittest.mock import MagicMock, mock_open, patch + +import pytest + +from crabs.tracker.track_video import Tracking + + +@pytest.fixture +def mock_args(): + temp_dir = tempfile.mkdtemp() + + return Namespace( + config_file="/path/to/config.yaml", + video_path="/path/to/video.mp4", + trained_model_path="/path/to/model.ckpt", + output_dir=temp_dir, + device="cuda", + gt_path=None, + save_video=None, + ) + + +@patch( + "builtins.open", + new_callable=mock_open, + read_data="max_age: 10\nmin_hits: 3\niou_threshold: 0.1", +) +@patch("yaml.safe_load") +@patch("cv2.VideoCapture") +@patch("crabs.tracker.track_video.FasterRCNN.load_from_checkpoint") +@patch("crabs.tracker.track_video.Sort") +def test_tracking_setup( + mock_sort, + mock_load_from_checkpoint, + mock_videocapture, + mock_yaml_load, + mock_open, + mock_args, +): + mock_yaml_load.return_value = { + "max_age": 10, + "min_hits": 3, + "iou_threshold": 0.1, + } + + mock_model = MagicMock() + mock_load_from_checkpoint.return_value = mock_model + + mock_video_capture = MagicMock() + mock_video_capture.isOpened.return_value = True + mock_videocapture.return_value = mock_video_capture + + tracker = Tracking(mock_args) + + assert tracker.args.output_dir == mock_args.output_dir + + Path(mock_args.output_dir).rmdir() diff --git a/tests/test_unit/test_tracking_evaluation.py b/tests/test_unit/test_tracking_evaluation.py deleted file mode 100644 index b2ee039f..00000000 --- a/tests/test_unit/test_tracking_evaluation.py +++ /dev/null @@ -1,232 +0,0 @@ -from pathlib import Path - -import numpy as np -import pytest - -from crabs.tracker.evaluate_tracker import TrackerEvaluate - - -@pytest.fixture -def evaluation(): - test_csv_file = Path(__file__).parents[1] / "data" / "gt_test.csv" - return TrackerEvaluate(test_csv_file, tracked_list=[], iou_threshold=0.1) - - -def test_get_ground_truth_data(evaluation): - gt_data = evaluation.get_ground_truth_data() - - assert len(gt_data) == 2 - - for i, frame_data in enumerate(gt_data): - for j, detection_data in enumerate(frame_data): - assert detection_data.shape == ( - 5, - ), f"Detection data shape mismatch for frame {i}" - - expected_ids = [2.0, 1.0] - for i, frame_data in enumerate(gt_data): - for j, detection_data in enumerate(frame_data): - assert ( - detection_data[4] == expected_ids[j] - ), f"Failed for frame {i}, detection {j}" - - -@pytest.fixture -def ground_truth_data(): - return [ - { - "frame_number": 0, - "x": 10, - "y": 20, - "width": 30, - "height": 40, - "id": 1, - }, - { - "frame_number": 0, - "x": 50, - "y": 60, - "width": 70, - "height": 80, - "id": 2, - }, - { - "frame_number": 1, - "x": 100, - "y": 200, - "width": 300, - "height": 400, - "id": 1, - }, - ] - - -@pytest.fixture -def gt_boxes_list(): - return [np.array([]) for _ in range(2)] # Two frames - - -def test_create_gt_list(ground_truth_data, gt_boxes_list, evaluation): - created_gt = evaluation.create_gt_list(ground_truth_data, gt_boxes_list) - - assert isinstance(created_gt, list) - - for item in created_gt: - assert isinstance(item, np.ndarray) - - assert len(created_gt) == len(gt_boxes_list) - - for i, array in enumerate(created_gt): - for box in array: - assert box.shape == (5,) - - i = 0 - for gt_created in created_gt: - for frame_number in range(len(gt_created)): - gt_data = ground_truth_data[i] - gt_boxes = gt_created[frame_number] - - assert gt_boxes[0] == gt_data["x"] - assert gt_boxes[1] == gt_data["y"] - assert gt_boxes[2] == gt_data["x"] + gt_data["width"] - assert gt_boxes[3] == gt_data["y"] + gt_data["height"] - assert gt_boxes[4] == gt_data["id"] - i += 1 - - -def test_create_gt_list_invalid_data(ground_truth_data, evaluation): - invalid_data = ground_truth_data[:] - - del invalid_data[0]["x"] - with pytest.raises(KeyError): - evaluation.create_gt_list( - invalid_data, [np.array([]) for _ in range(2)] - ) - - -def test_create_gt_list_insufficient_gt_boxes_list( - ground_truth_data, evaluation -): - with pytest.raises(IndexError): - evaluation.create_gt_list(ground_truth_data, [np.array([])]) - - -@pytest.mark.parametrize( - "prev_frame_id, current_frame_id, expected_output", - [ - (None, [[6, 5, 4, 3, 2, 1]], 0), - ( - [[6, 5, 4, 3, 2, 1]], - [[6, 5, 4, 3, 2, 1]], - 0, - ), # no identity switches - ([[5, 6, 4, 3, 1, 2]], [[6, 5, 4, 3, 2, 1]], 0), - ([[6, 5, 4, 3, 2, 1]], [[6, 5, 4, 2, 1]], 1), - ([[6, 5, 4, 2, 1]], [[6, 5, 4, 2, 1, 7]], 1), - ([[6, 5, 4, 2, 1, 7]], [[6, 5, 4, 2, 7, 8]], 2), - ([[6, 5, 4, 2, 7, 8]], [[6, 5, 4, 2, 7, 8, 3]], 1), - ], -) -def test_count_identity_switches( - evaluation, prev_frame_id, current_frame_id, expected_output -): - assert ( - evaluation.count_identity_switches(prev_frame_id, current_frame_id) - == expected_output - ) - - -@pytest.mark.parametrize( - "box1, box2, expected_iou", - [ - ([0, 0, 10, 10], [5, 5, 12, 12], 0.25), - ([0, 0, 10, 10], [0, 0, 10, 10], 1.0), - ([0, 0, 10, 10], [20, 20, 30, 30], 0.0), - ([0, 0, 10, 10], [5, 15, 15, 25], 0.0), - ], -) -def test_calculate_iou(box1, box2, expected_iou, evaluation): - box1 = np.array(box1) - box2 = np.array(box2) - - iou = evaluation.calculate_iou(box1, box2) - - # Check if IoU matches expected value - assert iou == pytest.approx(expected_iou, abs=1e-2) - - -@pytest.fixture -def gt_boxes(): - return np.array( - [ - [10.0, 10.0, 20.0, 20.0, 1.0], - [30.0, 30.0, 40.0, 40.0, 2.0], - [50.0, 50.0, 60.0, 60.0, 3.0], - ] - ) - - -@pytest.fixture -def tracked_boxes(): - return np.array( - [ - [10.0, 10.0, 20.0, 20.0, 1.0], - [30.0, 30.0, 40.0, 40.0, 2.0], - [50.0, 50.0, 60.0, 60.0, 3.0], - ] - ) - - -@pytest.fixture -def prev_frame_ids(): - return [[1.0, 2.0, 3.0]] - - -def test_perfect_tracking(gt_boxes, tracked_boxes, prev_frame_ids, evaluation): - mota = evaluation.evaluate_mota( - gt_boxes, - tracked_boxes, - iou_threshold=0.1, - prev_frame_ids=prev_frame_ids, - ) - assert mota == pytest.approx(1.0) - - -def test_missed_detections( - gt_boxes, tracked_boxes, prev_frame_ids, evaluation -): - # Remove one ground truth box to simulate a missed detection - gt_boxes = np.delete(gt_boxes, 0, axis=0) - mota = evaluation.evaluate_mota( - gt_boxes, - tracked_boxes, - iou_threshold=0.1, - prev_frame_ids=prev_frame_ids, - ) - assert mota < 1.0 - - -def test_false_positives(gt_boxes, tracked_boxes, prev_frame_ids, evaluation): - # Add one extra tracked box to simulate a false positive - tracked_boxes = np.vstack([tracked_boxes, [70, 70, 80, 80, 4]]) - mota = evaluation.evaluate_mota( - gt_boxes, - tracked_boxes, - iou_threshold=0.1, - prev_frame_ids=prev_frame_ids, - ) - assert mota < 1.0 - - -def test_identity_switches( - gt_boxes, tracked_boxes, prev_frame_ids, evaluation -): - # Change ID of one tracked box to simulate an identity switch - tracked_boxes[0][-1] = 5 - mota = evaluation.evaluate_mota( - gt_boxes, - tracked_boxes, - iou_threshold=0.5, - prev_frame_ids=prev_frame_ids, - ) - assert mota < 1.0 diff --git a/tests/test_unit/test_tracking_utils.py b/tests/test_unit/test_tracking_utils.py index 37f581b0..3550f134 100644 --- a/tests/test_unit/test_tracking_utils.py +++ b/tests/test_unit/test_tracking_utils.py @@ -24,7 +24,7 @@ def test_extract_bounding_box_info(): result = extract_bounding_box_info(csv_row) expected_result = { - "frame_number": 0, + "frame_number": 1, "x": 2894.860594987354, "y": 975.8516839863181, "width": 51,