From e8a2bca8b9f290fab113f92b63f4de3a27ad5b3c Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 27 Jun 2024 11:08:31 +0100 Subject: [PATCH] Evaluate script: review of command-line arguments (#172) * move eval params to config (WIP) * follow train CLI: add debugger options, add experiment name, use score_threshold from config * fix prettier * edit CLI defaults and get dataset params from ckpt if not defined (WIP) * fix ninja comma * Add sections to config * Rename to evaluate utils * Match current train script and add slurm logs as artifacts * Fix evaluate_utils * Use config from ckpt if not passed. Use dataset, annot files and seed from ckpt if not passed. * Clarify CLI help (hopefully) * Add score threshold for visualisation as CLI argument * Small fix to config yaml * Clean up * Fix save frames and add output_dir * Fix tests * Move get_ functions to evaluate utils * Replace assert by try-except --- .../config/faster_rcnn.yaml | 32 ++- crabs/detection_tracking/evaluate.py | 97 ------- crabs/detection_tracking/evaluate_model.py | 217 ++++++++++------ crabs/detection_tracking/evaluate_utils.py | 240 ++++++++++++++++++ crabs/detection_tracking/models.py | 4 +- crabs/detection_tracking/visualization.py | 16 +- tests/test_unit/test_evaluate.py | 2 +- tests/test_unit/test_visualization.py | 19 +- 8 files changed, 434 insertions(+), 193 deletions(-) delete mode 100644 crabs/detection_tracking/evaluate.py create mode 100644 crabs/detection_tracking/evaluate_utils.py diff --git a/crabs/detection_tracking/config/faster_rcnn.yaml b/crabs/detection_tracking/config/faster_rcnn.yaml index ac6b4c51..cbdfaad3 100644 --- a/crabs/detection_tracking/config/faster_rcnn.yaml +++ b/crabs/detection_tracking/config/faster_rcnn.yaml @@ -1,12 +1,22 @@ +# Dataset +#------------- +train_fraction: 0.8 +val_over_test_fraction: 0.5 +num_workers: 4 + +# ------------------- +# Model architecture +# ------------------- +num_classes: 2 + +# ------------------------------- +# Training & validation parameters +# ------------------------------- n_epochs: 250 learning_rate: 0.00005 wdecay: 0.00005 batch_size_train: 4 -batch_size_test: 4 batch_size_val: 4 -num_classes: 2 -train_fraction: 0.8 -val_over_test_fraction: 0.5 checkpoint_saving: every_n_epochs: 50 keep_last_n_ckpts: 5 @@ -17,8 +27,16 @@ checkpoint_saving: # if all, all checkpoints for every epoch are added as artifacts during training, # if False, no checkpoints are added as artifacts. save_last: True + +# ----------------------- +# Evaluation parameters +# ----------------------- iou_threshold: 0.1 -num_workers: 4 +batch_size_test: 4 + +# ------------------- +# Data augmentation +# ------------------- transform_brightness: 0.5 transform_hue: 0.3 gaussian_blur_params: @@ -28,6 +46,10 @@ gaussian_blur_params: sigma: - 0.1 - 5.0 + +# ---------------------------- +# Hyperparameter optimisation +# ----------------------------- # when we run Optuna, the n_trials and n_epochs above will be overwritten by the parameters set by Optuna optuna: # Parameters for hyperparameter optimisation with Optuna: diff --git a/crabs/detection_tracking/evaluate.py b/crabs/detection_tracking/evaluate.py deleted file mode 100644 index af82e46e..00000000 --- a/crabs/detection_tracking/evaluate.py +++ /dev/null @@ -1,97 +0,0 @@ -import logging - -import torchvision - -logging.basicConfig(level=logging.INFO) - - -def compute_precision_recall(class_stats) -> tuple[float, float, dict]: - """ - Compute precision and recall. - - Parameters - ---------- - class_stats : dict - Statistics or information about different classes. - - Returns - ---------- - Tuple[float, float] - precision and recall - """ - for _, stats in class_stats.items(): - precision = stats["tp"] / max(stats["tp"] + stats["fp"], 1) - recall = stats["tp"] / max(stats["tp"] + stats["fn"], 1) - - return precision, recall, class_stats - - -def compute_confusion_matrix_elements( - targets, detections, ious_threshold -) -> tuple[float, float, dict]: - """ - Compute metrics (true positive, false positive, false negative) for object detection. - - Parameters - ---------- - targets : list - Ground truth annotations. - detections : list - Detected objects. - ious_threshold : float - The threshold value for the intersection-over-union (IOU). - Only detections whose IOU relative to the ground truth is above the - threshold are true positive candidates. - class_stats : dict - Statistics or information about different classes. - - Returns - ---------- - Tuple[float, float] - precision and recall - """ - class_stats = {"crab": {"tp": 0, "fp": 0, "fn": 0}} - for target, detection in zip(targets, detections): - gt_boxes = target["boxes"] - pred_boxes = detection["boxes"] - pred_labels = detection["labels"] - - ious = torchvision.ops.box_iou(pred_boxes, gt_boxes) - - max_ious, max_indices = ious.max(dim=1) - - # Identify true positives, false positives, and false negatives - for idx, iou in enumerate(max_ious): - if iou.item() > ious_threshold: - pred_class_idx = pred_labels[idx].item() - true_label = target["labels"][max_indices[idx]].item() - - if pred_class_idx == true_label: - class_stats["crab"]["tp"] += 1 - else: - class_stats["crab"]["fp"] += 1 - else: - class_stats["crab"]["fp"] += 1 - - for target_box_index, target_box in enumerate(gt_boxes): - found_match = False - for idx, iou in enumerate(max_ious): - if ( - iou.item() - > ious_threshold # we need this condition because the max overlap is not necessarily above the threshold - and max_indices[idx] - == target_box_index # the matching index is the index of the GT box with which it has max overlap - ): - # There's an IoU match and the matched index corresponds to the current target_box_index - found_match = True - break # Exit loop, a match was found - - if not found_match: - # print(found_match) - class_stats["crab"][ - "fn" - ] += 1 # Ground truth box has no corresponding detection - - precision, recall, class_stats = compute_precision_recall(class_stats) - - return precision, recall, class_stats diff --git a/crabs/detection_tracking/evaluate_model.py b/crabs/detection_tracking/evaluate_model.py index cc2132ae..069c9ee9 100644 --- a/crabs/detection_tracking/evaluate_model.py +++ b/crabs/detection_tracking/evaluate_model.py @@ -1,17 +1,22 @@ import argparse +import logging +import os import sys -from pathlib import Path import lightning -import yaml # type: ignore -from lightning.pytorch.loggers import MLFlowLogger +import torch from crabs.detection_tracking.datamodules import CrabsDataModule from crabs.detection_tracking.detection_utils import ( - prep_annotation_files, - prep_img_directories, set_mlflow_run_name, setup_mlflow_logger, + slurm_logs_as_artifacts, +) +from crabs.detection_tracking.evaluate_utils import ( + get_annotation_files_from_ckpt, + get_cli_arg_from_ckpt, + get_config_from_ckpt, + get_img_directories_from_ckpt, ) from crabs.detection_tracking.models import FasterRCNN from crabs.detection_tracking.visualization import save_images_with_boxes @@ -19,80 +24,80 @@ class DetectorEvaluation: """ - A class for evaluating an object detector using trained model. + A class for evaluating an object detector. Parameters ---------- args : argparse Command-line arguments containing configuration settings. - config_file : str - Path to the directory containing configuration file. - images_dirs : list[str] - list of paths to the image directories of the datasets. - annotation_files : list[str] - list of filenames for the COCO annotations. - score_threshold : float - The score threshold for confidence detection. - ious_threshold : float - The ious threshold for detection bounding boxes. - evaluate_dataloader: - The DataLoader for the test dataset. + """ - def __init__( - self, - args: argparse.Namespace, - ) -> None: + def __init__(self, args: argparse.Namespace) -> None: + # CLI inputs self.args = args + + # trained model + self.trained_model_path = args.trained_model_path + + # config: retreieve from ckpt if not passed as CLI argument self.config_file = args.config_file - self.images_dirs = prep_img_directories(args.dataset_dirs) - self.annotation_files = prep_annotation_files( - args.annotation_files, args.dataset_dirs + self.config = get_config_from_ckpt( + config_file=self.config_file, + trained_model_path=self.trained_model_path, ) - self.seed_n = args.seed_n - self.ious_threshold = args.ious_threshold - self.score_threshold = args.score_threshold + + # dataset: retrieve from ckpt if no CLI arguments are passed + self.images_dirs = get_img_directories_from_ckpt( + args=self.args, trained_model_path=self.trained_model_path + ) + self.annotation_files = get_annotation_files_from_ckpt( + args=self.args, trained_model_path=self.trained_model_path + ) + self.seed_n = get_cli_arg_from_ckpt( + args=self.args, + cli_arg_str="seed_n", + trained_model_path=self.trained_model_path, + ) + + # Hardware + self.accelerator = args.accelerator + + # MLflow + self.experiment_name = args.experiment_name self.mlflow_folder = args.mlflow_folder - self.load_config_yaml() - def load_config_yaml(self): - with open(self.config_file, "r") as f: - self.config = yaml.safe_load(f) + # Debugging + self.fast_dev_run = args.fast_dev_run + self.limit_test_batches = args.limit_test_batches - def set_run_name(self): - self.run_name = set_mlflow_run_name() + logging.info("Dataset") + logging.info(f"Images directories: {self.images_dirs}") + logging.info(f"Annotation files: {self.annotation_files}") + logging.info(f"Seed: {self.seed_n}") - def setup_logger(self) -> MLFlowLogger: + def setup_trainer(self): """ - Setup MLflow logger for testing. - - Includes logging metadata about the job (CLI arguments and SLURM job IDs). + Setup trainer object with logging for testing. """ + # Assign run name - self.set_run_name() + self.run_name = set_mlflow_run_name() # Setup logger mlf_logger = setup_mlflow_logger( - experiment_name="Sep2023_evaluation", + experiment_name=self.experiment_name, run_name=self.run_name, mlflow_folder=self.mlflow_folder, cli_args=self.args, ) - return mlf_logger - - def setup_trainer(self): - """ - Setup trainer object with logging for testing. - """ - - # Get MLflow logger - mlf_logger = self.setup_logger() - # Return trainer linked to logger return lightning.Trainer( - accelerator=self.args.accelerator, + accelerator=self.accelerator, logger=mlf_logger, + fast_dev_run=self.fast_dev_run, + limit_test_batches=self.limit_test_batches, ) def evaluate_model(self) -> None: @@ -109,7 +114,7 @@ def evaluate_model(self) -> None: # Get trained model trained_model = FasterRCNN.load_from_checkpoint( - self.args.checkpoint_path + self.trained_model_path, config=self.config ) # Run testing @@ -122,15 +127,21 @@ def evaluate_model(self) -> None: # Save images if required if self.args.save_frames: save_images_with_boxes( - data_module.test_dataloader(), - trained_model, - self.score_threshold, + test_dataloader=data_module.test_dataloader(), + trained_model=trained_model, + output_dir=self.args.frames_output_dir, + score_threshold=self.args.frames_score_threshold, ) + # if this is a slurm job: add slurm logs as artifacts + slurm_job_id = os.environ.get("SLURM_JOB_ID") + if slurm_job_id: + slurm_logs_as_artifacts(trainer.logger, slurm_job_id) + def main(args) -> None: """ - Main function to orchestrate the testing process using Detector_Test. + Main function to orchestrate the testing process. Parameters ---------- @@ -147,36 +158,52 @@ def main(args) -> None: def evaluate_parse_args(args): parser = argparse.ArgumentParser() + parser.add_argument( + "--trained_model_path", + type=str, + required=True, + help="Location of trained model (a .ckpt file)", + ) + parser.add_argument( + "--config_file", + type=str, + default="", + help=( + "Location of YAML config to control evaluation. " + " If None is povided, the config used to train the model is used (recommended)." + ), + ) parser.add_argument( "--dataset_dirs", nargs="+", - required=True, - help="List of dataset directories", + default=[], + help=( + "List of dataset directories. " + "If none is provided (recommended), the datasets used for " + "the trained model are used." + ), ) parser.add_argument( "--annotation_files", nargs="+", default=[], help=( - "List of paths to annotation files. The full path or the filename can be provided. " + "List of paths to annotation files. " + "If none are provided (recommended), the annotations from the dataset of the trained model are used." + "The full path or the filename can be provided. " "If only filename is provided, it is assumed to be under dataset/annotations." ), ) parser.add_argument( - "--checkpoint_path", - type=str, - required=True, - help="Location of trained model", - ) - parser.add_argument( - "--config_file", - type=str, - default=str(Path(__file__).parent / "config" / "faster_rcnn.yaml"), + "--seed_n", + type=int, help=( - "Location of YAML config to control training. " - "Default: crabs-exploration/crabs/detection_tracking/config/faster_rcnn.yaml" + "Seed for dataset splits. " + "If none is provided (recommended), the seed from the dataset of " + "the trained model is used." ), ) + parser.add_argument( "--accelerator", type=str, @@ -188,22 +215,28 @@ def evaluate_parse_args(args): ), ) parser.add_argument( - "--score_threshold", - type=float, - default=0.1, - help="Threshold for confidence score. Default: 0.1", + "--experiment_name", + type=str, + default="Sept2023_evaluation", + help=( + "Name of the experiment in MLflow, under which the current run will be logged. " + "For example, the name of the dataset could be used, to group runs using the same data. " + "Default: Sept2023_evaluation" + ), ) parser.add_argument( - "--ious_threshold", - type=float, - default=0.1, - help="Threshold for IOU. Default: 0.1", + "--fast_dev_run", + action="store_true", + help="Debugging option to run training for one batch and one epoch", ) parser.add_argument( - "--seed_n", - type=int, - default=42, - help="Seed for dataset splits. Default: 42", + "--limit_test_batches", + type=float, + default=1.0, + help=( + "Debugging option to run training on a fraction of the training set." + "Default: 1.0 (all the training set)" + ), ) parser.add_argument( "--mlflow_folder", @@ -216,10 +249,30 @@ def evaluate_parse_args(args): action="store_true", help=("Save predicted frames with bounding boxes."), ) + parser.add_argument( + "--frames_score_threshold", + type=float, + default=0.5, + help=( + "Score threshold for visualising detections on output frames. Default: 0.5" + ), + ) + parser.add_argument( + "--frames_output_dir", + type=str, + default="", + help=( + "Output directory for the exported frames. " + "By default, the frames are saved in a `results_ folder " + "under the current working directory." + ), + ) return parser.parse_args(args) def app_wrapper(): + torch.set_float32_matmul_precision("medium") + eval_args = evaluate_parse_args(sys.argv[1:]) main(eval_args) diff --git a/crabs/detection_tracking/evaluate_utils.py b/crabs/detection_tracking/evaluate_utils.py new file mode 100644 index 00000000..307c1af0 --- /dev/null +++ b/crabs/detection_tracking/evaluate_utils.py @@ -0,0 +1,240 @@ +import argparse +import ast +import logging +import sys +from pathlib import Path + +import torchvision +import yaml # type: ignore + +from crabs.detection_tracking.detection_utils import ( + prep_annotation_files, + prep_img_directories, +) + +logging.basicConfig(level=logging.INFO) + + +def compute_precision_recall(class_stats) -> tuple[float, float, dict]: + """ + Compute precision and recall. + + Parameters + ---------- + class_stats : dict + Statistics or information about different classes. + + Returns + ---------- + Tuple[float, float] + precision and recall + """ + for _, stats in class_stats.items(): + precision = stats["tp"] / max(stats["tp"] + stats["fp"], 1) + recall = stats["tp"] / max(stats["tp"] + stats["fn"], 1) + + return precision, recall, class_stats + + +def compute_confusion_matrix_elements( + targets, detections, ious_threshold +) -> tuple[float, float, dict]: + """ + Compute metrics (true positive, false positive, false negative) for object detection. + + Parameters + ---------- + targets : list + Ground truth annotations. + detections : list + Detected objects. + ious_threshold : float + The threshold value for the intersection-over-union (IOU). + Only detections whose IOU relative to the ground truth is above the + threshold are true positive candidates. + class_stats : dict + Statistics or information about different classes. + + Returns + ---------- + Tuple[float, float] + precision and recall + """ + class_stats = {"crab": {"tp": 0, "fp": 0, "fn": 0}} + for target, detection in zip(targets, detections): + gt_boxes = target["boxes"] + pred_boxes = detection["boxes"] + pred_labels = detection["labels"] + + ious = torchvision.ops.box_iou(pred_boxes, gt_boxes) + + max_ious, max_indices = ious.max(dim=1) + + # Identify true positives, false positives, and false negatives + for idx, iou in enumerate(max_ious): + if iou.item() > ious_threshold: + pred_class_idx = pred_labels[idx].item() + true_label = target["labels"][max_indices[idx]].item() + + if pred_class_idx == true_label: + class_stats["crab"]["tp"] += 1 + else: + class_stats["crab"]["fp"] += 1 + else: + class_stats["crab"]["fp"] += 1 + + for target_box_index, target_box in enumerate(gt_boxes): + found_match = False + for idx, iou in enumerate(max_ious): + if ( + iou.item() + > ious_threshold # we need this condition because the max overlap is not necessarily above the threshold + and max_indices[idx] + == target_box_index # the matching index is the index of the GT box with which it has max overlap + ): + # There's an IoU match and the matched index corresponds to the current target_box_index + found_match = True + break # Exit loop, a match was found + + if not found_match: + # print(found_match) + class_stats["crab"][ + "fn" + ] += 1 # Ground truth box has no corresponding detection + + precision, recall, class_stats = compute_precision_recall(class_stats) + + return precision, recall, class_stats + + +def get_mlflow_parameters_from_ckpt(trained_model_path: str) -> dict: + """Get MLflow client from ckpt path and associated params.""" + import mlflow + + # roughly assert the format of the path is correct + # Note: to check if this is an MLflow chekcpoint, + # we simply check if the parent directory is called + # 'checkpoints', so it is not a very strict check. + try: + assert ( + Path(trained_model_path).parent.stem == "checkpoints" + ), "The parent directory to an MLflow checkpoint is expected to be called 'checkpoints'" + except AssertionError as e: + print(f"Assertion failed: {e}") + sys.exit(1) + + # get mlruns path, experiment and run ID associated to this checkpoint + ckpt_mlruns_path = str(Path(trained_model_path).parents[3]) + # ckpt_experimentID = Path(trained_model_path).parents[2].stem + ckpt_runID = Path(trained_model_path).parents[1].stem + + # create an Mlflow client to interface with mlflow runs + mlrun_client = mlflow.tracking.MlflowClient( + tracking_uri=ckpt_mlruns_path, + ) + + # get parameters of the run + run = mlrun_client.get_run(ckpt_runID) + params = run.data.params + + return params + + +def get_config_from_ckpt(config_file: str, trained_model_path: str) -> dict: + """Get config from checkpoint if config is not passed as a CLI argument.""" + + # If config in CLI arguments: used passed config + if config_file: + with open(config_file, "r") as f: + config_dict = yaml.safe_load(f) + + # If not: used config from ckpt + else: + params = get_mlflow_parameters_from_ckpt( + trained_model_path + ) # string-dict + + # create a 1-level dict + config_dict = {} + for p in params: + if p.startswith("config"): + config_dict[p.replace("config/", "")] = ast.literal_eval( + params[p] + ) + + # format as a 2-levels nested dict + # forward slashes in a key indicate a nested dict + for key in list(config_dict): # list makes a copy of original keys + if "/" in key: + key_parts = key.split("/") + assert len(key_parts) == 2 + if key_parts[0] not in config_dict: + config_dict[key_parts[0]] = { + key_parts[1]: config_dict.pop(key) + } + else: + config_dict[key_parts[0]].update( + {key_parts[1]: config_dict.pop(key)} + ) + + # check there are no more levels + assert all(["/" not in key for key in config_dict]) + + return config_dict + + +def get_cli_arg_from_ckpt( + args: argparse.Namespace, cli_arg_str: str, trained_model_path: str +): + """Get CLI argument from checkpoint if not in args.""" + if getattr(args, cli_arg_str): + cli_arg = getattr(args, cli_arg_str) + else: + params = get_mlflow_parameters_from_ckpt(trained_model_path) + cli_arg = ast.literal_eval(params[f"cli_args/{cli_arg_str}"]) + + return cli_arg + + +def get_img_directories_from_ckpt( + args: argparse.Namespace, trained_model_path: str +) -> list[str]: + """Get image directories from checkpoint if not passed as CLI argument.""" + + # Get dataset directories from ckpt if not defined + dataset_dirs = get_cli_arg_from_ckpt( + args=args, + cli_arg_str="dataset_dirs", + trained_model_path=trained_model_path, + ) + + # Extract image directories + images_dirs = prep_img_directories(dataset_dirs) + + return images_dirs + + +def get_annotation_files_from_ckpt( + args: argparse.Namespace, trained_model_path: str +) -> list[str]: + """Get annotation files from checkpoint if not passed as CLI argument.""" + + # Get path to input annotation files from ckpt if not defined + input_annotation_files = get_cli_arg_from_ckpt( + args=args, + cli_arg_str="annotation_files", + trained_model_path=trained_model_path, + ) + + # Get dataset dirs from ckpt if not defined + dataset_dirs = get_cli_arg_from_ckpt( + args=args, + cli_arg_str="dataset_dirs", + trained_model_path=trained_model_path, + ) + + # Extract annotation files + annotation_files = prep_annotation_files( + input_annotation_files, dataset_dirs + ) + return annotation_files diff --git a/crabs/detection_tracking/models.py b/crabs/detection_tracking/models.py index 5cb183b2..132064ba 100644 --- a/crabs/detection_tracking/models.py +++ b/crabs/detection_tracking/models.py @@ -8,7 +8,9 @@ fasterrcnn_resnet50_fpn_v2, ) -from crabs.detection_tracking.evaluate import compute_confusion_matrix_elements +from crabs.detection_tracking.evaluate_utils import ( + compute_confusion_matrix_elements, +) class FasterRCNN(LightningModule): diff --git a/crabs/detection_tracking/visualization.py b/crabs/detection_tracking/visualization.py index 199d5051..e6fc98bc 100644 --- a/crabs/detection_tracking/visualization.py +++ b/crabs/detection_tracking/visualization.py @@ -1,4 +1,5 @@ import os +from datetime import datetime from typing import Any, Optional import cv2 @@ -152,6 +153,7 @@ def draw_detection( def save_images_with_boxes( test_dataloader: torch.utils.data.DataLoader, trained_model: torch.nn.Module, + output_dir: str, score_threshold: float, ) -> None: """ @@ -175,20 +177,28 @@ def save_images_with_boxes( if torch.cuda.is_available() else torch.device("cpu") ) + + trained_model.to(device) trained_model.eval() - directory = "results" - os.makedirs(directory, exist_ok=True) + + if not output_dir: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = f"results_{timestamp}" + os.makedirs(output_dir, exist_ok=True) + with torch.no_grad(): imgs_id = 0 for imgs, annotations in test_dataloader: imgs_id += 1 imgs = list(img.to(device) for img in imgs) + detections = trained_model(imgs) image_with_boxes = draw_detection( imgs, annotations, detections, score_threshold ) - cv2.imwrite(f"{directory}/imgs{imgs_id}.jpg", image_with_boxes) + + cv2.imwrite(f"{output_dir}/imgs{imgs_id}.jpg", image_with_boxes) def plot_sample(imgs: list, row_title: Optional[str] = None, **imshow_kwargs): diff --git a/tests/test_unit/test_evaluate.py b/tests/test_unit/test_evaluate.py index 2527d6b8..9f693b7a 100644 --- a/tests/test_unit/test_evaluate.py +++ b/tests/test_unit/test_evaluate.py @@ -1,6 +1,6 @@ import torch -from crabs.detection_tracking.evaluate import ( +from crabs.detection_tracking.evaluate_utils import ( compute_confusion_matrix_elements, compute_precision_recall, ) diff --git a/tests/test_unit/test_visualization.py b/tests/test_unit/test_visualization.py index 67edb852..271e1dbc 100644 --- a/tests/test_unit/test_visualization.py +++ b/tests/test_unit/test_visualization.py @@ -1,3 +1,4 @@ +import re from unittest.mock import MagicMock, patch import numpy as np @@ -149,6 +150,10 @@ def test_draw_detection(annotations, detections): assert image_with_boxes is not None +@pytest.mark.parametrize( + "output_dir_name, expected_dir_name", + [("output", r"^output$"), ("", r"^results_\d{8}_\d{6}$")], +) @pytest.mark.parametrize( "detections", [ @@ -168,9 +173,11 @@ def test_draw_detection(annotations, detections): ), ], ) -@patch("cv2.imwrite") -@patch("os.makedirs") -def test_save_images_with_boxes(mock_makedirs, mock_imwrite, detections): +@patch("crabs.detection_tracking.visualization.cv2.imwrite") +@patch("crabs.detection_tracking.visualization.os.makedirs") +def test_save_images_with_boxes( + mock_makedirs, mock_imwrite, detections, output_dir_name, expected_dir_name +): trained_model = MagicMock() test_dataloader = MagicMock() trained_model.return_value = detections @@ -178,8 +185,12 @@ def test_save_images_with_boxes(mock_makedirs, mock_imwrite, detections): save_images_with_boxes( test_dataloader, trained_model, + output_dir=output_dir_name, score_threshold=0.5, ) - assert mock_makedirs.called_once_with("results", exist_ok=True) + # extract and check first positional argument to (mocked) os.makedirs + input_path_makedirs = mock_makedirs.call_args[0][0] + assert re.match(expected_dir_name, input_path_makedirs) + assert mock_imwrite.call_count == len(test_dataloader)