From dda043dd439a9af78a33362b448f420a403c2ff9 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 18 Apr 2024 15:13:21 +0000 Subject: [PATCH 01/17] move eval params to config (WIP) --- .../config/faster_rcnn.yaml | 30 ++++++++++++++----- crabs/detection_tracking/evaluate_model.py | 14 +-------- 2 files changed, 24 insertions(+), 20 deletions(-) diff --git a/crabs/detection_tracking/config/faster_rcnn.yaml b/crabs/detection_tracking/config/faster_rcnn.yaml index bf967d1f..42889fc4 100644 --- a/crabs/detection_tracking/config/faster_rcnn.yaml +++ b/crabs/detection_tracking/config/faster_rcnn.yaml @@ -1,12 +1,23 @@ -num_epochs: 4 +# --------- +# Dataset +# --------- +train_fraction: 0.8 +val_over_test_fraction: 0.5 +num_workers: 0 # for all dataloaders + +# --------- +# Network +# --------- +num_classes: 2 + +# ------------------------------- +# Training/validation parameters +# ------------------------------- +num_epochs: 1 learning_rate: 0.00005 wdecay: 0.00005 batch_size_train: 4 -batch_size_test: 1 batch_size_val: 1 -num_classes: 2 -train_fraction: 0.8 -val_over_test_fraction: 0.5 checkpoint_saving: every_n_epochs: 1 keep_last_n_ckpts: 2 @@ -17,8 +28,6 @@ checkpoint_saving: # if all, all checkpoints for every epoch are added as artifacts during training, # if False, no checkpoints are added as artifacts. save_last: True -iou_threshold: 0.1 -num_workers: 0 transform_brightness: 0.5 transform_hue: 0.3 gaussian_blur_params: @@ -28,3 +37,10 @@ gaussian_blur_params: sigma: - 0.1 - 5.0 + +# ------------------------------ +# Testing parameters +# ------------------------------ +batch_size_test: 1 +iou_threshold: 0.1 +score_threshold: 0.1 \ No newline at end of file diff --git a/crabs/detection_tracking/evaluate_model.py b/crabs/detection_tracking/evaluate_model.py index 9c7003e5..4c89cc46 100644 --- a/crabs/detection_tracking/evaluate_model.py +++ b/crabs/detection_tracking/evaluate_model.py @@ -168,7 +168,7 @@ def evaluate_parse_args(args): parser.add_argument( "--checkpoint_path", type=str, - required=True, + required=True, #--------- can we pass experiment and runid? help="Location of trained model", ) parser.add_argument( @@ -190,18 +190,6 @@ def evaluate_parse_args(args): "and https://lightning.ai/docs/pytorch/stable/accelerators/mps_basic.html#run-on-apple-silicon-gpus" ), ) - parser.add_argument( - "--score_threshold", - type=float, - default=0.1, - help="Threshold for confidence score. Default: 0.1", - ) - parser.add_argument( - "--ious_threshold", - type=float, - default=0.1, - help="Threshold for IOU. Default: 0.1", - ) parser.add_argument( "--seed_n", type=int, From 5815be866f2a3ba18f558e674a00ceb06d027680 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 18 Apr 2024 16:45:03 +0100 Subject: [PATCH 02/17] follow train CLI: add debugger options, add experiment name, use score_threshold from config --- crabs/detection_tracking/evaluate_model.py | 67 +++++++++++++++------- 1 file changed, 47 insertions(+), 20 deletions(-) diff --git a/crabs/detection_tracking/evaluate_model.py b/crabs/detection_tracking/evaluate_model.py index 4c89cc46..a9b8095c 100644 --- a/crabs/detection_tracking/evaluate_model.py +++ b/crabs/detection_tracking/evaluate_model.py @@ -26,35 +26,35 @@ class DetectorEvaluation: ---------- args : argparse Command-line arguments containing configuration settings. - config_file : str - Path to the directory containing configuration file. - images_dirs : list[str] - list of paths to the image directories of the datasets. - annotation_files : list[str] - list of filenames for the COCO annotations. - score_threshold : float - The score threshold for confidence detection. - ious_threshold : float - The ious threshold for detection bounding boxes. - evaluate_dataloader: - The DataLoader for the test dataset. + """ def __init__( self, args: argparse.Namespace, ) -> None: + # inputs self.args = args self.config_file = args.config_file + self.load_config_yaml() + + # dataset self.images_dirs = prep_img_directories(args.dataset_dirs) self.annotation_files = prep_annotation_files( args.annotation_files, args.dataset_dirs ) self.seed_n = args.seed_n - self.ious_threshold = args.ious_threshold - self.score_threshold = args.score_threshold + + # Hardware + self.accelerator = args.accelerator # -------- + + # MLflow + self.experiment_name = args.experiment_name self.mlflow_folder = args.mlflow_folder - self.load_config_yaml() + + # Debugging + self.fast_dev_run = args.fast_dev_run + self.limit_test_batches = args.limit_test_batches def load_config_yaml(self): with open(self.config_file, "r") as f: @@ -74,12 +74,13 @@ def setup_logger(self) -> MLFlowLogger: # Setup logger (no checkpointing) mlf_logger = setup_mlflow_logger( - experiment_name="Sep2023_evaluation", + experiment_name=self.experiment_name, # "Sep2023_evaluation", run_name=self.run_name, mlflow_folder=self.mlflow_folder, ) # Log metadata to logger: CLI arguments and SLURM (if required) + # log model name? mlf_logger = log_metadata_to_logger(mlf_logger, self.args) return mlf_logger @@ -94,8 +95,10 @@ def setup_trainer(self): # Return trainer linked to logger return lightning.Trainer( - accelerator=self.args.accelerator, + accelerator=self.accelerator, logger=mlf_logger, + fast_dev_run=self.fast_dev_run, + limit_test_batches=self.limit_test_batches, ) def evaluate_model(self) -> None: @@ -127,13 +130,13 @@ def evaluate_model(self) -> None: save_images_with_boxes( data_module.test_dataloader(), trained_model, - self.score_threshold, + self.config["score_threshold"], ) def main(args) -> None: """ - Main function to orchestrate the testing process using Detector_Test. + Main function to orchestrate the testing process. Parameters ---------- @@ -168,7 +171,7 @@ def evaluate_parse_args(args): parser.add_argument( "--checkpoint_path", type=str, - required=True, #--------- can we pass experiment and runid? + required=True, # --------- can we pass experiment and run-id? help="Location of trained model", ) parser.add_argument( @@ -190,12 +193,36 @@ def evaluate_parse_args(args): "and https://lightning.ai/docs/pytorch/stable/accelerators/mps_basic.html#run-on-apple-silicon-gpus" ), ) + parser.add_argument( + "--experiment_name", + type=str, + default="Sept2023_evaluation", + help=( + "Name of the experiment in MLflow, under which the current run will be logged. " + "For example, the name of the dataset could be used, to group runs using the same data. " + "Default: Sept2023_evaluation" + ), + ) parser.add_argument( "--seed_n", type=int, default=42, help="Seed for dataset splits. Default: 42", ) + parser.add_argument( + "--fast_dev_run", + action="store_true", + help="Debugging option to run training for one batch and one epoch", + ) + parser.add_argument( + "--limit_test_batches", + type=float, + default=1.0, + help=( + "Debugging option to run training on a fraction of the training set." + "Default: 1.0 (all the training set)" + ), + ) parser.add_argument( "--mlflow_folder", type=str, From 2e262a780b8dd390c260b734d7de81a46df11d6c Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 18 Apr 2024 16:50:17 +0100 Subject: [PATCH 03/17] fix prettier --- crabs/detection_tracking/config/faster_rcnn.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crabs/detection_tracking/config/faster_rcnn.yaml b/crabs/detection_tracking/config/faster_rcnn.yaml index 42889fc4..009b9658 100644 --- a/crabs/detection_tracking/config/faster_rcnn.yaml +++ b/crabs/detection_tracking/config/faster_rcnn.yaml @@ -3,7 +3,7 @@ # --------- train_fraction: 0.8 val_over_test_fraction: 0.5 -num_workers: 0 # for all dataloaders +num_workers: 0 # for all dataloaders # --------- # Network @@ -43,4 +43,4 @@ gaussian_blur_params: # ------------------------------ batch_size_test: 1 iou_threshold: 0.1 -score_threshold: 0.1 \ No newline at end of file +score_threshold: 0.1 From a180f2e45d83ba954ba23a042ce3c3e5134f68ca Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 18 Apr 2024 19:32:34 +0100 Subject: [PATCH 04/17] edit CLI defaults and get dataset params from ckpt if not defined (WIP) --- crabs/detection_tracking/evaluate_model.py | 160 +++++++++++++++++---- 1 file changed, 132 insertions(+), 28 deletions(-) diff --git a/crabs/detection_tracking/evaluate_model.py b/crabs/detection_tracking/evaluate_model.py index a9b8095c..ab6f7e59 100644 --- a/crabs/detection_tracking/evaluate_model.py +++ b/crabs/detection_tracking/evaluate_model.py @@ -1,4 +1,7 @@ import argparse +import ast + +# import pdb import sys from pathlib import Path @@ -38,12 +41,16 @@ def __init__( self.config_file = args.config_file self.load_config_yaml() - # dataset - self.images_dirs = prep_img_directories(args.dataset_dirs) - self.annotation_files = prep_annotation_files( - args.annotation_files, args.dataset_dirs - ) - self.seed_n = args.seed_n + # trained model + self.checkpoint_path = args.checkpoint_path + + # dataset: retrieve from ckpt if possible + # maybe a different name? + self.images_dirs = ( + self.get_img_directories_from_ckpt() + ) # maybe in detection utils? - I think good here cause nothing else uses it for now + self.annotation_files = self.get_annotation_files_from_ckpt() + self.seed_n = self.get_seed_from_ckpt() # Hardware self.accelerator = args.accelerator # -------- @@ -60,6 +67,97 @@ def load_config_yaml(self): with open(self.config_file, "r") as f: self.config = yaml.safe_load(f) + def get_mlflow_client_from_ckpt(self): + # we assume an mlflow ckpt + + import mlflow + + # roughly assert the format of the path + assert Path(self.checkpoint_path).parent.stem == "checkpoints" + + # get mlruns path, experiment and run ID associated to this checkpoint + self.ckpt_mlruns_path = str(Path(self.checkpoint_path).parents[3]) + self.ckpt_experimentID = Path(self.checkpoint_path).parents[2].stem + self.ckpt_runID = Path(self.checkpoint_path).parents[1].stem + + # create an Mlflow client to interface with mlflow runs + self.mlrun_client = mlflow.tracking.MlflowClient( + tracking_uri=self.ckpt_mlruns_path, + ) + + # get params of the run + run = self.mlrun_client.get_run(self.ckpt_runID) + params = run.data.params + + # pdb.set_trace() + return params + + def get_img_directories_from_ckpt(self) -> list[str]: + # if dataset_dirs is empty: + # retrieve from ckpt path + # We assume we always pass a mlflow chckpoint + # would this work with a remote? + if not self.args.dataset_dirs: + # get mlflow client for the ml-runs folder containing the checkpoint + params = self.get_mlflow_client_from_ckpt() + + # get dataset_dirs used in training job + train_cli_dataset_dirs = ast.literal_eval( + params["cli_args/dataset_dirs"] + ) + + # pass that to prep image directories + # pdb.set_trace() + images_dirs = prep_img_directories(train_cli_dataset_dirs) + # pdb.set_trace() + + # if not empty, call the regular one + else: + images_dirs = prep_img_directories(self.args.dataset_dirs) + + return images_dirs + + def get_annotation_files_from_ckpt(self) -> list[str]: + # if no annotation files passed: + # retrieve from checkpoint + # pdb.set_trace() + if not self.args.annotation_files: + # get mlflow client for the ml-runs folder containing the checkpoint + params = self.get_mlflow_client_from_ckpt() + + # pdb.set_trace() + train_cli_dataset_dirs = ast.literal_eval( + params["cli_args/dataset_dirs"] + ) + train_cli_annotation_files = ast.literal_eval( + params["cli_args/annotation_files"] + ) + + # pdb.set_trace() + annotation_files = prep_annotation_files( + train_cli_annotation_files, train_cli_dataset_dirs + ) + # pdb.set_trace() + + else: + annotation_files = prep_annotation_files( + self.args.annotation_files, self.args.dataset_dirs + ) + return annotation_files + + def get_seed_from_ckpt(self) -> int: + # pdb.set_trace() + if not self.args.seed_n: + # get mlflow client for the ml-runs folder containing the checkpoint + params = self.get_mlflow_client_from_ckpt() + + # pdb.set_trace() + seed_n = ast.literal_eval(params["cli_args/seed_n"]) + # pdb.set_trace() + else: + seed_n = self.args.seed_n + return seed_n + def set_run_name(self): self.run_name = set_mlflow_run_name() @@ -114,9 +212,7 @@ def evaluate_model(self) -> None: ) # Get trained model - trained_model = FasterRCNN.load_from_checkpoint( - self.args.checkpoint_path - ) + trained_model = FasterRCNN.load_from_checkpoint(self.checkpoint_path) # Run testing trainer = self.setup_trainer() @@ -153,11 +249,29 @@ def main(args) -> None: def evaluate_parse_args(args): parser = argparse.ArgumentParser() + parser.add_argument( + "--checkpoint_path", + type=str, + required=True, # --------- can we pass experiment and run-id? + help="Location of trained model", + ) + parser.add_argument( + "--config_file", + type=str, + default=str(Path(__file__).parent / "config" / "faster_rcnn.yaml"), + help=( + "Location of YAML config to control training. " + "Default: crabs-exploration/crabs/detection_tracking/config/faster_rcnn.yaml" + ), + ) parser.add_argument( "--dataset_dirs", nargs="+", - required=True, - help="List of dataset directories", + default=[], # required=True, + help=( + "List of dataset directories. If none provided, the ones used for " + "the ckpt training are used." + ), ) parser.add_argument( "--annotation_files", @@ -166,23 +280,19 @@ def evaluate_parse_args(args): help=( "List of paths to annotation files. The full path or the filename can be provided. " "If only filename is provided, it is assumed to be under dataset/annotations." + "If none is provided, the annotations from the dataset of the checkpoint are used." ), ) parser.add_argument( - "--checkpoint_path", - type=str, - required=True, # --------- can we pass experiment and run-id? - help="Location of trained model", - ) - parser.add_argument( - "--config_file", - type=str, - default=str(Path(__file__).parent / "config" / "faster_rcnn.yaml"), + "--seed_n", + type=int, + # default=42, help=( - "Location of YAML config to control training. " - "Default: crabs-exploration/crabs/detection_tracking/config/faster_rcnn.yaml" + "Seed for dataset splits. If none is provided, the seed from the dataset of " + "the checkpoint is used", # No default ), ) + parser.add_argument( "--accelerator", type=str, @@ -203,12 +313,6 @@ def evaluate_parse_args(args): "Default: Sept2023_evaluation" ), ) - parser.add_argument( - "--seed_n", - type=int, - default=42, - help="Seed for dataset splits. Default: 42", - ) parser.add_argument( "--fast_dev_run", action="store_true", From 03b66644b75c2025655cbe0360c37b58babca2c1 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Thu, 18 Apr 2024 20:06:58 +0100 Subject: [PATCH 05/17] fix ninja comma --- crabs/detection_tracking/evaluate_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crabs/detection_tracking/evaluate_model.py b/crabs/detection_tracking/evaluate_model.py index ab6f7e59..e9c94ca4 100644 --- a/crabs/detection_tracking/evaluate_model.py +++ b/crabs/detection_tracking/evaluate_model.py @@ -289,7 +289,7 @@ def evaluate_parse_args(args): # default=42, help=( "Seed for dataset splits. If none is provided, the seed from the dataset of " - "the checkpoint is used", # No default + "the checkpoint is used." # No default ), ) From 442b9ee5a7a2b239f93cbc8eaa572b5f5d70691b Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Tue, 25 Jun 2024 18:22:59 +0100 Subject: [PATCH 06/17] Add sections to config --- .../config/faster_rcnn.yaml | 36 +++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/crabs/detection_tracking/config/faster_rcnn.yaml b/crabs/detection_tracking/config/faster_rcnn.yaml index ac6b4c51..04ba789e 100644 --- a/crabs/detection_tracking/config/faster_rcnn.yaml +++ b/crabs/detection_tracking/config/faster_rcnn.yaml @@ -1,12 +1,24 @@ +--- +#---------- +# Dataset +#------------- +train_fraction: 0.8 +val_over_test_fraction: 0.5 +num_workers: 4 + +# ------------------- +# Model architecture +# ------------------- +num_classes: 2 + +# ------------------------------- +# Training/validation parameters +# ------------------------------- n_epochs: 250 learning_rate: 0.00005 wdecay: 0.00005 batch_size_train: 4 -batch_size_test: 4 batch_size_val: 4 -num_classes: 2 -train_fraction: 0.8 -val_over_test_fraction: 0.5 checkpoint_saving: every_n_epochs: 50 keep_last_n_ckpts: 5 @@ -17,8 +29,10 @@ checkpoint_saving: # if all, all checkpoints for every epoch are added as artifacts during training, # if False, no checkpoints are added as artifacts. save_last: True -iou_threshold: 0.1 -num_workers: 4 + +# ------------------- +# Data augmentation +# ------------------- transform_brightness: 0.5 transform_hue: 0.3 gaussian_blur_params: @@ -28,6 +42,16 @@ gaussian_blur_params: sigma: - 0.1 - 5.0 + +# ----------------------- +# Evaluation parameters +# ----------------------- +iou_threshold: 0.1 +batch_size_test: 4 + +# ---------------------------- +# Hyperparameter optimisation +# ----------------------------- # when we run Optuna, the n_trials and n_epochs above will be overwritten by the parameters set by Optuna optuna: # Parameters for hyperparameter optimisation with Optuna: From 1f5ead4e3e6fa2e7370730cac1cfb55d9cbf43c6 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Tue, 25 Jun 2024 18:24:13 +0100 Subject: [PATCH 07/17] Rename to evaluate utils --- crabs/detection_tracking/{evaluate.py => evaluate_utils.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename crabs/detection_tracking/{evaluate.py => evaluate_utils.py} (100%) diff --git a/crabs/detection_tracking/evaluate.py b/crabs/detection_tracking/evaluate_utils.py similarity index 100% rename from crabs/detection_tracking/evaluate.py rename to crabs/detection_tracking/evaluate_utils.py From ba5ed954d85cc140cdba6cff6a1d15ad6ce06dd0 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Tue, 25 Jun 2024 18:33:26 +0100 Subject: [PATCH 08/17] Match current train script and add slurm logs as artifacts --- crabs/detection_tracking/evaluate_model.py | 60 +++++++--------------- 1 file changed, 19 insertions(+), 41 deletions(-) diff --git a/crabs/detection_tracking/evaluate_model.py b/crabs/detection_tracking/evaluate_model.py index 380507bc..59a1978d 100644 --- a/crabs/detection_tracking/evaluate_model.py +++ b/crabs/detection_tracking/evaluate_model.py @@ -1,13 +1,11 @@ import argparse import ast - -# import pdb +import os import sys from pathlib import Path import lightning import yaml # type: ignore -from lightning.pytorch.loggers import MLFlowLogger from crabs.detection_tracking.datamodules import CrabsDataModule from crabs.detection_tracking.detection_utils import ( @@ -15,6 +13,7 @@ prep_img_directories, set_mlflow_run_name, setup_mlflow_logger, + slurm_logs_as_artifacts, ) from crabs.detection_tracking.models import FasterRCNN from crabs.detection_tracking.visualization import save_images_with_boxes @@ -31,10 +30,7 @@ class DetectorEvaluation: """ - def __init__( - self, - args: argparse.Namespace, - ) -> None: + def __init__(self, args: argparse.Namespace) -> None: # inputs self.args = args self.config_file = args.config_file @@ -45,10 +41,11 @@ def __init__( # dataset: retrieve from ckpt if possible # maybe a different name? - self.images_dirs = ( - self.get_img_directories_from_ckpt() - ) # maybe in detection utils? - I think good here cause nothing else uses it for now - self.annotation_files = self.get_annotation_files_from_ckpt() + self.images_dirs = self.get_img_directories_from_ckpt() # if defined + # maybe in detection utils? - I think good here cause nothing else uses it for now + self.annotation_files = ( + self.get_annotation_files_from_ckpt() + ) # if defined self.seed_n = self.get_seed_from_ckpt() # Hardware @@ -88,12 +85,10 @@ def get_mlflow_client_from_ckpt(self): run = self.mlrun_client.get_run(self.ckpt_runID) params = run.data.params - # pdb.set_trace() return params def get_img_directories_from_ckpt(self) -> list[str]: - # if dataset_dirs is empty: - # retrieve from ckpt path + # if dataset_dirs is empty: retrieve from ckpt path # We assume we always pass a mlflow chckpoint # would this work with a remote? if not self.args.dataset_dirs: @@ -106,9 +101,7 @@ def get_img_directories_from_ckpt(self) -> list[str]: ) # pass that to prep image directories - # pdb.set_trace() images_dirs = prep_img_directories(train_cli_dataset_dirs) - # pdb.set_trace() # if not empty, call the regular one else: @@ -124,7 +117,6 @@ def get_annotation_files_from_ckpt(self) -> list[str]: # get mlflow client for the ml-runs folder containing the checkpoint params = self.get_mlflow_client_from_ckpt() - # pdb.set_trace() train_cli_dataset_dirs = ast.literal_eval( params["cli_args/dataset_dirs"] ) @@ -132,11 +124,9 @@ def get_annotation_files_from_ckpt(self) -> list[str]: params["cli_args/annotation_files"] ) - # pdb.set_trace() annotation_files = prep_annotation_files( train_cli_annotation_files, train_cli_dataset_dirs ) - # pdb.set_trace() else: annotation_files = prep_annotation_files( @@ -145,29 +135,22 @@ def get_annotation_files_from_ckpt(self) -> list[str]: return annotation_files def get_seed_from_ckpt(self) -> int: - # pdb.set_trace() if not self.args.seed_n: # get mlflow client for the ml-runs folder containing the checkpoint params = self.get_mlflow_client_from_ckpt() - # pdb.set_trace() seed_n = ast.literal_eval(params["cli_args/seed_n"]) - # pdb.set_trace() else: seed_n = self.args.seed_n return seed_n - def set_run_name(self): - self.run_name = set_mlflow_run_name() - - def setup_logger(self) -> MLFlowLogger: + def setup_trainer(self): """ - Setup MLflow logger for testing. - - Includes logging metadata about the job (CLI arguments and SLURM job IDs). + Setup trainer object with logging for testing. """ + # Assign run name - self.set_run_name() + self.run_name = set_mlflow_run_name() # Setup logger mlf_logger = setup_mlflow_logger( @@ -177,16 +160,6 @@ def setup_logger(self) -> MLFlowLogger: cli_args=self.args, ) - return mlf_logger - - def setup_trainer(self): - """ - Setup trainer object with logging for testing. - """ - - # Get MLflow logger - mlf_logger = self.setup_logger() - # Return trainer linked to logger return lightning.Trainer( accelerator=self.accelerator, @@ -225,6 +198,11 @@ def evaluate_model(self) -> None: self.config["score_threshold"], ) + # if this is a slurm job: add slurm logs as artifacts + slurm_job_id = os.environ.get("SLURM_JOB_ID") + if slurm_job_id: + slurm_logs_as_artifacts(trainer.logger, slurm_job_id) + def main(args) -> None: """ @@ -256,7 +234,7 @@ def evaluate_parse_args(args): type=str, default=str(Path(__file__).parent / "config" / "faster_rcnn.yaml"), help=( - "Location of YAML config to control training. " + "Location of YAML config to control evaluation. " "Default: crabs-exploration/crabs/detection_tracking/config/faster_rcnn.yaml" ), ) From a77216ae752be39d28890595c956e887b244e61d Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Tue, 25 Jun 2024 18:37:14 +0100 Subject: [PATCH 09/17] Fix evaluate_utils --- crabs/detection_tracking/models.py | 4 +++- tests/test_unit/test_evaluate.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/crabs/detection_tracking/models.py b/crabs/detection_tracking/models.py index f83f4a1e..d341c6c3 100644 --- a/crabs/detection_tracking/models.py +++ b/crabs/detection_tracking/models.py @@ -8,7 +8,9 @@ fasterrcnn_resnet50_fpn_v2, ) -from crabs.detection_tracking.evaluate import compute_confusion_matrix_elements +from crabs.detection_tracking.evaluate_utils import ( + compute_confusion_matrix_elements, +) class FasterRCNN(LightningModule): diff --git a/tests/test_unit/test_evaluate.py b/tests/test_unit/test_evaluate.py index 2527d6b8..9f693b7a 100644 --- a/tests/test_unit/test_evaluate.py +++ b/tests/test_unit/test_evaluate.py @@ -1,6 +1,6 @@ import torch -from crabs.detection_tracking.evaluate import ( +from crabs.detection_tracking.evaluate_utils import ( compute_confusion_matrix_elements, compute_precision_recall, ) From 61609b6a1ddd98325e6dfe0d9f491151e85aa43d Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 26 Jun 2024 12:39:56 +0100 Subject: [PATCH 10/17] Use config from ckpt if not passed. Use dataset, annot files and seed from ckpt if not passed. --- crabs/detection_tracking/evaluate_model.py | 168 ++++++++++++--------- 1 file changed, 96 insertions(+), 72 deletions(-) diff --git a/crabs/detection_tracking/evaluate_model.py b/crabs/detection_tracking/evaluate_model.py index 59a1978d..fad3a056 100644 --- a/crabs/detection_tracking/evaluate_model.py +++ b/crabs/detection_tracking/evaluate_model.py @@ -1,10 +1,12 @@ import argparse import ast +import logging import os import sys from pathlib import Path import lightning +import torch import yaml # type: ignore from crabs.detection_tracking.datamodules import CrabsDataModule @@ -31,25 +33,22 @@ class DetectorEvaluation: """ def __init__(self, args: argparse.Namespace) -> None: - # inputs + # CLI inputs self.args = args - self.config_file = args.config_file - self.load_config_yaml() # trained model - self.checkpoint_path = args.checkpoint_path + self.trained_model_path = args.trained_model_path + self.config_file = args.config_file + self.get_config_from_ckpt() + # self.load_config_yaml() # adds self.config from yaml ----> instead get from ckpt! # dataset: retrieve from ckpt if possible - # maybe a different name? self.images_dirs = self.get_img_directories_from_ckpt() # if defined - # maybe in detection utils? - I think good here cause nothing else uses it for now - self.annotation_files = ( - self.get_annotation_files_from_ckpt() - ) # if defined - self.seed_n = self.get_seed_from_ckpt() + self.annotation_files = self.get_annotation_files_from_ckpt() + self.seed_n = self.get_cli_arg_from_ckpt("seed_n") # Hardware - self.accelerator = args.accelerator # -------- + self.accelerator = args.accelerator # MLflow self.experiment_name = args.experiment_name @@ -59,22 +58,21 @@ def __init__(self, args: argparse.Namespace) -> None: self.fast_dev_run = args.fast_dev_run self.limit_test_batches = args.limit_test_batches - def load_config_yaml(self): - with open(self.config_file, "r") as f: - self.config = yaml.safe_load(f) - - def get_mlflow_client_from_ckpt(self): - # we assume an mlflow ckpt + logging.info(f"Images directories: {self.images_dirs}") + logging.info(f"Annotation files: {self.annotation_files}") + logging.info(f"Seed: {self.seed_n}") + def get_mlflow_parameters_from_ckpt(self): + """Get MLflow client from ckpt path and associated hparams""" import mlflow # roughly assert the format of the path - assert Path(self.checkpoint_path).parent.stem == "checkpoints" + assert Path(self.trained_model_path).parent.stem == "checkpoints" # get mlruns path, experiment and run ID associated to this checkpoint - self.ckpt_mlruns_path = str(Path(self.checkpoint_path).parents[3]) - self.ckpt_experimentID = Path(self.checkpoint_path).parents[2].stem - self.ckpt_runID = Path(self.checkpoint_path).parents[1].stem + self.ckpt_mlruns_path = str(Path(self.trained_model_path).parents[3]) + self.ckpt_experimentID = Path(self.trained_model_path).parents[2].stem + self.ckpt_runID = Path(self.trained_model_path).parents[1].stem # create an Mlflow client to interface with mlflow runs self.mlrun_client = mlflow.tracking.MlflowClient( @@ -87,62 +85,84 @@ def get_mlflow_client_from_ckpt(self): return params - def get_img_directories_from_ckpt(self) -> list[str]: - # if dataset_dirs is empty: retrieve from ckpt path - # We assume we always pass a mlflow chckpoint - # would this work with a remote? - if not self.args.dataset_dirs: - # get mlflow client for the ml-runs folder containing the checkpoint - params = self.get_mlflow_client_from_ckpt() - - # get dataset_dirs used in training job - train_cli_dataset_dirs = ast.literal_eval( - params["cli_args/dataset_dirs"] - ) + def get_config_from_ckpt(self): + """Get config from checkpoint if not passed as CLI arg""" - # pass that to prep image directories - images_dirs = prep_img_directories(train_cli_dataset_dirs) + # If passed: used passed config + if self.config_file: + with open(self.config_file, "r") as f: + config_dict = yaml.safe_load(f) - # if not empty, call the regular one + # If not passed: used config from ckpt + else: + params = self.get_mlflow_parameters_from_ckpt() # string-dict + + # create a 1-level dict + config_dict = {} + for p in params: + if p.startswith("config"): + config_dict[p.replace("config/", "")] = ast.literal_eval( + params[p] + ) + + # format as a 2-levels nested dict + # forward slashes indicate a nested dict + for key in list(config_dict): # makes a copy of original keys! + if "/" in key: + key_parts = key.split("/") + if key_parts[0] not in config_dict: + config_dict[key_parts[0]] = { + key_parts[1]: config_dict.pop(key) + } # config_dict[key]} # initialise + else: + config_dict[key_parts[0]].update( + {key_parts[1]: config_dict.pop(key)} + ) + + # check there are no more levels + assert all(["/" not in key for key in config_dict]) + + self.config = config_dict + + def get_cli_arg_from_ckpt(self, cli_arg_str): + """Get CLI argument from checkpoint if not defined""" + if getattr(self.args, cli_arg_str): + cli_arg = getattr(self.args, cli_arg_str) else: - images_dirs = prep_img_directories(self.args.dataset_dirs) + params = self.get_mlflow_parameters_from_ckpt() + + cli_arg = ast.literal_eval(params[f"cli_args/{cli_arg_str}"]) + + return cli_arg + + def get_img_directories_from_ckpt(self) -> list[str]: + """Get image directories from checkpoint if not defined.""" + # Get dataset directories from ckpt if not defined + dataset_dirs = self.get_cli_arg_from_ckpt("dataset_dirs") + + # Extract image directories + images_dirs = prep_img_directories(dataset_dirs) return images_dirs def get_annotation_files_from_ckpt(self) -> list[str]: - # if no annotation files passed: - # retrieve from checkpoint - # pdb.set_trace() - if not self.args.annotation_files: - # get mlflow client for the ml-runs folder containing the checkpoint - params = self.get_mlflow_client_from_ckpt() - - train_cli_dataset_dirs = ast.literal_eval( - params["cli_args/dataset_dirs"] - ) - train_cli_annotation_files = ast.literal_eval( - params["cli_args/annotation_files"] - ) + """Get annotation files from checkpoint if not defined. - annotation_files = prep_annotation_files( - train_cli_annotation_files, train_cli_dataset_dirs - ) + If annotation_files is not pass as CLI arg to evaluate: + retrieve annotation_files from ckpt path. + """ - else: - annotation_files = prep_annotation_files( - self.args.annotation_files, self.args.dataset_dirs - ) - return annotation_files + # Get path to input annotation files from ckpt if not defined + input_annotation_files = self.get_cli_arg_from_ckpt("annotation_files") - def get_seed_from_ckpt(self) -> int: - if not self.args.seed_n: - # get mlflow client for the ml-runs folder containing the checkpoint - params = self.get_mlflow_client_from_ckpt() + # Get dataset dirs from ckpt if not defined + dataset_dirs = self.get_cli_arg_from_ckpt("dataset_dirs") - seed_n = ast.literal_eval(params["cli_args/seed_n"]) - else: - seed_n = self.args.seed_n - return seed_n + # Extract annotation files + annotation_files = prep_annotation_files( + input_annotation_files, dataset_dirs + ) + return annotation_files def setup_trainer(self): """ @@ -181,7 +201,9 @@ def evaluate_model(self) -> None: ) # Get trained model - trained_model = FasterRCNN.load_from_checkpoint(self.checkpoint_path) + trained_model = FasterRCNN.load_from_checkpoint( + self.trained_model_path, config=self.config + ) # Run testing trainer = self.setup_trainer() @@ -224,7 +246,7 @@ def main(args) -> None: def evaluate_parse_args(args): parser = argparse.ArgumentParser() parser.add_argument( - "--checkpoint_path", + "--trained_model_path", type=str, required=True, # --------- can we pass experiment and run-id? help="Location of trained model", @@ -232,10 +254,10 @@ def evaluate_parse_args(args): parser.add_argument( "--config_file", type=str, - default=str(Path(__file__).parent / "config" / "faster_rcnn.yaml"), + default="", help=( "Location of YAML config to control evaluation. " - "Default: crabs-exploration/crabs/detection_tracking/config/faster_rcnn.yaml" + "Default: '' (the config used to train the model is used)" ), ) parser.add_argument( @@ -243,8 +265,8 @@ def evaluate_parse_args(args): nargs="+", default=[], # required=True, help=( - "List of dataset directories. If none provided, the ones used for " - "the ckpt training are used." + "List of dataset directories. If none provided, the same datasets used for " + "the provided model are used." ), ) parser.add_argument( @@ -316,6 +338,8 @@ def evaluate_parse_args(args): def app_wrapper(): + torch.set_float32_matmul_precision("medium") + eval_args = evaluate_parse_args(sys.argv[1:]) main(eval_args) From d28680b36f5c8cbc70900809e88897182eabe19d Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 26 Jun 2024 12:44:45 +0100 Subject: [PATCH 11/17] Clarify CLI help (hopefully) --- crabs/detection_tracking/evaluate_model.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/crabs/detection_tracking/evaluate_model.py b/crabs/detection_tracking/evaluate_model.py index fad3a056..564b2f56 100644 --- a/crabs/detection_tracking/evaluate_model.py +++ b/crabs/detection_tracking/evaluate_model.py @@ -249,7 +249,7 @@ def evaluate_parse_args(args): "--trained_model_path", type=str, required=True, # --------- can we pass experiment and run-id? - help="Location of trained model", + help="Location of trained model (a .ckpt file)", ) parser.add_argument( "--config_file", @@ -257,16 +257,17 @@ def evaluate_parse_args(args): default="", help=( "Location of YAML config to control evaluation. " - "Default: '' (the config used to train the model is used)" + " If None is povided, the config used to train the model is used (recommended)." ), ) parser.add_argument( "--dataset_dirs", nargs="+", - default=[], # required=True, + default=[], help=( - "List of dataset directories. If none provided, the same datasets used for " - "the provided model are used." + "List of dataset directories. " + "If none is provided (recommended), the datasets used for " + "the trained model are used." ), ) parser.add_argument( @@ -274,18 +275,19 @@ def evaluate_parse_args(args): nargs="+", default=[], help=( - "List of paths to annotation files. The full path or the filename can be provided. " + "List of paths to annotation files. " + "If none are provided (recommended), the annotations from the dataset of the trained model are used." + "The full path or the filename can be provided. " "If only filename is provided, it is assumed to be under dataset/annotations." - "If none is provided, the annotations from the dataset of the checkpoint are used." ), ) parser.add_argument( "--seed_n", type=int, - # default=42, help=( - "Seed for dataset splits. If none is provided, the seed from the dataset of " - "the checkpoint is used." # No default + "Seed for dataset splits. " + "If none is provided (recommended), the seed from the dataset of " + "the trained model is used." ), ) From 0ec4a74e6f7e94da5aba49bc71ac16aa92a957d7 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 26 Jun 2024 12:51:50 +0100 Subject: [PATCH 12/17] Add score threshold for visualisation as CLI argument --- crabs/detection_tracking/evaluate_model.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/crabs/detection_tracking/evaluate_model.py b/crabs/detection_tracking/evaluate_model.py index 564b2f56..e63f6c32 100644 --- a/crabs/detection_tracking/evaluate_model.py +++ b/crabs/detection_tracking/evaluate_model.py @@ -217,7 +217,7 @@ def evaluate_model(self) -> None: save_images_with_boxes( data_module.test_dataloader(), trained_model, - self.config["score_threshold"], + self.args.viz_score_threshold, ) # if this is a slurm job: add slurm logs as artifacts @@ -336,6 +336,12 @@ def evaluate_parse_args(args): action="store_true", help=("Save predicted frames with bounding boxes."), ) + parser.add_argument( + "--viz_score_threshold", + type=float, + default=0.5, + help=("Score threshold for visualisation. Default: 0.5"), + ) return parser.parse_args(args) From 06119d3e7c8ff297fbe3d2dd51c1841f9ab6eb34 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 26 Jun 2024 16:14:25 +0100 Subject: [PATCH 13/17] Small fix to config yaml --- crabs/detection_tracking/config/faster_rcnn.yaml | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/crabs/detection_tracking/config/faster_rcnn.yaml b/crabs/detection_tracking/config/faster_rcnn.yaml index 04ba789e..cbdfaad3 100644 --- a/crabs/detection_tracking/config/faster_rcnn.yaml +++ b/crabs/detection_tracking/config/faster_rcnn.yaml @@ -1,5 +1,3 @@ ---- -#---------- # Dataset #------------- train_fraction: 0.8 @@ -12,7 +10,7 @@ num_workers: 4 num_classes: 2 # ------------------------------- -# Training/validation parameters +# Training & validation parameters # ------------------------------- n_epochs: 250 learning_rate: 0.00005 @@ -30,6 +28,12 @@ checkpoint_saving: # if False, no checkpoints are added as artifacts. save_last: True +# ----------------------- +# Evaluation parameters +# ----------------------- +iou_threshold: 0.1 +batch_size_test: 4 + # ------------------- # Data augmentation # ------------------- @@ -43,12 +47,6 @@ gaussian_blur_params: - 0.1 - 5.0 -# ----------------------- -# Evaluation parameters -# ----------------------- -iou_threshold: 0.1 -batch_size_test: 4 - # ---------------------------- # Hyperparameter optimisation # ----------------------------- From ead4f72d2b1bc11a0bee97b501783c01e755d2bc Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 26 Jun 2024 16:24:53 +0100 Subject: [PATCH 14/17] Clean up --- crabs/detection_tracking/evaluate_model.py | 39 ++++++++++------------ 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/crabs/detection_tracking/evaluate_model.py b/crabs/detection_tracking/evaluate_model.py index e63f6c32..bc5967d4 100644 --- a/crabs/detection_tracking/evaluate_model.py +++ b/crabs/detection_tracking/evaluate_model.py @@ -23,7 +23,7 @@ class DetectorEvaluation: """ - A class for evaluating an object detector using trained model. + A class for evaluating an object detector. Parameters ---------- @@ -40,10 +40,9 @@ def __init__(self, args: argparse.Namespace) -> None: self.trained_model_path = args.trained_model_path self.config_file = args.config_file self.get_config_from_ckpt() - # self.load_config_yaml() # adds self.config from yaml ----> instead get from ckpt! - # dataset: retrieve from ckpt if possible - self.images_dirs = self.get_img_directories_from_ckpt() # if defined + # dataset: retrieve from ckpt if no CLI arguments are passed + self.images_dirs = self.get_img_directories_from_ckpt() self.annotation_files = self.get_annotation_files_from_ckpt() self.seed_n = self.get_cli_arg_from_ckpt("seed_n") @@ -58,12 +57,13 @@ def __init__(self, args: argparse.Namespace) -> None: self.fast_dev_run = args.fast_dev_run self.limit_test_batches = args.limit_test_batches + logging.info("Dataset") logging.info(f"Images directories: {self.images_dirs}") logging.info(f"Annotation files: {self.annotation_files}") logging.info(f"Seed: {self.seed_n}") def get_mlflow_parameters_from_ckpt(self): - """Get MLflow client from ckpt path and associated hparams""" + """Get MLflow client from ckpt path and associated params.""" import mlflow # roughly assert the format of the path @@ -79,21 +79,21 @@ def get_mlflow_parameters_from_ckpt(self): tracking_uri=self.ckpt_mlruns_path, ) - # get params of the run + # get parameters of the run run = self.mlrun_client.get_run(self.ckpt_runID) params = run.data.params return params def get_config_from_ckpt(self): - """Get config from checkpoint if not passed as CLI arg""" + """Get config from checkpoint if config is not passed as a CLI argument.""" - # If passed: used passed config + # If config in CLI arguments: used passed config if self.config_file: with open(self.config_file, "r") as f: config_dict = yaml.safe_load(f) - # If not passed: used config from ckpt + # If not: used config from ckpt else: params = self.get_mlflow_parameters_from_ckpt() # string-dict @@ -106,14 +106,15 @@ def get_config_from_ckpt(self): ) # format as a 2-levels nested dict - # forward slashes indicate a nested dict - for key in list(config_dict): # makes a copy of original keys! + # forward slashes in a key indicate a nested dict + for key in list(config_dict): # list makes a copy of original keys if "/" in key: key_parts = key.split("/") + assert len(key_parts) == 2 if key_parts[0] not in config_dict: config_dict[key_parts[0]] = { key_parts[1]: config_dict.pop(key) - } # config_dict[key]} # initialise + } else: config_dict[key_parts[0]].update( {key_parts[1]: config_dict.pop(key)} @@ -125,18 +126,18 @@ def get_config_from_ckpt(self): self.config = config_dict def get_cli_arg_from_ckpt(self, cli_arg_str): - """Get CLI argument from checkpoint if not defined""" + """Get CLI argument from checkpoint if not in self.args.""" if getattr(self.args, cli_arg_str): cli_arg = getattr(self.args, cli_arg_str) else: params = self.get_mlflow_parameters_from_ckpt() - cli_arg = ast.literal_eval(params[f"cli_args/{cli_arg_str}"]) return cli_arg def get_img_directories_from_ckpt(self) -> list[str]: - """Get image directories from checkpoint if not defined.""" + """Get image directories from checkpoint if not passed as CLI argument.""" + # Get dataset directories from ckpt if not defined dataset_dirs = self.get_cli_arg_from_ckpt("dataset_dirs") @@ -146,11 +147,7 @@ def get_img_directories_from_ckpt(self) -> list[str]: return images_dirs def get_annotation_files_from_ckpt(self) -> list[str]: - """Get annotation files from checkpoint if not defined. - - If annotation_files is not pass as CLI arg to evaluate: - retrieve annotation_files from ckpt path. - """ + """Get annotation files from checkpoint if not passed as CLI argument.""" # Get path to input annotation files from ckpt if not defined input_annotation_files = self.get_cli_arg_from_ckpt("annotation_files") @@ -174,7 +171,7 @@ def setup_trainer(self): # Setup logger mlf_logger = setup_mlflow_logger( - experiment_name=self.experiment_name, # "Sep2023_evaluation", + experiment_name=self.experiment_name, run_name=self.run_name, mlflow_folder=self.mlflow_folder, cli_args=self.args, From d9d4783c8cdda715a82b72f7203a26b42d1054f5 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 26 Jun 2024 17:14:16 +0100 Subject: [PATCH 15/17] Fix save frames and add output_dir --- crabs/detection_tracking/evaluate_model.py | 23 +++++++++++++++++----- crabs/detection_tracking/visualization.py | 16 ++++++++++++--- 2 files changed, 31 insertions(+), 8 deletions(-) diff --git a/crabs/detection_tracking/evaluate_model.py b/crabs/detection_tracking/evaluate_model.py index bc5967d4..a15e7f55 100644 --- a/crabs/detection_tracking/evaluate_model.py +++ b/crabs/detection_tracking/evaluate_model.py @@ -212,9 +212,10 @@ def evaluate_model(self) -> None: # Save images if required if self.args.save_frames: save_images_with_boxes( - data_module.test_dataloader(), - trained_model, - self.args.viz_score_threshold, + test_dataloader=data_module.test_dataloader(), + trained_model=trained_model, + output_dir=self.args.frames_output_dir, + score_threshold=self.args.frames_score_threshold, ) # if this is a slurm job: add slurm logs as artifacts @@ -334,10 +335,22 @@ def evaluate_parse_args(args): help=("Save predicted frames with bounding boxes."), ) parser.add_argument( - "--viz_score_threshold", + "--frames_score_threshold", type=float, default=0.5, - help=("Score threshold for visualisation. Default: 0.5"), + help=( + "Score threshold for visualising detections on output frames. Default: 0.5" + ), + ) + parser.add_argument( + "--frames_output_dir", + type=str, + default="", + help=( + "Output directory for the exported frames. " + "By default, the frames are saved in a `results_ folder " + "under the current working directory." + ), ) return parser.parse_args(args) diff --git a/crabs/detection_tracking/visualization.py b/crabs/detection_tracking/visualization.py index 199d5051..e6fc98bc 100644 --- a/crabs/detection_tracking/visualization.py +++ b/crabs/detection_tracking/visualization.py @@ -1,4 +1,5 @@ import os +from datetime import datetime from typing import Any, Optional import cv2 @@ -152,6 +153,7 @@ def draw_detection( def save_images_with_boxes( test_dataloader: torch.utils.data.DataLoader, trained_model: torch.nn.Module, + output_dir: str, score_threshold: float, ) -> None: """ @@ -175,20 +177,28 @@ def save_images_with_boxes( if torch.cuda.is_available() else torch.device("cpu") ) + + trained_model.to(device) trained_model.eval() - directory = "results" - os.makedirs(directory, exist_ok=True) + + if not output_dir: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_dir = f"results_{timestamp}" + os.makedirs(output_dir, exist_ok=True) + with torch.no_grad(): imgs_id = 0 for imgs, annotations in test_dataloader: imgs_id += 1 imgs = list(img.to(device) for img in imgs) + detections = trained_model(imgs) image_with_boxes = draw_detection( imgs, annotations, detections, score_threshold ) - cv2.imwrite(f"{directory}/imgs{imgs_id}.jpg", image_with_boxes) + + cv2.imwrite(f"{output_dir}/imgs{imgs_id}.jpg", image_with_boxes) def plot_sample(imgs: list, row_title: Optional[str] = None, **imshow_kwargs): From 9ae87302670452bbd2a45ced226926f5144e3021 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 26 Jun 2024 18:05:55 +0100 Subject: [PATCH 16/17] Fix tests --- crabs/detection_tracking/evaluate_model.py | 2 +- tests/test_unit/test_visualization.py | 19 +++++++++++++++---- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/crabs/detection_tracking/evaluate_model.py b/crabs/detection_tracking/evaluate_model.py index a15e7f55..aec64dda 100644 --- a/crabs/detection_tracking/evaluate_model.py +++ b/crabs/detection_tracking/evaluate_model.py @@ -246,7 +246,7 @@ def evaluate_parse_args(args): parser.add_argument( "--trained_model_path", type=str, - required=True, # --------- can we pass experiment and run-id? + required=True, help="Location of trained model (a .ckpt file)", ) parser.add_argument( diff --git a/tests/test_unit/test_visualization.py b/tests/test_unit/test_visualization.py index 67edb852..271e1dbc 100644 --- a/tests/test_unit/test_visualization.py +++ b/tests/test_unit/test_visualization.py @@ -1,3 +1,4 @@ +import re from unittest.mock import MagicMock, patch import numpy as np @@ -149,6 +150,10 @@ def test_draw_detection(annotations, detections): assert image_with_boxes is not None +@pytest.mark.parametrize( + "output_dir_name, expected_dir_name", + [("output", r"^output$"), ("", r"^results_\d{8}_\d{6}$")], +) @pytest.mark.parametrize( "detections", [ @@ -168,9 +173,11 @@ def test_draw_detection(annotations, detections): ), ], ) -@patch("cv2.imwrite") -@patch("os.makedirs") -def test_save_images_with_boxes(mock_makedirs, mock_imwrite, detections): +@patch("crabs.detection_tracking.visualization.cv2.imwrite") +@patch("crabs.detection_tracking.visualization.os.makedirs") +def test_save_images_with_boxes( + mock_makedirs, mock_imwrite, detections, output_dir_name, expected_dir_name +): trained_model = MagicMock() test_dataloader = MagicMock() trained_model.return_value = detections @@ -178,8 +185,12 @@ def test_save_images_with_boxes(mock_makedirs, mock_imwrite, detections): save_images_with_boxes( test_dataloader, trained_model, + output_dir=output_dir_name, score_threshold=0.5, ) - assert mock_makedirs.called_once_with("results", exist_ok=True) + # extract and check first positional argument to (mocked) os.makedirs + input_path_makedirs = mock_makedirs.call_args[0][0] + assert re.match(expected_dir_name, input_path_makedirs) + assert mock_imwrite.call_count == len(test_dataloader) From e87a57c3ea88d76f55f8758aadceb23a45c4e6a0 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Wed, 26 Jun 2024 18:28:23 +0100 Subject: [PATCH 17/17] Add skeleton for tests --- tests/test_unit/test_evaluate_model.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 tests/test_unit/test_evaluate_model.py diff --git a/tests/test_unit/test_evaluate_model.py b/tests/test_unit/test_evaluate_model.py new file mode 100644 index 00000000..b43a3105 --- /dev/null +++ b/tests/test_unit/test_evaluate_model.py @@ -0,0 +1,19 @@ +def test_get_mlflow_parameters_from_ckpt(): + pass + + +def test_get_config_from_ckpt(): + pass + + +def test_get_cli_arg_from_ckpt(): + pass + + +def test_get_img_directories_from_ckpt(): + # If specified in CLI get those, otherwise from ckpt + pass + + +def test_get_annotation_files_from_ckpt(): + pass