From 7105c4c3613d3f689743df77daeeecffacc35fd8 Mon Sep 17 00:00:00 2001 From: sfmig <33267254+sfmig@users.noreply.github.com> Date: Tue, 29 Oct 2024 16:43:18 +0000 Subject: [PATCH] Maintenance of configs and update README (#229) * Edit pre-commit config to fix missing `wheel` dependency * Check if problem is macos15 * Update pyproject.toml to match movement * Update precommit to match movement * Add precommit CI * Run CI on intel macOS and macos-15 * Make new precommits happy * Make new precommits happy * Some more pre-commit changes * Make ruff precommit happy with tests - pending mypy * Make mypy pass * Remove sleap comment * Update readme * Fix test with typer and ellipsis in argument * Remove macOS-15 from CI * Fixed check-manifest issue * Update evaluate command description * Update readme and cli help * Change cli of detect+track to better match the other entry points. Simplify structure of outputs. * Update readme of detect+track to reflect current status * Fix test on track video CLI --- .github/workflows/test_and_deploy.yml | 6 +- .pre-commit-config.yaml | 101 ++++++---- README.md | 158 +++++++++++---- conftest.py | 2 + .../additional_channels_extraction.py | 40 ++-- crabs/bboxes_labelling/annotations_utils.py | 39 ++-- crabs/bboxes_labelling/clip_video.py | 51 +++-- .../combine_and_format_annotations.py | 7 +- .../extract_frames_to_label_w_sleap.py | 64 +++--- crabs/detector/datamodules.py | 32 +-- crabs/detector/datasets.py | 24 ++- crabs/detector/evaluate_model.py | 52 ++--- crabs/detector/models.py | 77 ++++---- crabs/detector/train_model.py | 79 +++++--- crabs/detector/utils/detection.py | 52 ++--- crabs/detector/utils/evaluate.py | 44 +++-- crabs/detector/utils/hpo.py | 7 +- crabs/detector/utils/train.py | 2 +- crabs/detector/utils/visualization.py | 37 ++-- .../extract_pairs_of_frames.py | 43 ++-- crabs/tracker/evaluate_tracker.py | 184 ++++++++++-------- crabs/tracker/sort.py | 66 ++++--- crabs/tracker/track_video.py | 129 +++++++----- crabs/tracker/utils/io.py | 57 +++--- crabs/tracker/utils/sort.py | 92 +++++---- crabs/tracker/utils/tracking.py | 53 +++-- guides/ManualLabellingSteps.md | 0 notebooks/notebook_data_augm.py | 2 +- ...ook_detect_chessboard_in_sampled_frames.py | 4 +- .../notebook_detect_chessboard_in_video.py | 4 +- notebooks/notebook_overlay_video.py | 1 - pyproject.toml | 66 ++++--- scripts/output_video.py | 7 +- tests/data/COCO_VIA_JSONS/VIA_JSON_1.json | 0 tests/data/COCO_VIA_JSONS/VIA_JSON_2.json | 0 tests/fixtures/frame_extraction.py | 16 +- tests/test_integration/test_annotations.py | 10 +- .../test_integration/test_frame_extraction.py | 17 +- tests/test_unit/test_datamodules.py | 34 ++-- tests/test_unit/test_datasets.py | 3 +- tests/test_unit/test_evaluate_tracker.py | 37 +++- tests/test_unit/test_track_video.py | 4 +- tests/test_unit/test_train_model.py | 14 +- 43 files changed, 1001 insertions(+), 716 deletions(-) mode change 100755 => 100644 crabs/bboxes_labelling/extract_frames_to_label_w_sleap.py mode change 100755 => 100644 guides/ManualLabellingSteps.md mode change 100755 => 100644 tests/data/COCO_VIA_JSONS/VIA_JSON_1.json mode change 100755 => 100644 tests/data/COCO_VIA_JSONS/VIA_JSON_2.json diff --git a/.github/workflows/test_and_deploy.yml b/.github/workflows/test_and_deploy.yml index c9549874..6c8d3f01 100644 --- a/.github/workflows/test_and_deploy.yml +++ b/.github/workflows/test_and_deploy.yml @@ -30,9 +30,11 @@ jobs: # Run all supported Python versions on linux os: [ubuntu-latest] python-version: ["3.9", "3.10"] - # Include one macos run + # Include 1 Intel macos (13) and 1 M1 macos (latest) include: - - os: macos-latest + - os: macos-13 # intel macOS + python-version: "3.10" + - os: macos-latest # M1 macOS python-version: "3.10" steps: - uses: neuroinformatics-unit/actions/test@v2 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e16d66da..ed19f3ed 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,37 +1,66 @@ +# exclude: 'conf.py' --- relevant for docs +# Configuring https://pre-commit.ci/ +ci: + autoupdate_schedule: monthly repos: - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.0-alpha.9-for-vscode - hooks: - - id: prettier - args: [--ignore-path=guides/CorrectingTrackLabellingSteps.md] - - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 - hooks: - - id: check-docstring-first - # - id: check-executables-have-shebangs TODO: fix later - - id: check-merge-conflict - - id: check-toml - - id: end-of-file-fixer - - id: mixed-line-ending - args: [--fix=lf] - - id: trailing-whitespace - - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.280 - hooks: - - id: ruff - - repo: https://github.com/psf/black - rev: 23.7.0 - hooks: - - id: black - - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.3.0 - hooks: - - id: mypy - additional_dependencies: - - types-setuptools - - repo: https://github.com/mgedmin/check-manifest - rev: "0.49" - hooks: - - id: check-manifest - args: [--no-build-isolation] - additional_dependencies: [setuptools-scm] + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: check-added-large-files + - id: check-docstring-first + - id: check-executables-have-shebangs + - id: check-case-conflict + - id: check-merge-conflict + - id: check-symlinks + - id: check-yaml + - id: check-toml + - id: debug-statements + - id: end-of-file-fixer + - id: mixed-line-ending + args: [--fix=lf] + - id: name-tests-test + args: ["--pytest-test-first"] + exclude: ^tests/fixtures + - id: requirements-txt-fixer + - id: trailing-whitespace + # - repo: https://github.com/pre-commit/pygrep-hooks + # rev: v1.10.0 + # hooks: + # - id: rst-backticks + # - id: rst-directive-colons + # - id: rst-inline-touching-normal + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.9 + hooks: + - id: ruff + - id: ruff-format + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.11.2 + hooks: + - id: mypy + additional_dependencies: + - attrs + - types-setuptools + - pandas-stubs + - types-attrs + - types-PyYAML + - types-requests + - repo: https://github.com/mgedmin/check-manifest + rev: "0.49" + hooks: + - id: check-manifest + args: [--no-build-isolation] + additional_dependencies: [setuptools-scm] + # - repo: https://github.com/codespell-project/codespell + # # Configuration for codespell is in pyproject.toml + # rev: v2.3.0 + # hooks: + # - id: codespell + # additional_dependencies: + # # tomli dependency can be removed when we drop support for Python 3.10 + # - tomli +exclude: | + (?x)( + ^notebooks/| + ^tests/data/ + ) diff --git a/README.md b/README.md index 65e96e0a..7cb75563 100644 --- a/README.md +++ b/README.md @@ -12,76 +12,162 @@ A toolkit for detecting and tracking crabs in the field. -requires Python 3.9 or 3.10 or 3.11. +`crabs` uses neural networks to detect and track multiple crabs in the field. The detection model is based on the [Faster R-CNN](https://arxiv.org/abs/1506.01497) architecture. The tracking model is based on the [SORT](https://github.com/abewley/sort) tracking algorithm. + +The package supports Python 3.9 or 3.10, and is tested on Linux and MacOS. + +We highly recommend running `crabs` on a machine with a dedicated graphics device, such as an NVIDIA GPU or an Apple M1+ chip. + ### Installation - +#### Users +To install the `crabs` package, first clone this git repository. +```bash +git clone https://github.com/SainsburyWellcomeCentre/crabs-exploration.git +``` -### Data Structure +Then, navigate to the root directory of the repository and install the `crabs` package in a conda environment: -We assume the following structure for the dataset directory: +```bash +conda create -n crabs-env python=3.10 -y +conda activate crabs-env +pip install . +``` +#### Developers +For development, we recommend installing the package in editable mode and with additional `dev` dependencies: + +```bash +pip install -e .[dev] # or ".[dev]" if you are using zsh ``` -|_ Dataset - |_ frames - |_ annotations - |_ VIA_JSON_combined_coco_gen.json + +### CrabsField - Sept2023 dataset + +We trained the detector model on our [CrabsField - Sept2023](https://gin.g-node.org/SainsburyWellcomeCentre/CrabsField) dataset. The dataset consists of 53041 annotations (bounding boxes) over 544 frames extracted from 28 videos of crabs in the field. + +The dataset is currently private. If you have access to the [GIN](https://gin.g-node.org/) repository, you can download the dataset using the GIN CLI tool. To set up the GIN CLI tool: +1. Create [a GIN account](https://gin.g-node.org/user/sign_up). +2. [Download GIN CLI](https://gin.g-node.org/G-Node/Info/wiki/GIN+CLI+Setup#setup-gin-client) and set it up by running: + ``` + $ gin login + ``` + You will be prompted for your GIN username and password. +3. Confirm that everything is working properly by typing: + ``` + $ gin --version + ``` + +Then to download the dataset, run the following command from the directory you want the data to be in: ``` +gin get SainsburyWellcomeCentre/CrabsField +``` +This command will clone the data repository to the current working directory, and download the large files in the dataset as lightweight placeholder files. To download the content of these placeholder files, run: +``` +gin download --content +``` +Because the large files in the dataset are **locked**, this command will download the content to the git annex subdirectory, and turn the placeholder files in the working directory into symlinks that point to that content. For more information on how to work with a GIN repository, see the corresponding [NIU HowTo guide](https://howto.neuroinformatics.dev/open_science/GIN-repositories.html). -The default name assumed for the annotations file is `VIA_JSON_combined_coco_gen.json`. This is used if no input files are passed. Other filenames (or fullpaths) can be passed with the `--annotation_files` command-line argument. +## Basic commands -### Running Locally +### Train a detector -For training +To train a detector on an existing dataset, run the following command: -```bash -python train-detector --dataset_dirs {parent_directory_of_frames_and_annotation} {optional_second_parent_directory_of_frames_and_annotation} --annotation_files {path_to_annotation_file.json} {path_to_optional_second_annotation_file.json} +``` +train-detector --dataset_dirs ``` -Example (using default annotation file and one dataset): +This command assumes each dataset directory has the following structure: -```bash -python train-detector --dataset_dirs /home/data/dataset1 +``` +dataset +|_ frames +|_ annotations + |_ VIA_JSON_combined_coco_gen.json ``` -Example (passing the full path of the annotation file): +The default name assumed for the annotations file is `VIA_JSON_combined_coco_gen.json`. Other filenames (or full paths to annotation files) can be passed with the `--annotation_files` command-line argument. -```bash -python train-detector --dataset_dirs /home/data/dataset1 --annotation_files /home/user/annotations/annotations42.json +To see the full list of possible arguments to the `train-detector` command run: +``` +train-detector --help ``` -Example (passing several datasets with annotation filenames different from the default): +### Monitor a training job + +We use [MLflow](https://mlflow.org) to monitor the training of the detector and log the hyperparameters used. + +To run MLflow, execute the following command from your `crabs-env` conda environment: -```bash -python train-detector --dataset_dirs /home/data/dataset1 /home/data/dataset2 --annotation_files annotation_dataset1.json annotation_dataset2.json +``` +mlflow ui --backend-store-uri file:/// ``` -For evaluation +Replace `` with the path to the directory where the MLflow output is. By default, the output is placed in an `ml-runs` folder under the directory from which the `train-detector` is launched. -```bash -python evaluate-detector --model_dir {directory_to_saved_model} --images_dirs {parent_directory_of_frames_and_annotation} {optional_second_parent_directory_of_frames_and_annotation} --annotation_files {annotation_file.json} {optional_second_annotation_file.json} +In the MLflow browser-based user-interface, you can find the path to the checkpoints directory for any run, under the `path_to_checkpoints` parameter. This will be useful to evaluate the trained model. The model saved at the end of the training job is saved as `last.ckpt` in the `path_to_checkpoints` directory. + +### Evaluate a detector + +To evaluate a trained detector on the test split of the dataset, run the following command: + +``` +evaluate-detector --trained_model_path ``` -Example: +This command assumes the trained detector model (a `.ckpt` checkpoint file) is saved in an MLflow database structure. That is, the checkpoint is assumed to be under a `checkpoints` directory, which in turn should be under a `/` directory. This will be the case if the model has been trained using the `train-detector` command. -```bash -python evaluate-detector --model_dir model/model_00.pt --main_dir /home/data/dataset1/frames /home/data/dataset2/frames --annotation_files /home/data/dataset1/annotations/annotation_dataset1.json /home/data/dataset2/annotations/annotation_dataset2.json +The `evaluate-detector` command will print to screen the average precision and average recall of the detector on the test set. It will also log those metrics to the MLflow database, along with the hyperparameters of the evaluation job. To visualise the MLflow summary of the evaluation job, run: +``` +mlflow ui --backend-store-uri file:/// ``` +where `` is the path to the directory where the MLflow output is. -For running inference +To see the full list of possible arguments to the `evaluate-detector` command, run it with the `--help` flag. + +### Run detector+tracking on a video + +To track crabs in a new video, using a trained detector and a tracker, run the following command: -```bash -python crabs/detection_tracking/inference_model.py --model_dir {oath_to_trained_model} --vid_path {path_to_input_video} ``` +detect-and-track-video --trained_model_path --video_path +``` + +This will produce a `tracking_output_` directory with the output from tracking. + +The tracking output consists of: +- a .csv file named `_tracks.csv`, with the tracked bounding boxes data; +- if the flag `--save_video` is added to the command: a video file named `_tracks.mp4`, with the tracked bounding boxes; +- if the flag `--save_frames` is added to the command: a subdirectory named `_frames` is created, and the video frames are saved in it. + +The .csv file with tracked bounding boxes can be imported in [movement](https://github.com/neuroinformatics-unit/movement) for further analysis. See the [movement documentation](https://movement.neuroinformatics.dev/getting_started/input_output.html#loading-bounding-boxes-tracks) for more details. + +Note that when using `--save_frames`, the frames of the video are saved as-is, without added bounding boxes. The aim is to support the visualisation and correction of the predictions using the [VGG Image Annotator (VIA)](https://www.robots.ox.ac.uk/~vgg/software/via/) tool. To do so, follow the instructions of the [VIA Face track annotation tutorial](https://www.robots.ox.ac.uk/~vgg/software/via/docs/face_track_annotation.html). + +If a file with ground-truth annotations is passed to the command (with the `--annotations_file` flag), the MOTA metric for evaluating tracking is computed and printed to screen. -### MLFLow + -We are using [MLflow](https://mlflow.org) to log our training loss and the hyperparameters used. -To run MLflow, execute the following command in your terminal: +To see the full list of possible arguments to the `evaluate-detector` command, run it with the `--help` flag. + + + + + + + + + + + + + diff --git a/conftest.py b/conftest.py index be156ca5..f762c948 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,5 @@ +"""Pytest configuration file.""" + pytest_plugins = [ "tests.fixtures.frame_extraction", ] diff --git a/crabs/bboxes_labelling/additional_channels_extraction.py b/crabs/bboxes_labelling/additional_channels_extraction.py index 58487dcb..bf43e3da 100644 --- a/crabs/bboxes_labelling/additional_channels_extraction.py +++ b/crabs/bboxes_labelling/additional_channels_extraction.py @@ -1,3 +1,5 @@ +"""Script to compute additional channels.""" + import argparse import os from pathlib import Path @@ -9,16 +11,15 @@ def apply_grayscale_and_blur( - frame: np.array, + frame: np.ndarray, kernel_size: list, sigmax: int, -) -> np.array: - """ - Convert the frame to grayscale and apply Gaussian blurring. +) -> tuple: + """Convert the frame to grayscale and apply Gaussian blurring. Parameters ---------- - frame : np.array + frame : np.ndarray frame array read from the video capture kernel_size : list kernel size for GaussianBlur @@ -27,10 +28,11 @@ def apply_grayscale_and_blur( Returns ------- - gray_frame : np.array + gray_frame grayscaled input frame - blurred_frame : np.array + blurred_frame Gaussian-blurred grayscaled input frame + """ # convert the frame to grayscale frame gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) @@ -41,9 +43,7 @@ def apply_grayscale_and_blur( def compute_mean_and_max_abs_blurred_frame(cap, kernel_size, sigmax): - """ - Compute the mean blurred frame and the maximum absolute-value - blurred frame for a video capture cap. + """Compute mean blurred frame and maximum absolute-value blurred frame. Parameters ---------- @@ -60,6 +60,7 @@ def compute_mean_and_max_abs_blurred_frame(cap, kernel_size, sigmax): mean of all blurred frames in the video max_abs_blurred_frame : np.array pixelwise max absolute value across all blurred frames in the video + """ frame_counter = 0 @@ -105,8 +106,9 @@ def compute_background_subtracted_frame( mean_blurred_frame, max_abs_blurred_frame, ): - """ - Compute the background subtracted frame for the + """Compute background subtracted frame. + + Compute background subtracted frame for the input blurred frame, given the mean and max absolute frames of its corresponding video. @@ -124,6 +126,7 @@ def compute_background_subtracted_frame( background_subtracted_frame : np.array normalised difference between the blurred frame f and the mean blurred frame + """ return ( ((blurred_frame - mean_blurred_frame) / max_abs_blurred_frame) + 1 @@ -136,8 +139,7 @@ def compute_motion_frame( mean_blurred_frame, max_abs_blurred_frame, ): - """ - _summary_. + """_summary_. Parameters ---------- @@ -157,6 +159,7 @@ def compute_motion_frame( motion_frame : np.array absolute difference between the background subtracted frame f and the background subtracted frame f+delta + """ # compute the blurred frame frame_idx+delta _, blurred_frame_delta = apply_grayscale_and_blur( @@ -178,8 +181,9 @@ def compute_motion_frame( def compute_stacked_inputs(args: argparse.Namespace) -> None: - """ - Compute the stacked inputs consist of + """Compute stacked inputs. + + Stack consist of grayscale, background subtracted and motion signal. Parameters @@ -279,8 +283,7 @@ def compute_stacked_inputs(args: argparse.Namespace) -> None: def argument_parser() -> argparse.Namespace: - """ - Parse command-line arguments for the script. + """Parse command-line arguments for the script. Returns ------- @@ -288,6 +291,7 @@ def argument_parser() -> argparse.Namespace: An object containing the parsed command-line arguments. The attributes of this object correspond to the defined command-line arguments in the script. + """ parser = argparse.ArgumentParser() parser.add_argument( diff --git a/crabs/bboxes_labelling/annotations_utils.py b/crabs/bboxes_labelling/annotations_utils.py index 761a3690..4d07594b 100644 --- a/crabs/bboxes_labelling/annotations_utils.py +++ b/crabs/bboxes_labelling/annotations_utils.py @@ -1,3 +1,5 @@ +"""Utility functions to work with annotations in JSON format.""" + import json import os import re @@ -8,8 +10,7 @@ def read_json_file( file_path: str, ) -> dict: - """ - Read JSON file as dict. + """Read JSON file as dict. Parameters ---------- @@ -20,6 +21,7 @@ def read_json_file( ------- Optional[dict] Dictionary with the JSON data + """ try: with open(file_path) as file: @@ -41,8 +43,7 @@ def combine_multiple_via_jsons( via_default_dir: Optional[str] = None, via_project_name: Optional[str] = None, ) -> str: - """ - Combine all the input VIA JSON files into one. + r"""Combine all the input VIA JSON files into one. A VIA JSON file is a json file specific to the VIA tool that defines the annotations and also the visualisation settings @@ -78,6 +79,7 @@ def combine_multiple_via_jsons( ------- json_out_fullpath: str full path to the combined VIA JSON file + """ # Initialise data structures for the combined VIA JSON file via_data_combined = {} @@ -128,15 +130,15 @@ def combine_multiple_via_jsons( raise ValueError(msg) # assign directory path to the VIA combined dictionary - via_data_combined["_via_settings"]["core"][ - "default_filepath" - ] = via_default_dir + via_data_combined["_via_settings"]["core"]["default_filepath"] = ( + via_default_dir + ) # If required: change _via_settings > project > name if via_project_name: - via_data_combined["_via_settings"]["project"][ - "name" - ] = via_project_name + via_data_combined["_via_settings"]["project"]["name"] = ( + via_project_name + ) # Save the VIA combined data as a new JSON file # if no output directory is passed, use the parent directory @@ -160,8 +162,7 @@ def convert_via_json_to_coco( coco_out_filename: Optional[str] = None, coco_out_dir: Optional[str] = None, ) -> str: - """ - Convert annotation data for one category from VIA-JSON format to COCO. + """Convert annotation data for one category from VIA-JSON format to COCO. This function takes annotation data in a VIA JSON format and converts it into COCO format, which is widely used for object detection datasets. @@ -175,12 +176,9 @@ def convert_via_json_to_coco( ---------- json_file_path : str Path to the VIA-JSON file containing the annotation data. - coco_category_ID : int, optional - category ID of all the annotations - coco_category_name : str, optional - category name of all the annotations - coco_supercategory_name : str, optional - supercategory for all the annotations + coco_category : dict, optional + Dictionary with the category ID, name and supercategory for all the + annotations. coco_out_filename : str, optional Name of the COCO output file. If None (default), the input VIA JSON filename is used with the suffix '_coco_gen' @@ -193,6 +191,7 @@ def convert_via_json_to_coco( ------- str path to the COCO json file. + """ # Load the annotation data in VIA JSON format with open(json_file_path) as json_file: @@ -213,8 +212,8 @@ def convert_via_json_to_coco( for image_info in annotation_data["_via_img_metadata"].values(): image_data = { "id": image_id, - "width": 0, # TODO: find how we can get this (not available in JSON) - "height": 0, # TODO: find how we can get this (not available in JSON) + "width": 0, # TODO: find how we can get this data from json + "height": 0, # TODO: find how we can get this data from json "file_name": image_info["filename"], } coco_data["images"].append(image_data) diff --git a/crabs/bboxes_labelling/clip_video.py b/crabs/bboxes_labelling/clip_video.py index aa08a0ac..68d2e8f3 100644 --- a/crabs/bboxes_labelling/clip_video.py +++ b/crabs/bboxes_labelling/clip_video.py @@ -1,3 +1,5 @@ +"""Script to clip a video file.""" + import argparse from datetime import datetime from pathlib import Path @@ -8,16 +10,22 @@ def real_time_to_frame_number( real_time: datetime, video_fps: float, start_real_time: datetime ) -> int: - """ - Convert a real-time timestamp to the corresponding frame number in a video. + """Convert a real-time timestamp to the corresponding frame number. + + Parameters + ---------- + real_time : datetime + The real-time timestamp. + video_fps : float + Frames per second of the video. + start_real_time : datetime + The starting real-time timestamp of the video. - Parameters: - real_time (datetime): The real-time timestamp. - video_fps (float): Frames per second of the video. - start_real_time (datetime): The starting real-time timestamp of the video. + Returns + ------- + int + The corresponding frame number in the video. - Returns: - int: The corresponding frame number in the video. """ time_difference = real_time - start_real_time total_seconds = time_difference.total_seconds() @@ -27,18 +35,23 @@ def real_time_to_frame_number( def create_clip( input_file: str, start_frame: int, end_frame: int, output_file: str ) -> None: - """ - Create a video clip from the input video file, starting from a specific frame - and ending at another frame. + """Create a video clip from the input video file. + + Parameters + ---------- + input_file : str + Path to the input video file. + start_frame : int + Starting frame number. + end_frame : int + Ending frame number. + output_file : str + Path to the output video file to be created. - Parameters: - input_file (str): Path to the input video file. - start_frame (int): Starting frame number. - end_frame (int): Ending frame number. - output_file (str): Path to the output video file to be created. + Returns + ------- + None - Returns: - None """ cap = cv2.VideoCapture(input_file) video_fps = cap.get(cv2.CAP_PROP_FPS) @@ -75,8 +88,8 @@ def argument_parser() -> argparse.Namespace: An object containing the parsed command-line arguments. The attributes of this object correspond to the defined command-line arguments in the script. - """ + """ parser = argparse.ArgumentParser() parser.add_argument( "--video_path", diff --git a/crabs/bboxes_labelling/combine_and_format_annotations.py b/crabs/bboxes_labelling/combine_and_format_annotations.py index 4740e66d..fc5001d6 100644 --- a/crabs/bboxes_labelling/combine_and_format_annotations.py +++ b/crabs/bboxes_labelling/combine_and_format_annotations.py @@ -1,3 +1,5 @@ +"""Script to combine and format annotations.""" + from pathlib import Path from typing import Optional @@ -19,7 +21,7 @@ def combine_VIA_and_convert_to_COCO( via_default_dir: Optional[str] = None, via_project_name: Optional[str] = None, ) -> str: - """Combine a list of VIA JSON files into one and convert to COCO format + r"""Combine a list of VIA JSON files into one and convert to COCO format. Parameters ---------- @@ -40,8 +42,8 @@ def combine_VIA_and_convert_to_COCO( ------- str path to the COCO json file. By default, the file - """ + """ # Get list of all JSON files in directory all_files = Path(parent_dir_via_jsons).glob("*") list_input_json_files = [ @@ -61,6 +63,7 @@ def combine_VIA_and_convert_to_COCO( def app_wrapper(): + """Wrap function for the Typer app.""" app() diff --git a/crabs/bboxes_labelling/extract_frames_to_label_w_sleap.py b/crabs/bboxes_labelling/extract_frames_to_label_w_sleap.py old mode 100755 new mode 100644 index 58b5dd9d..9a71b737 --- a/crabs/bboxes_labelling/extract_frames_to_label_w_sleap.py +++ b/crabs/bboxes_labelling/extract_frames_to_label_w_sleap.py @@ -1,14 +1,4 @@ -r""" -A script to extract frames for labelling using SLEAP's algorith,. - -Example usage: - python bboxes\ labelling/extract_frames_to_label_w_sleap.py - 'crab_sample_data/sample_clips/' - --initial_samples 5 - --n_components 2 - --n_clusters 2 - --per_cluster 1 - --compute_features_per_video +"""A script to extract frames for labelling using SLEAP's algorithm. TODO: can I make it deterministic? TODO: check https://github.com/talmolab/sleap-io/tree/main/sleap_io @@ -16,7 +6,6 @@ https://www.geeksforgeeks.org/python-copy-directory-structure-without-files/ """ - import copy import json import logging @@ -38,12 +27,11 @@ app = typer.Typer(rich_markup_mode="rich") -def get_list_of_sleap_videos( +def get_list_of_sleap_videos( # noqa: C901 list_video_locations, - list_video_extensions_in=["mp4"], + video_extensions_in=("mp4"), ): - """ - Generate list of SLEAP videos. + """Generate list of SLEAP videos. The locations in which we look for videos can be expressed as paths to files or @@ -55,7 +43,7 @@ def get_list_of_sleap_videos( list of video locations. These may be paths to video files or paths to their parent directories (only one level deep is searched). - list_video_extensions_in : list[str] + video_extensions_in : tuple[str] list of video extensions to look for in the directories. By default, mp4 videos. @@ -63,8 +51,10 @@ def get_list_of_sleap_videos( ------- list_sleap_videos : list[sleap.io.video.Video] list of SLEAP videos + """ # Make list of extensions case insensitive + list_video_extensions_in = list(video_extensions_in) list_video_extensions = copy.deepcopy(list_video_extensions_in) for ext in list_video_extensions_in: if ext.isupper(): @@ -89,8 +79,7 @@ def get_list_of_sleap_videos( # If the path is a file with the relevant extension: # append path directly to list elif location_path.is_file() and ( - location_path.suffix[1:] - in list_video_extensions + location_path.suffix[1:] in list_video_extensions # suffix includes dot ): list_video_paths.append(location_path) @@ -123,9 +112,7 @@ def get_list_of_sleap_videos( def get_map_videos_to_extracted_frames(list_sleap_videos, suggestions): - """ - Compute dictionary that maps videos to - their frame indices selected for labelling. + """Compute dictionary mapping videos to frame indices for labelling. Parameters ---------- @@ -159,7 +146,7 @@ def get_map_videos_to_extracted_frames(list_sleap_videos, suggestions): def compute_suggested_sleap_frames( list_video_locations, - video_extensions=["mp4"], + video_extensions=("mp4"), initial_samples=200, sample_method="stride", scale=1.0, @@ -169,9 +156,7 @@ def compute_suggested_sleap_frames( per_cluster=5, compute_features_per_video=True, ): - """ - Compute suggested frames for labelling using SLEAP's - FeatureSuggestionPipeline. + """Compute frames for labelling using SLEAP's FeatureSuggestionPipeline. See https://sleap.ai/guides/gui.html#labeling-suggestions @@ -180,9 +165,9 @@ def compute_suggested_sleap_frames( list_video_locations : list[str] list of video locations. These may be paths to video files or paths to their parent directories (only one level deep is searched). - video_extensions : list[str] - list of video extensions to look for in the directories. - Default: ["mp4"] + video_extensions : tuple[str] + tuple of video extensions to look for in the directories. + Default: ("mp4") initial_samples : int initial number of frames to extract per video Default: 200 @@ -217,6 +202,7 @@ def compute_suggested_sleap_frames( dictionary that maps each video path to a list of frames indices extracted for labelling. The frame indices are sorted in ascending order. + """ # Transform list of input videos to list of SLEAP Video instances list_sleap_videos = get_list_of_sleap_videos( @@ -265,9 +251,7 @@ def extract_frames_to_label_from_video( output_subdir_path, flag_parent_dir_subdir_in_output=False, ): - """ - Extract suggested frames for labelling from - corresponding videos using OpenCV. + """Extract frames for labelling from corresponding videos using OpenCV. The png files for each frame are named with the following format: @@ -291,6 +275,7 @@ def extract_frames_to_label_from_video( ------ KeyError If a frame from a video is not correctly read by openCV + """ for vid_str in map_videos_to_extracted_frames: # Initialise video capture @@ -354,7 +339,7 @@ def compute_and_extract_frames_to_label( list_video_locations: list[str], output_path: str = ".", output_subdir: Optional[str] = None, - video_extensions: list[str] = ["mp4"], + video_extensions: tuple[str] = ("mp4",), initial_samples: int = 200, sample_method: str = "stride", # choices=["random", "stride"], scale: float = 1.0, @@ -364,9 +349,7 @@ def compute_and_extract_frames_to_label( per_cluster: int = 5, compute_features_per_video: bool = True, ): - """Compute suggested frames to label and - extract them as png files. - + """Compute frames to label and extract them as png files. We use SLEAP's image feature method to select the frames for labelling and export them as png @@ -385,9 +368,9 @@ def compute_and_extract_frames_to_label( output_subdir : str, optional name of output subdirectory in which to put extracted frames, by default the timestamp in the format YYYMMDD_HHMMSS. - video_extensions : list, optional + video_extensions : tuple, optional extensions to search for when looking for video files, - by default ["mp4"] + by default ("mp4") initial_samples : int, optional initial number of frames to extract per video, by default 200 sample_method : str, optional @@ -408,6 +391,7 @@ def compute_and_extract_frames_to_label( compute_features_per_video : bool, optional whether to compute the (PCA?) features per video, or across all videos, by default True + """ # Compute list of suggested frames using SLEAP map_videos_to_extracted_frames = compute_suggested_sleap_frames( @@ -449,7 +433,8 @@ def compute_and_extract_frames_to_label( indent=4, ) logging.info( - f"Existing json file with extracted frames updated at {json_output_file}", + "Existing json file with " + f"extracted frames updated at {json_output_file}", ) # else: start a new file else: @@ -473,6 +458,7 @@ def compute_and_extract_frames_to_label( def app_wrapper(): + """Wrap function for the Typer app.""" app() diff --git a/crabs/detector/datamodules.py b/crabs/detector/datamodules.py index 2be2744f..9b4ea5c3 100644 --- a/crabs/detector/datamodules.py +++ b/crabs/detector/datamodules.py @@ -1,3 +1,5 @@ +"""DataModule for the crabs data.""" + from typing import Optional import torch @@ -26,6 +28,7 @@ def __init__( split_seed: Optional[int] = None, no_data_augmentation: bool = False, ): + """Initialise the CrabsDataModule.""" super().__init__() self.list_img_dirs = list_img_dirs self.list_annotation_files = list_annotation_files @@ -34,7 +37,7 @@ def __init__( self.no_data_augmentation = no_data_augmentation def _transform_str_to_operator(self, transform_str): - """Get transform operator from its name in snake case""" + """Get transform operator from its name in snake case.""" def snake_to_camel_case(snake_str): return "".join( @@ -48,8 +51,7 @@ def snake_to_camel_case(snake_str): return transform_callable(**self.config[transform_str]) def _compute_list_of_transforms(self) -> list[torchvision.transforms.v2]: - """Read transforms from config and add to list""" - + """Read transforms from config and add to list.""" # Initialise list train_data_augm: list[torchvision.transforms.v2] = [] @@ -142,6 +144,7 @@ def _collate_fn(self, batch: tuple) -> tuple: tuple a tuple of length = batch size, made up of (image, annotations) tuples. + """ return tuple(zip(*batch)) @@ -167,8 +170,8 @@ def _compute_splits( ------- tuple A tuple with the train, test and validation datasets - """ + """ # Optionally fix the random number generators for reproducible # splits of data rng_train_split, rng_val_split = None, None @@ -207,14 +210,17 @@ def _compute_splits( return train_dataset, test_dataset, val_dataset def prepare_data(self): - """ + """Prepare dataset. + + Unused for now. + To download data, IO, etc. Useful with shared filesystems, only called on 1 GPU/TPU in distributed. """ pass def setup(self, stage: str): - """Setup the data for training, testing and validation. + """Set up the data for training, testing and validation. Define the transforms for each split of the data and compute them. """ @@ -232,16 +238,14 @@ def setup(self, stage: str): ) def train_dataloader(self) -> DataLoader: - """Define dataloader for the training set""" + """Define dataloader for the training set.""" return DataLoader( self.train_dataset, batch_size=self.config["batch_size_train"], shuffle=True, # a shuffled sampler will be constructed num_workers=self.config["num_workers"], collate_fn=self._collate_fn, - persistent_workers=True - if self.config["num_workers"] > 0 - else False, + persistent_workers=bool(self.config["num_workers"] > 0), multiprocessing_context="fork" if self.config["num_workers"] > 0 and torch.backends.mps.is_available() @@ -249,16 +253,14 @@ def train_dataloader(self) -> DataLoader: ) def val_dataloader(self) -> DataLoader: - """Define dataloader for the validation set""" + """Define dataloader for the validation set.""" return DataLoader( self.val_dataset, batch_size=self.config["batch_size_val"], shuffle=False, num_workers=self.config["num_workers"], collate_fn=self._collate_fn, - persistent_workers=True - if self.config["num_workers"] > 0 - else False, + persistent_workers=bool(self.config["num_workers"] > 0), multiprocessing_context="fork" if self.config["num_workers"] > 0 and torch.backends.mps.is_available() @@ -266,7 +268,7 @@ def val_dataloader(self) -> DataLoader: ) def test_dataloader(self) -> DataLoader: - """Define dataloader for the test set""" + """Define dataloader for the test set.""" return DataLoader( self.test_dataset, batch_size=self.config["batch_size_test"], diff --git a/crabs/detector/datasets.py b/crabs/detector/datasets.py index 58b2ed05..03f5929e 100644 --- a/crabs/detector/datasets.py +++ b/crabs/detector/datasets.py @@ -1,3 +1,5 @@ +"""Dataset classes for the crabs COCO dataset.""" + import json import os import tempfile @@ -9,6 +11,12 @@ class CrabsCocoDetection(torch.utils.data.ConcatDataset): + """Class for crabs' COCO dataset. + + The dataset is built by concatenating CocoDetection datasets, each wrapped + for use with `transforms_v2`. + """ + def __init__( self, list_img_dirs: list[str], @@ -16,9 +24,7 @@ def __init__( transforms: Optional[Callable] = None, list_exclude_files: Optional[list[str]] = None, ): - """ - A class for concatenated CocoDetection datasets wrapped for - transforms_v2. + """Construct a concatenated dataset of CocoDetection datasets. If a list of files to exclude from the dataset is passed, a new annotation file is generated without the data to exclude. @@ -36,9 +42,10 @@ def __init__( Each individual dataset in the concatenated set should be equivalent to one obtained with: - > dataset = wrap_dataset_for_transforms_v2(CocoDetection([IMAGES_PATH], [ANNOTATIONS_PATH])) + > dataset = wrap_dataset_for_transforms_v2( + > CocoDetection([IMAGES_PATH], [ANNOTATIONS_PATH]) + > ) """ - # Create list of transformed-COCO datasets list_datasets = [] for img_dir, annotation_file in zip( @@ -103,7 +110,7 @@ def save_filt_annotations( self, annotation_file: str, list_files_to_exclude: list[str], - out_filename, + out_filename: str, ) -> str: """Remove selected images from annotation file and save new file. @@ -116,14 +123,17 @@ def save_filt_annotations( path to file with annotations list_files_to_exclude : list[str] list of filenames to exclude from the dataset + out_filename : str + path to save new annotation file Returns ------- str path to new annotation file + """ # Read annotation file as a dataset dict - with open(annotation_file, "r") as f: + with open(annotation_file) as f: dataset = json.load(f) # Determine images to exclude diff --git a/crabs/detector/evaluate_model.py b/crabs/detector/evaluate_model.py index 53889319..37c3cf01 100644 --- a/crabs/detector/evaluate_model.py +++ b/crabs/detector/evaluate_model.py @@ -1,3 +1,5 @@ +"""Script to evaluate a trained object detector.""" + import argparse import logging import os @@ -23,8 +25,7 @@ class DetectorEvaluate: - """ - A class for evaluating an object detector. + """Interface for evaluating an object detector. Parameters ---------- @@ -34,6 +35,7 @@ class DetectorEvaluate: """ def __init__(self, args: argparse.Namespace) -> None: + """Initialise the evaluation interface with the given arguments.""" # CLI inputs self.args = args @@ -77,10 +79,7 @@ def __init__(self, args: argparse.Namespace) -> None: logging.info(f"Seed: {self.seed_n}") def setup_trainer(self): - """ - Setup trainer object with logging for testing. - """ - + """Set up trainer object with logging for testing.""" # Assign run name self.run_name = set_mlflow_run_name() @@ -101,9 +100,7 @@ def setup_trainer(self): ) def evaluate_model(self) -> None: - """ - Evaluate the trained model on the test dataset. - """ + """Evaluate the trained model on the test dataset.""" # Create datamodule data_module = CrabsDataModule( list_img_dirs=self.images_dirs, @@ -141,8 +138,7 @@ def evaluate_model(self) -> None: def main(args) -> None: - """ - Main function to orchestrate the testing process. + """Run detector testing. Parameters ---------- @@ -152,12 +148,14 @@ def main(args) -> None: Returns ------- None + """ evaluator = DetectorEvaluate(args) evaluator.evaluate_model() def evaluate_parse_args(args): + """Parse command-line arguments for evaluation.""" parser = argparse.ArgumentParser() parser.add_argument( "--trained_model_path", @@ -171,7 +169,8 @@ def evaluate_parse_args(args): default="", help=( "Location of YAML config to control evaluation. " - " If None is povided, the config used to train the model is used (recommended)." + "If none is povided, the config used to train " + "the model is used (recommended)." ), ) parser.add_argument( @@ -190,9 +189,11 @@ def evaluate_parse_args(args): default=[], help=( "List of paths to annotation files. " - "If none are provided (recommended), the annotations from the dataset of the trained model are used." + "If none are provided (recommended), the annotations " + "from the dataset of the trained model are used." "The full path or the filename can be provided. " - "If only filename is provided, it is assumed to be under dataset/annotations." + "If only filename is provided, it is assumed to be " + "under dataset/annotations." ), ) parser.add_argument( @@ -210,9 +211,10 @@ def evaluate_parse_args(args): type=str, default="gpu", help=( - "Accelerator for Pytorch Lightning. Valid inputs are: cpu, gpu, tpu, ipu, auto, mps. Default: gpu." - "See https://lightning.ai/docs/pytorch/stable/common/trainer.html#accelerator " - "and https://lightning.ai/docs/pytorch/stable/accelerators/mps_basic.html#run-on-apple-silicon-gpus" + "Accelerator for Pytorch Lightning. " + "Valid inputs are: cpu, gpu, tpu, ipu, auto, mps. Default: gpu." + "See https://lightning.ai/docs/pytorch/stable/common/trainer.html#accelerator " # noqa: E501 + "and https://lightning.ai/docs/pytorch/stable/accelerators/mps_basic.html#run-on-apple-silicon-gpus" # noqa: E501 ), ) parser.add_argument( @@ -220,8 +222,10 @@ def evaluate_parse_args(args): type=str, default="Sept2023_evaluation", help=( - "Name of the experiment in MLflow, under which the current run will be logged. " - "For example, the name of the dataset could be used, to group runs using the same data. " + "Name of the experiment in MLflow, under which the current run " + "will be logged. " + "For example, the name of the dataset could be used, to group " + "runs using the same data. " "Default: Sept2023_evaluation" ), ) @@ -235,7 +239,8 @@ def evaluate_parse_args(args): type=float, default=1.0, help=( - "Debugging option to run training on a fraction of the training set." + "Debugging option to run training on a fraction of " + "the training set." "Default: 1.0 (all the training set)" ), ) @@ -255,7 +260,8 @@ def evaluate_parse_args(args): type=float, default=0.5, help=( - "Score threshold for visualising detections on output frames. Default: 0.5" + "Score threshold for visualising detections on output frames. " + "Default: 0.5" ), ) parser.add_argument( @@ -264,7 +270,8 @@ def evaluate_parse_args(args): default="", help=( "Output directory for the exported frames. " - "By default, the frames are saved in a `results_ folder " + "By default, the frames are saved in a " + "`results_ folder " "under the current working directory." ), ) @@ -272,6 +279,7 @@ def evaluate_parse_args(args): def app_wrapper(): + """Wrap function to run the evaluation.""" torch.set_float32_matmul_precision("medium") eval_args = evaluate_parse_args(sys.argv[1:]) diff --git a/crabs/detector/models.py b/crabs/detector/models.py index e76b89f4..a6bc8a4b 100644 --- a/crabs/detector/models.py +++ b/crabs/detector/models.py @@ -1,5 +1,7 @@ +"""LightningModule for Faster R-CNN for object detection.""" + import logging -from typing import Any, Tuple, Union +from typing import Any, Union import torch from lightning import LightningModule @@ -12,8 +14,7 @@ class FasterRCNN(LightningModule): - """ - LightningModule implementation of Faster R-CNN for object detection. + """LightningModule implementation of Faster R-CNN for object detection. Parameters ---------- @@ -41,9 +42,11 @@ class FasterRCNN(LightningModule): Dictionary to store validation metrics. test_step_outputs : dict Dictionary to store test metrics. + """ def __init__(self, config: dict[str, Any], optuna_log=False): + """Initialise the Faster R-CNN model with the given configuration.""" super().__init__() self.config = config self.model = self.configure_model() @@ -69,8 +72,9 @@ def __init__(self, config: dict[str, Any], optuna_log=False): } def configure_model(self) -> torch.nn.Module: - """ - Configures the Faster R-CNN model with default weights, + """Configure Faster R-CNN model. + + Use default weights, specified backbone, and box predictor. """ model = fasterrcnn_resnet50_fpn_v2(weights="DEFAULT") @@ -81,9 +85,7 @@ def configure_model(self) -> torch.nn.Module: return model def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the model. - """ + """Forward pass of the model.""" return self.model(x) def accumulate_epoch_metrics( @@ -91,26 +93,21 @@ def accumulate_epoch_metrics( batch_output: dict, dataset_str: str, ) -> None: - """ - Accumulates precision and recall metrics per epoch. - """ - getattr(self, f"{dataset_str}_step_outputs")[ - "precision_epoch" - ] += batch_output["precision"] + """Accumulates precision and recall metrics per epoch.""" + getattr(self, f"{dataset_str}_step_outputs")["precision_epoch"] += ( + batch_output["precision"] + ) - getattr(self, f"{dataset_str}_step_outputs")[ - "recall_epoch" - ] += batch_output["recall"] + getattr(self, f"{dataset_str}_step_outputs")["recall_epoch"] += ( + batch_output["recall"] + ) getattr(self, f"{dataset_str}_step_outputs")["num_batches"] += 1 def compute_precision_recall_epoch( self, step_outputs: dict[str, Union[float, int]], log_str: str - ) -> Tuple[float, float]: - """ - Computes and logs mean precision and recall for the current epoch. - """ - + ) -> tuple[float, float]: + """Compute and log mean precision and recall for the current epoch.""" # compute mean precision and recall mean_precision = ( step_outputs["precision_epoch"] / step_outputs["num_batches"] @@ -136,8 +133,9 @@ def compute_precision_recall_epoch( return mean_precision, mean_recall def on_train_epoch_end(self) -> None: - """ - Hook called after each training epoch to perform tasks such as logging and resetting metrics. + """Define hook called after each training epoch. + + Used to perform tasks such as logging and resetting metrics. """ # compute average loss avg_loss = ( @@ -157,8 +155,9 @@ def on_train_epoch_end(self) -> None: } def on_validation_epoch_end(self) -> None: - """ - Hook called after each validation epoch to compute metrics and logging. + """Define hook called after each validation epoch. + + Used to compute metrics and logging. """ (val_precision, val_recall) = self.compute_precision_recall_epoch( self.validation_step_outputs, "val" @@ -177,8 +176,9 @@ def on_validation_epoch_end(self) -> None: } def on_test_epoch_end(self) -> None: - """ - Hook called after each testing epoch to compute metrics and logging. + """Define hook called after each testing epoch. + + Used to compute metrics and logging. """ (test_precision, test_recall) = self.compute_precision_recall_epoch( self.test_step_outputs, "test" @@ -194,9 +194,7 @@ def on_test_epoch_end(self) -> None: def training_step( self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> torch.Tensor: - """ - Defines the training step for the model. - """ + """Define training step for the model.""" images, targets = batch loss_dict = self.model(images, targets) total_loss = sum(loss for loss in loss_dict.values()) @@ -209,8 +207,9 @@ def training_step( def val_test_step( self, batch: tuple[torch.Tensor, torch.Tensor] ) -> dict[str, Union[float, int]]: - """ - Performs inference on a validation or test batch and computes precision and recall. + """Perform inference on a validation or test batch. + + Computes precision and recall. """ images, targets = batch predictions = self.model(images) @@ -223,9 +222,7 @@ def val_test_step( def validation_step( self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> dict[str, Union[float, int]]: - """ - Defines the validation step for the model. - """ + """Define the validation step for the model.""" outputs = self.val_test_step(batch) self.accumulate_epoch_metrics(outputs, "validation") return outputs @@ -233,17 +230,13 @@ def validation_step( def test_step( self, batch: tuple[torch.Tensor, torch.Tensor], batch_idx: int ) -> dict[str, Union[float, int]]: - """ - Defines the test step for the model. - """ + """Define the test step for the model.""" outputs = self.val_test_step(batch) self.accumulate_epoch_metrics(outputs, "test") return outputs def configure_optimizers(self) -> dict[str, torch.optim.Optimizer]: - """ - Configures the optimizer for training. - """ + """Configure the optimizer for training.""" optimizer = torch.optim.Adam( self.parameters(), lr=self.config["learning_rate"], diff --git a/crabs/detector/train_model.py b/crabs/detector/train_model.py index 735f4213..4ef3d4ea 100644 --- a/crabs/detector/train_model.py +++ b/crabs/detector/train_model.py @@ -1,3 +1,5 @@ +"""Train FasterRCNN model for object detection.""" + import argparse import os import sys @@ -26,15 +28,17 @@ class DectectorTrain: - """Training class for detector algorithm + """Training class for detector algorithm. Parameters ---------- args: argparse.Namespace An object containing the parsed command-line arguments. + """ def __init__(self, args: argparse.Namespace): + """Initialise the training class with the given arguments.""" # inputs self.args = args self.config_file = args.config_file @@ -62,16 +66,12 @@ def __init__(self, args: argparse.Namespace): self.checkpoint_path = args.checkpoint_path def load_config_yaml(self): - """ - Load yaml file that contains config parameters. - """ - with open(self.config_file, "r") as f: + """Load yaml file that contains config parameters.""" + with open(self.config_file) as f: self.config = yaml.safe_load(f) def setup_trainer(self): - """ - Setup trainer with logging and checkpointing. - """ + """Set up trainer with logging and checkpointing.""" self.run_name = set_mlflow_run_name() # Setup logger with checkpointing @@ -91,7 +91,9 @@ def setup_trainer(self): filename="checkpoint-{epoch}", every_n_epochs=config_ckpt["every_n_epochs"], save_top_k=config_ckpt["keep_last_n_ckpts"], - monitor="epoch", # monitor the metric "epoch" for selecting which checkpoints to save + monitor="epoch", + # monitor the metric "epoch" for selecting which checkpoints + # to save mode="max", # get the max of the monitored metric save_last=config_ckpt["save_last"], save_weights_only=config_ckpt["save_weights_only"], @@ -127,6 +129,7 @@ def optuna_objective_fn(self, trial: optuna.Trial) -> float: ------- float The value to maximise. + """ # Sample hyperparameters from the search space for this trial optuna_config = self.config["optuna"] @@ -161,6 +164,7 @@ def core_training(self) -> lightning.Trainer: ------- lightning.Trainer The trainer object used for training. + """ # Create data module data_module = CrabsDataModule( @@ -182,9 +186,11 @@ def core_training(self) -> lightning.Trainer: if checkpoint_type == "weights": lightning_model = FasterRCNN.load_from_checkpoint( self.checkpoint_path, - config=self.config, # overwrite hparams from ckpt with config + config=self.config, + # overwrite hparams from ckpt with config optuna_log=self.args.optuna, - ) # a 'weights' checkpoint is one saved with `save_weights_only=True` + ) + # a 'weights' checkpoint is saved with `save_weights_only=True` # Get trainer trainer = self.setup_trainer() @@ -199,13 +205,15 @@ def core_training(self) -> lightning.Trainer: self.checkpoint_path if checkpoint_type == "full" else None ), # a 'full' checkpoint is one saved with `save_weights_only=False` - # (automatically restores model, epoch, step, LR schedulers, etc...) - # see https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters + # (automatically restores model, epoch, step, LR schedulers, etc.) + # see + # https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters ) return trainer def train_model(self): + """Train detector.""" # Run hyperparameter sweep with Optuna if required if self.args.optuna: # Optimize hyperparameters in config @@ -228,8 +236,7 @@ def train_model(self): def main(args) -> None: - """ - Main function to orchestrate the training process. + """Run training process. Parameters ---------- @@ -239,12 +246,14 @@ def main(args) -> None: Returns ------- None + """ trainer = DectectorTrain(args) trainer.train_model() def train_parse_args(args): + """Parse command-line arguments for training.""" parser = argparse.ArgumentParser() parser.add_argument( "--dataset_dirs", @@ -257,8 +266,10 @@ def train_parse_args(args): nargs="+", default=[], help=( - "list of paths to annotation files. The full path or the filename can be provided. " - "If only filename is provided, it is assumed to be under dataset/annotations." + "list of paths to annotation files. The full path or the filename " + "can be provided. " + "If only filename is provided, it is assumed to be under " + "dataset/annotations." ), ) parser.add_argument( @@ -267,7 +278,8 @@ def train_parse_args(args): default=str(Path(__file__).parent / "config" / "faster_rcnn.yaml"), help=( "Location of YAML config to control training. " - "Default: crabs-exploration/crabs/detection_tracking/config/faster_rcnn.yaml" + "Default: " + "crabs-exploration/crabs/detection_tracking/config/faster_rcnn.yaml" ), ) parser.add_argument( @@ -275,9 +287,10 @@ def train_parse_args(args): type=str, default="gpu", help=( - "Accelerator for Pytorch Lightning. Valid inputs are: cpu, gpu, tpu, ipu, auto, mps. Default: gpu" - "See https://lightning.ai/docs/pytorch/stable/common/trainer.html#accelerator " - "and https://lightning.ai/docs/pytorch/stable/accelerators/mps_basic.html#run-on-apple-silicon-gpus" + "Accelerator for Pytorch Lightning. " + "Valid inputs are: cpu, gpu, tpu, ipu, auto, mps. Default: gpu. " + "See https://lightning.ai/docs/pytorch/stable/common/trainer.html#accelerator " # noqa: E501 + "and https://lightning.ai/docs/pytorch/stable/accelerators/mps_basic.html#run-on-apple-silicon-gpus" # noqa: E501 ), ) parser.add_argument( @@ -285,8 +298,10 @@ def train_parse_args(args): type=str, default="Sept2023", help=( - "Name of the experiment in MLflow, under which the current run will be logged. " - "For example, the name of the dataset could be used, to group runs using the same data. " + "Name of the experiment in MLflow, under which " + "the current run will be logged. " + "For example, the name of the dataset could be used, " + "to group runs using the same data. " "Default: Sep2023" ), ) @@ -306,8 +321,9 @@ def train_parse_args(args): type=float, default=1.0, help=( - "Debugging option to run training on a fraction of the training set." - "Default: 1.0 (all the training set)" + "Debugging option to run training on a fraction of the " + "training set. " + "Default: 1.0 (the full training set)" ), ) parser.add_argument( @@ -325,22 +341,29 @@ def train_parse_args(args): parser.add_argument( "--optuna", action="store_true", - help="Run a hyperparameter optimisation using Optuna prior to training the model", + help=( + "Run a hyperparameter optimisation using Optuna prior to training " + "the model" + ), ) parser.add_argument( "--no_data_augmentation", action="store_true", - help="Ignore the data augmentation transforms defined in config file", + help=( + "Ignore the data augmentation transforms " + "defined in the config file" + ), ) parser.add_argument( "--log_data_augmentation", action="store_true", - help="Log data augmentation transforms linked to datamodule as MLflow artifacts", + help=("Log data augmentation transforms to " "MLflow as artifacts"), ) return parser.parse_args(args) def app_wrapper(): + """Wrap function to run the training application.""" torch.set_float32_matmul_precision("medium") train_args = train_parse_args(sys.argv[1:]) diff --git a/crabs/detector/utils/detection.py b/crabs/detector/utils/detection.py index dd51bda3..096facda 100644 --- a/crabs/detector/utils/detection.py +++ b/crabs/detector/utils/detection.py @@ -13,19 +13,20 @@ def prep_img_directories(dataset_dirs: list[str]) -> list[str]: - """ - Derive list of input image directories from a list of dataset directories. + """Get list of input image directories from a list of dataset directories. + We assume a specific structure for the dataset directories. - Parameters: - ----------- + Parameters + ---------- dataset_dirs : List[str] List of directories containing dataset folders. - Returns: - -------- + Returns + ------- List[str]: List of directories containing image frames. + """ images_dirs = [] for dataset in dataset_dirs: @@ -36,20 +37,20 @@ def prep_img_directories(dataset_dirs: list[str]) -> list[str]: def prep_annotation_files( input_annotation_files: list[str], dataset_dirs: list[str] ) -> list[str]: - """ - Prepares annotation files for processing. + """Prepare annotation files for processing. - Parameters: - ----------- + Parameters + ---------- input_annotation_files : List[str] List of annotation files or filenames. dataset_dirs : List[str] List of directories containing dataset folders. - Returns: - -------- + Returns + ------- List[str]: List of annotation file paths. + """ # prepare list of annotation files annotation_files = [] @@ -86,9 +87,9 @@ def log_metadata_to_logger( mlf_logger: MLFlowLogger, cli_args: argparse.Namespace, ) -> MLFlowLogger: - """ - Log metadata to MLflow logger. - Add CLI arguments and, if available, SLURM job information. + """Log metadata to MLflow logger. + + Adds CLI arguments to logger and, if available, SLURM job information. Parameters ---------- @@ -101,8 +102,8 @@ def log_metadata_to_logger( ------- MLFlowLogger An MLflow logger instance with metadata logged. - """ + """ # Log CLI arguments mlf_logger.log_hyperparams({"cli_args": cli_args}) @@ -126,8 +127,7 @@ def log_metadata_to_logger( def set_mlflow_run_name() -> str: - """ - Set MLflow run name. + """Set MLflow run name. Use the slurm job ID if it is a SLURM job, else use a timestamp. For SLURM jobs: @@ -158,10 +158,9 @@ def setup_mlflow_logger( run_name: str, mlflow_folder: str, cli_args: argparse.Namespace, - ckpt_config: dict[str, Any] = {}, + ckpt_config: dict[str, Any] = {}, # noqa: B006 ) -> MLFlowLogger: - """ - Setup MLflow logger and log job metadata, with optional checkpointing. + """Set up MLflow logger and log job metadata. Setup MLflow logger for a given experiment and run name. If a checkpointing config is passed, it will setup the logger with a @@ -186,8 +185,8 @@ def setup_mlflow_logger( ------- MLFlowLogger A logger to record data for MLflow - """ + """ # Setup MLflow logger for a given experiment and run name # (with checkpointing if required) mlf_logger = MLFlowLogger( @@ -217,11 +216,11 @@ def setup_mlflow_logger( def slurm_logs_as_artifacts(logger: MLFlowLogger, slurm_job_id: str): - """ - Add slurm logs as an MLflow artifacts of the current run. - The filenaming convention from the training scripts at crabs-exploration/bash_scripts/ is assumed. - """ + """Add slurm logs as an MLflow artifacts of the current run. + The filenaming convention from the training scripts at + crabs-exploration/bash_scripts/ is assumed. + """ # Get slurm env variables: slurm and array job ID slurm_node = os.environ.get("SLURMD_NODENAME") slurm_array_job_id = os.environ.get("SLURM_ARRAY_JOB_ID") @@ -267,6 +266,7 @@ def bbox_tensors_to_COCO_dict( ------- dict COCO format dictionary with bounding boxes. + """ # Create list of image filenames if not provided if list_img_filenames is None: diff --git a/crabs/detector/utils/evaluate.py b/crabs/detector/utils/evaluate.py index d985fd4e..e2633291 100644 --- a/crabs/detector/utils/evaluate.py +++ b/crabs/detector/utils/evaluate.py @@ -1,4 +1,4 @@ -"""Utils used in evaluation""" +"""Utils used in evaluation.""" import argparse import ast @@ -18,8 +18,7 @@ def compute_precision_recall(class_stats: dict) -> tuple[float, float, dict]: - """ - Compute precision and recall. + """Compute precision and recall. Parameters ---------- @@ -27,9 +26,10 @@ def compute_precision_recall(class_stats: dict) -> tuple[float, float, dict]: Statistics or information about different classes. Returns - ---------- + ------- Tuple[float, float] precision and recall + """ for _, stats in class_stats.items(): precision = stats["tp"] / max(stats["tp"] + stats["fp"], 1) @@ -41,8 +41,9 @@ def compute_precision_recall(class_stats: dict) -> tuple[float, float, dict]: def compute_confusion_matrix_elements( targets: list, detections: list, ious_threshold: float ) -> tuple[float, float, dict]: - """ - Compute metrics (true positive, false positive, false negative) for object detection. + """Compute detection metrics. + + Compute true positive, false positive, and false negative values. Parameters ---------- @@ -58,9 +59,10 @@ def compute_confusion_matrix_elements( Statistics or information about different classes. Returns - ---------- + ------- Tuple[float, float] precision and recall + """ class_stats = {"crab": {"tp": 0, "fp": 0, "fn": 0}} for target, detection in zip(targets, detections): @@ -85,24 +87,27 @@ def compute_confusion_matrix_elements( else: class_stats["crab"]["fp"] += 1 - for target_box_index, target_box in enumerate(gt_boxes): + for target_box_index, _target_box in enumerate(gt_boxes): found_match = False for idx, iou in enumerate(max_ious): if ( - iou.item() - > ious_threshold # we need this condition because the max overlap is not necessarily above the threshold - and max_indices[idx] - == target_box_index # the matching index is the index of the GT box with which it has max overlap + iou.item() > ious_threshold + # we need this condition because the max overlap + # is not necessarily above the threshold + and max_indices[idx] == target_box_index + # the matching index is the index of the GT + # box with which it has max overlap ): - # There's an IoU match and the matched index corresponds to the current target_box_index + # There's an IoU match and the matched index corresponds + # to the current target_box_index found_match = True break # Exit loop, a match was found if not found_match: # print(found_match) - class_stats["crab"][ - "fn" - ] += 1 # Ground truth box has no corresponding detection + class_stats["crab"]["fn"] += ( + 1 # Ground truth box has no corresponding detection + ) precision, recall, class_stats = compute_precision_recall(class_stats) @@ -120,7 +125,7 @@ def get_mlflow_parameters_from_ckpt(trained_model_path: str) -> dict: try: assert ( Path(trained_model_path).parent.stem == "checkpoints" - ), "The parent directory to an MLflow checkpoint is expected to be called 'checkpoints'" + ), "The parent directory to an MLflow checkpoint is expected to be called 'checkpoints'" # noqa: E501 except AssertionError as e: print(f"Assertion failed: {e}") sys.exit(1) @@ -144,10 +149,9 @@ def get_mlflow_parameters_from_ckpt(trained_model_path: str) -> dict: def get_config_from_ckpt(config_file: str, trained_model_path: str) -> dict: """Get config from checkpoint if config is not passed as a CLI argument.""" - # If config in CLI arguments: used passed config if config_file: - with open(config_file, "r") as f: + with open(config_file) as f: config_dict = yaml.safe_load(f) # If not: used config from ckpt @@ -202,7 +206,6 @@ def get_img_directories_from_ckpt( args: argparse.Namespace, trained_model_path: str ) -> list[str]: """Get image directories from checkpoint if not passed as CLI argument.""" - # Get dataset directories from ckpt if not defined dataset_dirs = get_cli_arg_from_ckpt( args=args, @@ -220,7 +223,6 @@ def get_annotation_files_from_ckpt( args: argparse.Namespace, trained_model_path: str ) -> list[str]: """Get annotation files from checkpoint if not passed as CLI argument.""" - # Get path to input annotation files from ckpt if not defined input_annotation_files = get_cli_arg_from_ckpt( args=args, diff --git a/crabs/detector/utils/hpo.py b/crabs/detector/utils/hpo.py index 24dec5dc..801a3181 100644 --- a/crabs/detector/utils/hpo.py +++ b/crabs/detector/utils/hpo.py @@ -1,12 +1,12 @@ -"""Utils for hyperparameter optimisation""" +"""Utils for hyperparameter optimisation with Optuna.""" -from typing import Callable, Dict +from typing import Callable import optuna def compute_optimal_hyperparameters( - objective_fn: Callable, config_optuna: Dict, direction: str = "maximize" + objective_fn: Callable, config_optuna: dict, direction: str = "maximize" ) -> dict: """Compute hyperparameters that optimize the objective function. @@ -28,6 +28,7 @@ def compute_optimal_hyperparameters( ------- dict The optimal parameters computed by Optuna. + """ # Create an study study = optuna.create_study(direction=direction) diff --git a/crabs/detector/utils/train.py b/crabs/detector/utils/train.py index 11594625..499ae7db 100644 --- a/crabs/detector/utils/train.py +++ b/crabs/detector/utils/train.py @@ -1,4 +1,4 @@ -"""Utils used in training""" +"""Utils used in training.""" import logging from typing import Optional diff --git a/crabs/detector/utils/visualization.py b/crabs/detector/utils/visualization.py index e6fc98bc..784eb76a 100644 --- a/crabs/detector/utils/visualization.py +++ b/crabs/detector/utils/visualization.py @@ -1,3 +1,5 @@ +"""Utilities for visualizing object detection results.""" + import os from datetime import datetime from typing import Any, Optional @@ -23,17 +25,18 @@ def draw_bbox( colour: tuple, label_text: Optional[str] = None, ) -> None: - """ - Draw bounding boxes on the image based on detection results. + """Draw bounding boxes on the image based on detection results. Parameters ---------- frame : np.ndarray Image with bounding boxes drawn on it. top_left : tuple[float, float] - Tuple containing (x, y) coordinates of the top-left corner of the bounding box. + Tuple containing (x, y) coordinates of the top-left corner of the + bounding box. bottom_right : tuple[float, float] - Tuple containing (x, y) coordinates of the bottom-right corner of the bounding box. + Tuple containing (x, y) coordinates of the bottom-right corner of the + bounding box. colour : tuple Color of the bounding box in BGR format. label_text : str, optional @@ -42,6 +45,7 @@ def draw_bbox( Returns ------- None + """ # Draw bounding box cv2.rectangle( @@ -72,8 +76,7 @@ def draw_detection( detections: Optional[dict[Any, Any]] = None, score_threshold: Optional[float] = None, ) -> np.ndarray: - """ - Draw the results based on the detection. + """Draw the results based on the detection. Parameters ---------- @@ -90,9 +93,9 @@ def draw_detection( ------- np.ndarray Image(s) with bounding boxes drawn on them. + """ coco_list = COCO_INSTANCE_CATEGORY_NAMES - image_with_boxes = None for image, label, prediction in zip( imgs, annotations, detections or [None] * len(imgs) @@ -156,8 +159,7 @@ def save_images_with_boxes( output_dir: str, score_threshold: float, ) -> None: - """ - Save images with bounding boxes drawn around detected objects. + """Save images with bounding boxes drawn around detected objects. Parameters ---------- @@ -165,12 +167,15 @@ def save_images_with_boxes( DataLoader for the test dataset. trained_model : torch.nn.Module The trained object detection model. + output_dir : str + Directory to save the images with bounding boxes. score_threshold : float Threshold for object detection. Returns - ---------- + ------- None + """ device = ( torch.device("cuda") @@ -189,7 +194,7 @@ def save_images_with_boxes( with torch.no_grad(): imgs_id = 0 for imgs, annotations in test_dataloader: - imgs_id += 1 + imgs_id += 1 # noqa: SIM113 imgs = list(img.to(device) for img in imgs) detections = trained_model(imgs) @@ -201,9 +206,10 @@ def save_images_with_boxes( cv2.imwrite(f"{output_dir}/imgs{imgs_id}.jpg", image_with_boxes) -def plot_sample(imgs: list, row_title: Optional[str] = None, **imshow_kwargs): - """ - Plot a sample (image & annotations) from a dataset. +def plot_sample( # noqa: C901 + imgs: list, row_title: Optional[str] = None, **imshow_kwargs +): + """Plot a sample (image & annotations) from a dataset. Example usage: > full_dataset = CrabsCocoDetection([IMAGES_PATH],[ANNOTATIONS_PATH]) @@ -211,7 +217,8 @@ def plot_sample(imgs: list, row_title: Optional[str] = None, **imshow_kwargs): > plt.figure() > plot_sample([sample]) - From https://github.com/pytorch/vision/blob/main/gallery/transforms/helpers.py + From: + https://github.com/pytorch/vision/blob/main/gallery/transforms/helpers.py """ if not isinstance(imgs[0], list): # Make a 2d grid even if there's just 1 row diff --git a/crabs/stereo_calibration/extract_pairs_of_frames.py b/crabs/stereo_calibration/extract_pairs_of_frames.py index f1d363ae..bb96efe6 100644 --- a/crabs/stereo_calibration/extract_pairs_of_frames.py +++ b/crabs/stereo_calibration/extract_pairs_of_frames.py @@ -1,15 +1,16 @@ +"""Script to extract pairs of frames for stereo calibration.""" + import logging from pathlib import Path import cv2 -import ffmpeg +import ffmpeg # type: ignore import typer from timecode import Timecode def compute_timecode_params_per_video(list_paths: list[Path]) -> dict: - """ - Compute timecode parameters per video + """Compute timecode parameters per video. We assume the timecode data is logged in the timecode stream ("tmcd"), since we are expecting MOV files (see Notes for further details). @@ -116,9 +117,7 @@ def compute_timecode_params_per_video(list_paths: list[Path]) -> dict: # compute end timecode end_timecode_tuple = start_timecode.frames_to_tc( - start_timecode.frames - + n_frames - - 1 + start_timecode.frames + n_frames - 1 # do not count the first frame twice! ) end_timecode_str = start_timecode.tc_to_string(*end_timecode_tuple) @@ -138,9 +137,7 @@ def compute_timecode_params_per_video(list_paths: list[Path]) -> dict: def compute_synching_timecodes( timecodes_dict: dict, ) -> tuple[Timecode, Timecode]: - """ - Determine the timecodes for the first and last frame - in common across all videos + """Determine timecodes for first and last common frames across all videos. We assume all videos in timecodes_dict were timecode-synched (i.e., their timecode streams will overlap from the common start frame @@ -189,8 +186,9 @@ def compute_opencv_start_idx( timecodes_dict: dict, max_min_timecode: tuple[Timecode, Timecode], ) -> dict: - """ - Compute the start and end indices of a set of videos + """Compute start and end indices of a set of videos. + + Compute start and end indices of a set of videos for opencv tools, based on their common starting frame (max_start_timecode) and their common end frame (min_end_timecode). @@ -233,6 +231,7 @@ def compute_opencv_start_idx( an extension to the timecodes_dict, with the openCV start and end indices per video. Both are 0-based indices relative to the start of the video. + """ (max_start_timecode, min_end_timecode) = max_min_timecode @@ -265,9 +264,7 @@ def extract_chessboard_frames_from_video( chessboard_config: dict, output_parent_dir: str = "./calibration_pairs", ): - """ - Extract frames with a chessboard pattern between the selected indices - and save them to an output directory + """Extract frames with a chessboard pattern between the selected indices. TODO: detecting the checkerboard is very slow with open-cv if no board is present. See issue here: @@ -283,11 +280,14 @@ def extract_chessboard_frames_from_video( - n_frames: number of frames from ffmpeg - opencv_start_idx: start index for synced period - opencv_end_idx: end index for synced period + chessboard_config : dict + A dictionary specifying the number of rows and columns of the + chessboard pattern. output_parent_dir : str, optional directory to which save the extracted synced frames, by default "./calibration_pairs" - """ + """ # initialise capture cap = cv2.VideoCapture(video_path_str) if cap.get(cv2.CAP_PROP_FRAME_COUNT) != video_dict["n_frames"]: @@ -316,7 +316,8 @@ def extract_chessboard_frames_from_video( if success: # --------------- # Find the chessboard corners - # If desired number of corners are found in the image then ret = true + # If desired number of corners are found in the image then + # ret = true # TODO: append 2d coords of corners? frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) @@ -325,7 +326,9 @@ def extract_chessboard_frames_from_video( frame_gray, (chessboard_config["rows"], chessboard_config["cols"]), None, - ) # cv2.CALIB_CB_ADAPTIVE_THRESH + cv2.CALIB_CB_FAST_CHECK + cv2.CALIB_CB_NORMALIZE_IMAGE) + ) + # cv2.CALIB_CB_ADAPTIVE_THRESH + cv2.CALIB_CB_FAST_CHECK + + # cv2.CALIB_CB_NORMALIZE_IMAGE # ------------- if ret: # filepath @@ -363,7 +366,7 @@ def main( video_extensions: list, output_calibration_dir: str = "./calibration_pairs", ): - """_summary_ + """Extract pairs of frames for stereo calibration. Parameters ---------- @@ -376,7 +379,6 @@ def main( by default "./calibration_pairs" """ - # Transform extensions to file_types regular expressions file_types = tuple(f"**/*.{ext}" for ext in video_extensions) @@ -421,8 +423,5 @@ def main( ) -# ---------- -# Main -# ------------ if __name__ == "__main__": typer.run(main) diff --git a/crabs/tracker/evaluate_tracker.py b/crabs/tracker/evaluate_tracker.py index e531581a..e40bfe5d 100644 --- a/crabs/tracker/evaluate_tracker.py +++ b/crabs/tracker/evaluate_tracker.py @@ -1,6 +1,8 @@ +"""Evaluate tracker using the Multi-Object Tracking Accuracy (MOTA) metric.""" + import csv import logging -from typing import Any, Dict, Optional, Tuple +from typing import Any, Optional import numpy as np @@ -8,45 +10,52 @@ class TrackerEvaluate: + """Interface to evaluate tracker.""" + def __init__( self, gt_dir: str, predicted_boxes_id: list[np.ndarray], iou_threshold: float, ): - """ - Initialize the TrackerEvaluate class with ground truth directory, tracked list, and IoU threshold. + """Initialize the TrackerEvaluate class. + + Initialised with ground truth directory, tracked list, and IoU + threshold. Parameters ---------- gt_dir : str Directory path of the ground truth CSV file. - tracked_list : List[np.ndarray] - A list where each element is a numpy array representing tracked objects in a frame. - Each numpy array has shape (N, 5), where N is the number of objects. - The columns are [x1, y1, x2, y2, id], where (x1, y1) and (x2, y2) - define the bounding box and id is the object ID. + predicted_boxes_id : list[np.ndarray] + List of numpy arrays containing predicted bounding boxes and IDs. iou_threshold : float - Intersection over Union (IoU) threshold for evaluating tracking performance. + Intersection over Union (IoU) threshold for evaluating + tracking performance. + """ self.gt_dir = gt_dir self.predicted_boxes_id = predicted_boxes_id self.iou_threshold = iou_threshold - self.last_known_predicted_ids: Dict = {} + self.last_known_predicted_ids: dict = {} - def get_predicted_data(self) -> Dict[int, Dict[str, Any]]: - """ - Convert predicted bounding box and ID into a dictionary organized by frame number. + def get_predicted_data(self) -> dict[int, dict[str, Any]]: + """Format predicted bounding box and ID as dictionary. + + Dictionary keys are frame numbers. Returns ------- - Dict[int, Dict[str, Any]]: - A dictionary where the key is the frame number and the value is another dictionary containing: - - 'bbox': A numpy array with shape (N, 4) containing coordinates of the bounding boxes - [x, y, x + width, y + height] for every object in the frame. + dict[int, dict[str, Any]]: + A dictionary where the key is the frame number and the value is + another dictionary containing: + - 'bbox': A numpy array with shape (N, 4) containing coordinates + of the bounding boxes [x, y, x + width, y + height] for every + object in the frame. - 'id': A numpy array containing the IDs of the tracked objects. + """ - predicted_dict: Dict[int, Dict[str, Any]] = {} + predicted_dict: dict[int, dict[str, Any]] = {} for frame_number, frame_data in enumerate(self.predicted_boxes_id): if frame_data.size == 0: @@ -59,19 +68,21 @@ def get_predicted_data(self) -> Dict[int, Dict[str, Any]]: return predicted_dict - def get_ground_truth_data(self) -> Dict[int, Dict[str, Any]]: - """ - Extract ground truth bounding box data from a CSV file and organize it by frame number. + def get_ground_truth_data(self) -> dict[int, dict[str, Any]]: + """Fromat ground truth bounding box data as dict with key frame number. Returns ------- - Dict[int, Dict[str, Any]]: - A dictionary where the key is the frame number and the value is another dictionary containing: - - 'bbox': A numpy arrays with shape of (N, 4) containing coordinates of the bounding box - [x, y, x + width, y + height] for every crabs in the frame. + dict[int, dict[str, Any]]: + A dictionary where the key is the frame number and the value is + another dictionary containing: + - 'bbox': A numpy arrays with shape of (N, 4) containing + coordinates of the bounding box [x, y, x + width, y + height] + for every crabs in the frame. - 'id': The ground truth ID + """ - with open(self.gt_dir, "r") as csvfile: + with open(self.gt_dir) as csvfile: csvreader = csv.reader(csvfile) next(csvreader) # Skip the header row ground_truth_data = [ @@ -110,22 +121,24 @@ def get_ground_truth_data(self) -> Dict[int, Dict[str, Any]]: return ground_truth_dict def calculate_iou(self, box1: np.ndarray, box2: np.ndarray) -> float: - """ - Calculate IoU (Intersection over Union) of two bounding boxes. + """Calculate IoU (Intersection over Union) of two bounding boxes. Parameters ---------- - box1 (np.ndarray): + box1 : np.ndarray Coordinates [x1, y1, x2, y2] of the first bounding box. - Here, (x1, y1) represents the top-left corner, and (x2, y2) represents the bottom-right corner. - box2 (np.ndarray): + Here, (x1, y1) represents the top-left corner, and (x2, y2) + represents the bottom-right corner. + box2 : np.ndarray Coordinates [x1, y1, x2, y2] of the second bounding box. - Here, (x1, y1) represents the top-left corner, and (x2, y2) represents the bottom-right corner. + Here, (x1, y1) represents the top-left corner, and (x2, y2) + represents the bottom-right corner. Returns ------- - float: + float IoU value. + """ x1_box1, y1_box1, x2_box1, y2_box1 = box1 x1_box2, y1_box2, x2_box2, y2_box2 = box2 @@ -149,25 +162,27 @@ def calculate_iou(self, box1: np.ndarray, box2: np.ndarray) -> float: return iou - def count_identity_switches( + def count_identity_switches( # noqa: C901 self, - gt_to_tracked_id_previous_frame: Optional[Dict[int, int]], - gt_to_tracked_id_current_frame: Dict[int, int], + gt_to_tracked_id_previous_frame: Optional[dict[int, int]], + gt_to_tracked_id_current_frame: dict[int, int], ) -> int: - """ - Count the number of identity switches between two sets of object IDs. + """Count the number of identity switches between two sets of IDs. Parameters ---------- - gt_to_tracked_id_previous_frame : Optional[Dict[int, int]] - A dictionary mapping ground truth IDs to predicted IDs from the previous frame. - gt_to_tracked_id_current_frame : Dict[int, int] - A dictionary mapping ground truth IDs to predicted IDs for the current frame. + gt_to_tracked_id_previous_frame : Optional[dict[int, int]] + A dictionary mapping ground truth IDs to predicted IDs from the + previous frame. + gt_to_tracked_id_current_frame : dict[int, int] + A dictionary mapping ground truth IDs to predicted IDs for the + current frame. Returns ------- int The number of identity switches between the two sets of object IDs. + """ if gt_to_tracked_id_previous_frame is None: for gt_id, pred_id in gt_to_tracked_id_current_frame.items(): @@ -176,17 +191,21 @@ def count_identity_switches( return 0 switch_counter = 0 - # Filter sets of ground truth IDs for current and previous frames to exclude NaN predicted IDs + # Filter sets of ground truth IDs for current and previous frames + # to exclude NaN predicted IDs gt_ids_current_frame = set(gt_to_tracked_id_current_frame.keys()) gt_ids_prev_frame = set(gt_to_tracked_id_previous_frame.keys()) - # Compute lists of ground truth IDs that continue, disappear, and appear + # Compute lists of ground truth IDs that continue, disappear, + # and appear gt_ids_cont = list(gt_ids_current_frame & gt_ids_prev_frame) gt_ids_disappear = list(gt_ids_prev_frame - gt_ids_current_frame) gt_ids_appear = list(gt_ids_current_frame - gt_ids_prev_frame) # Store used predicted IDs to avoid double counting - # In `used_pred_ids` we log IDs from either the current or the previous frame that have been involved in an already counted ID switch. + # In `used_pred_ids` we log IDs from either the current or the + # previous frame that have been involved in an already + # counted ID switch. used_pred_ids = set() # Case 1: Objects that continue to exist according to GT @@ -194,13 +213,15 @@ def count_identity_switches( previous_pred_id = gt_to_tracked_id_previous_frame.get(gt_id) current_pred_id = gt_to_tracked_id_current_frame.get(gt_id) if all( - not np.isnan(x) for x in [previous_pred_id, current_pred_id] + not np.isnan(x) # type: ignore + for x in [previous_pred_id, current_pred_id] ): # Exclude if missed detection in previous AND current frame if current_pred_id != previous_pred_id: switch_counter += 1 used_pred_ids.add(current_pred_id) - # if the object was a missed detection in the previous frame: check if current prediction matches historical - elif np.isnan(previous_pred_id) and not np.isnan(current_pred_id): + # if the object was a missed detection in the previous frame: + # check if current prediction matches historical + elif np.isnan(previous_pred_id) and not np.isnan(current_pred_id): # type: ignore # noqa: SIM102 if gt_id in self.last_known_predicted_ids: last_known_predicted_id = self.last_known_predicted_ids[ gt_id @@ -213,10 +234,10 @@ def count_identity_switches( # Case 2: Objects that disappear according to GT for gt_id in gt_ids_disappear: previous_pred_id = gt_to_tracked_id_previous_frame.get(gt_id) - if not np.isnan( - previous_pred_id + if not np.isnan( # noqa: SIM102 + previous_pred_id # type: ignore ): # Exclude if missed detection in previous frame - if previous_pred_id in gt_to_tracked_id_current_frame.values(): + if previous_pred_id in gt_to_tracked_id_current_frame.values(): # noqa: SIM102 if previous_pred_id not in used_pred_ids: switch_counter += 1 used_pred_ids.add(previous_pred_id) @@ -225,7 +246,7 @@ def count_identity_switches( for gt_id in gt_ids_appear: current_pred_id = gt_to_tracked_id_current_frame.get(gt_id) if not np.isnan( - current_pred_id + current_pred_id # type: ignore ): # Exclude if missed detection in current frame # check if there was and ID switch wrt previous frame if current_pred_id in gt_to_tracked_id_previous_frame.values(): @@ -245,35 +266,38 @@ def count_identity_switches( def evaluate_mota( self, - gt_data: Dict[str, np.ndarray], - pred_data: Dict[str, np.ndarray], + gt_data: dict[str, np.ndarray], + pred_data: dict[str, np.ndarray], iou_threshold: float, - gt_to_tracked_id_previous_frame: Optional[Dict[int, int]], - ) -> Tuple[float, Dict[int, int]]: - """ - Evaluate MOTA (Multiple Object Tracking Accuracy). + gt_to_tracked_id_previous_frame: Optional[dict[int, int]], + ) -> tuple[float, dict[int, int]]: + """Evaluate MOTA (Multiple Object Tracking Accuracy). Parameters ---------- - gt_data : Dict[str, np.ndarray] + gt_data : dict[str, np.ndarray] Dictionary containing ground truth bounding boxes and IDs. - 'bbox': Bounding boxes with shape (N, 4). - 'id': Ground truth IDs with shape (N,). - pred_data : Dict[str, np.ndarray] + pred_data : dict[str, np.ndarray] Dictionary containing predicted bounding boxes and IDs. - 'bbox': Bounding boxes with shape (N, 4). - 'id': Predicted IDs with shape (N,). iou_threshold : float Intersection over Union (IoU) threshold for considering a match. - gt_to_tracked_id_previous_frame : Optional[Dict[int, int]] - A dictionary mapping ground truth IDs to predicted IDs from the previous frame. + gt_to_tracked_id_previous_frame : Optional[dict[int, int]] + A dictionary mapping ground truth IDs to predicted IDs from the + previous frame. Returns ------- float - The computed MOTA (Multi-Object Tracking Accuracy) score for the tracking performance. - Dict[int, int] - A dictionary mapping ground truth IDs to predicted IDs for the current frame. + The computed MOTA (Multi-Object Tracking Accuracy) score for the + tracking performance. + dict[int, int] + A dictionary mapping ground truth IDs to predicted IDs for the + current frame. + """ total_gt = len(gt_data["bbox"]) false_positive = 0 @@ -286,7 +310,7 @@ def evaluate_mota( gt_boxes = gt_data["bbox"] gt_ids = gt_data["id"] - for i, (pred_box, pred_id) in enumerate(zip(pred_boxes, pred_ids)): + for _i, (pred_box, pred_id) in enumerate(zip(pred_boxes, pred_ids)): best_iou = 0.0 index_gt_best_match = None index_gt_not_match = None @@ -301,7 +325,8 @@ def evaluate_mota( index_gt_not_match = j if index_gt_best_match is not None: - # Successfully found a matching ground truth box for the tracked box. + # Successfully found a matching ground truth box for the + # tracked box. indices_of_matched_gt_boxes.add(index_gt_best_match) # Map ground truth ID to tracked ID gt_to_tracked_id_current_frame[ @@ -312,7 +337,7 @@ def evaluate_mota( if index_gt_not_match is not None: gt_to_tracked_id_current_frame[ int(gt_ids[index_gt_not_match]) - ] = np.nan + ] = np.nan # type: ignore missed_detections = total_gt - len(indices_of_matched_gt_boxes) num_switches = self.count_identity_switches( @@ -327,23 +352,26 @@ def evaluate_mota( def evaluate_tracking( self, - ground_truth_dict: Dict[int, Dict[str, Any]], - predicted_dict: Dict[int, Dict[str, Any]], + ground_truth_dict: dict[int, dict[str, Any]], + predicted_dict: dict[int, dict[str, Any]], ) -> list[float]: - """ - Evaluate tracking performance using the Multi-Object Tracking Accuracy (MOTA) metric. + """Evaluate tracking with the Multi-Object Tracking Accuracy metric. Parameters ---------- ground_truth_dict : dict - Dictionary containing ground truth bounding boxes and IDs for each frame, organized by frame number. + Dictionary containing ground truth bounding boxes and IDs for each + frame, organized by frame number. predicted_dict : dict - Dictionary containing predicted bounding boxes and IDs for each frame, organized by frame number. + Dictionary containing predicted bounding boxes and IDs for each + frame, organized by frame number. Returns ------- list[float]: - The computed MOTA (Multi-Object Tracking Accuracy) score for the tracking performance. + The computed MOTA (Multi-Object Tracking Accuracy) score for the + tracking performance. + """ mota_values = [] prev_frame_id_map: Optional[dict] = None @@ -364,12 +392,10 @@ def evaluate_tracking( return mota_values def run_evaluation(self) -> None: - """ - Run evaluation of tracking based on tracking ground truth. - """ + """Run evaluation of tracking based on tracking ground truth.""" predicted_dict = self.get_predicted_data() ground_truth_dict = self.get_ground_truth_data() mota_values = self.evaluate_tracking(ground_truth_dict, predicted_dict) overall_mota = np.mean(mota_values) - logging.info("Overall MOTA: %f" % overall_mota) + logging.info("Overall MOTA: %f" % overall_mota) # noqa: UP031 diff --git a/crabs/tracker/sort.py b/crabs/tracker/sort.py index a38635c9..d0ae2d79 100644 --- a/crabs/tracker/sort.py +++ b/crabs/tracker/sort.py @@ -1,5 +1,5 @@ -""" -SORT: A Simple, Online and Realtime Tracker +"""SORT: A Simple, Online and Realtime Tracker. + Copyright (C) 2016-2020 Alex Bewley alex@bewley.ai This program is free software: you can redistribute it and/or modify @@ -26,10 +26,8 @@ ) -class KalmanBoxTracker(object): - """ - This class represents the internal state of individual tracked objects - observed as bbox. +class KalmanBoxTracker: + """Class for the internal state of individual tracked objects. Parameters ---------- @@ -41,9 +39,7 @@ class KalmanBoxTracker(object): count = 0 def __init__(self, bbox): - """ - Initialises a tracker using initial bounding box. - """ + """Initialise a tracker using initial bounding box.""" # define constant velocity model self.kf = KalmanFilter(dim_x=7, dim_z=4) self.kf.F = np.array( @@ -67,9 +63,10 @@ def __init__(self, bbox): ) self.kf.R[2:, 2:] *= 10.0 - self.kf.P[ - 4:, 4: - ] *= 1000.0 # give high uncertainty to the unobservable initial velocities + self.kf.P[4:, 4:] *= ( + 1000.0 + # give high uncertainty to the unobservable initial velocities + ) self.kf.P *= 10.0 self.kf.Q[-1, -1] *= 0.01 self.kf.Q[4:, 4:] *= 0.01 @@ -84,13 +81,13 @@ def __init__(self, bbox): self.age = 0 def update(self, bbox: np.ndarray) -> None: - """ - Updates the state vector with an observed bounding box. + """Update the state vector with an observed bounding box. Parameters ---------- bbox : np.ndarray Observed bounding box coordinates in the format [x1, y1, x2, y2]. + """ self.time_since_update = 0 self.history = [] @@ -99,13 +96,13 @@ def update(self, bbox: np.ndarray) -> None: self.kf.update(convert_bbox_to_z(bbox)) def predict(self) -> np.ndarray: - """ - Advances the state vector and returns the predicted bounding box estimate. + """Advance the state vector and return predicted bounding box estimate. Returns ------- np.ndarray Predicted bounding box coordinates in the format [x1, y1, x2, y2]. + """ if (self.kf.x[6] + self.kf.x[2]) <= 0: self.kf.x[6] *= 0.0 @@ -118,32 +115,35 @@ def predict(self) -> np.ndarray: return self.history[-1] def get_state(self) -> np.ndarray: - """ - Returns the current bounding box estimate. + """Return the current bounding box estimate. Returns ------- np.ndarray Current bounding box coordinates in the format [x1, y1, x2, y2]. + """ return convert_x_to_bbox(self.kf.x) -class Sort(object): +class Sort: # noqa: D101 def __init__( self, max_age: int = 1, min_hits: int = 3, iou_threshold: float = 0.3 ): - """ - Sets key parameters for SORT. + """Set key parameters for SORT. Parameters ---------- max_age : int, optional - Maximum number of frames to keep a tracker alive without an update. Default is 1. + Maximum number of frames to keep a tracker alive without an update. + Default is 1. min_hits : int, optional - Minimum number of consecutive hits to consider a tracker valid. Default is 3. + Minimum number of consecutive hits to consider a tracker valid. + Default is 3. iou_threshold : float, optional - IOU threshold for associating detections with trackers. Default is 0.3. + IOU threshold for associating detections with trackers. + Default is 0.3. + """ self.max_age = max_age self.min_hits = min_hits @@ -151,23 +151,27 @@ def __init__( self.trackers: list = [] self.frame_count = 0 - def update(self, dets: np.ndarray = np.empty((0, 5))) -> np.ndarray: - """ - Updates the SORT tracker with new detections. + def update( + self, + dets: np.ndarray = np.empty((0, 5)), # noqa: B008 + ) -> np.ndarray: + """Update the SORT tracker with new detections. Parameters ---------- dets : np.ndarray, optional - Array of shape (N, 5) representing N detection bounding boxes in format [x1, y1, x2, y2, score]. - Use np.empty((0, 5)) for frames without detections. + Array of shape (N, 5) representing N detection bounding boxes in + format [x1, y1, x2, y2, score]. Use np.empty((0, 5)) for frames + without detections. Returns ------- np.ndarray Array of tracked objects with object IDs added as the last column. - The shape of the array is (M, 5), where M is the number of tracked objects. - """ + The shape of the array is (M, 5), where M is the number of tracked + objects. + """ self.frame_count += 1 # get predicted locations from existing trackers. trks = np.zeros((len(self.trackers), 5)) diff --git a/crabs/tracker/track_video.py b/crabs/tracker/track_video.py index e65fd07c..3dc295f3 100644 --- a/crabs/tracker/track_video.py +++ b/crabs/tracker/track_video.py @@ -1,3 +1,5 @@ +"""Track crabs in a video using a trained detector.""" + import argparse import logging import os @@ -24,9 +26,7 @@ class Tracking: - """ - A class for performing crabs tracking on a video - using a trained model. + """Interface for tracking crabs on a video using a trained detector. Parameters ---------- @@ -41,14 +41,16 @@ class Tracking: The path to the input video. sort_tracker : Sort An instance of the sorting algorithm used for tracking. + """ def __init__(self, args: argparse.Namespace) -> None: + """Initialise the tracking interface with the given arguments.""" self.args = args self.config_file = args.config_file self.video_path = args.video_path self.trained_model_path = self.args.trained_model_path - self.device = self.args.device + self.device = "cuda" if self.args.accelerator == "gpu" else "cpu" self.setup() self.prep_outputs() @@ -60,10 +62,8 @@ def __init__(self, args: argparse.Namespace) -> None: ) def setup(self): - """ - Load tracking config, trained model and input video path. - """ - with open(self.config_file, "r") as f: + """Load tracking config, trained model and input video path.""" + with open(self.config_file) as f: self.config = yaml.safe_load(f) # Get trained model @@ -80,9 +80,7 @@ def setup(self): self.video_file_root = f"{Path(self.video_path).stem}" def prep_outputs(self): - """ - Prepare csv writer and if required, video writer. - """ + """Prepare csv writer and if required, video writer.""" ( self.csv_writer, self.csv_file, @@ -96,6 +94,7 @@ def prep_outputs(self): self.video_output = prep_video_writer( self.tracking_output_dir, + self.video_file_root, frame_width, frame_height, cap_fps, @@ -104,8 +103,7 @@ def prep_outputs(self): self.video_output = None def get_prediction(self, frame: np.ndarray) -> torch.Tensor: - """ - Get prediction from the trained model for a given frame. + """Get prediction from the trained model for a given frame. Parameters ---------- @@ -116,6 +114,7 @@ def get_prediction(self, frame: np.ndarray) -> torch.Tensor: ------- torch.Tensor: The prediction tensor from the trained model. + """ transform = transforms.Compose( [ @@ -129,9 +128,8 @@ def get_prediction(self, frame: np.ndarray) -> torch.Tensor: prediction = self.trained_model(img) return prediction - def update_tracking(self, prediction: dict) -> list[list[float]]: - """ - Update the tracking system with the latest prediction. + def update_tracking(self, prediction: dict) -> np.ndarray: + """Update the tracking system with the latest prediction. Parameters ---------- @@ -140,22 +138,26 @@ def update_tracking(self, prediction: dict) -> list[list[float]]: Returns ------- - list[list[float]]: - list of tracked bounding boxes after updating the tracking system. + np.ndarray: + tracked bounding boxes after updating the tracking system. + """ pred_sort = prep_sort(prediction, self.config["score_threshold"]) tracked_boxes_id_per_frame = self.sort_tracker.update(pred_sort) self.tracked_bbox_id.append(tracked_boxes_id_per_frame) + return tracked_boxes_id_per_frame def run_tracking(self): - """ - Run object detection + tracking on the video frames. - """ + """Run object detection + tracking on the video frames.""" # If we pass ground truth: check the path exist - if self.args.gt_path and not os.path.exists(self.args.gt_path): + if self.args.annotations_file and not os.path.exists( + self.args.annotations_file + ): logging.info( - f"Ground truth file {self.args.gt_path} does not exist. Exiting..." + f"Ground truth file {self.args.annotations_file} " + "does not exist." + "Exiting..." ) return @@ -182,7 +184,8 @@ def run_tracking(self): break elif not ret: logging.info( - f"Cannot read frame {frame_idx+1}/{total_frames}. Exiting..." + f"Cannot read frame {frame_idx+1}/{total_frames}. " + "Exiting..." ) break @@ -208,9 +211,9 @@ def run_tracking(self): # update frame number frame_idx += 1 - if self.args.gt_path: + if self.args.annotations_file: evaluation = TrackerEvaluate( - self.args.gt_path, + self.args.annotations_file, self.tracked_bbox_id, self.config["iou_threshold"], ) @@ -228,8 +231,7 @@ def run_tracking(self): def main(args) -> None: - """ - Main function to run the inference on video based on the trained model. + """Run detection+tracking inference on video. Parameters ---------- @@ -239,76 +241,101 @@ def main(args) -> None: Returns ------- None - """ + """ inference = Tracking(args) inference.run_tracking() def tracking_parse_args(args): + """Parse command-line arguments for tracking.""" parser = argparse.ArgumentParser() parser.add_argument( "--trained_model_path", type=str, required=True, - help="location of checkpoint of the trained model", + help="Location of trained model (a .ckpt file). ", ) parser.add_argument( "--video_path", type=str, required=True, - help="location of video to be tracked", + help="Location of the video to be tracked.", ) parser.add_argument( - "--config_file", + "--annotations_file", type=str, - default=str(Path(__file__).parent / "config" / "tracking_config.yaml"), + default=None, help=( - "Location of YAML config to control tracking. " - "Default: crabs-exploration/crabs/tracking/config/tracking_config.yaml" + "Location of JSON file containing ground truth annotations " + "(optional). " + "If passed, the evaluation metrics for the tracker are computed." ), ) parser.add_argument( "--output_dir", type=str, default="tracking_output", - help="Directory to save the track output", # is this a csv or a video? (or both) - ) - parser.add_argument( - "--max_frames_to_read", - type=int, - default=None, - help="Maximum number of frames to read (mostly for debugging).", + help=( + "Root name of the directory to save the tracking output. " + "The name of the output directory is appended with a timestamp. " + "Default: ./tracking_output_. " + ), ) parser.add_argument( - "--gt_path", + "--config_file", type=str, - default=None, + default=str(Path(__file__).parent / "config" / "tracking_config.yaml"), help=( - "Location of json file containing ground truth annotations (optional)." - "If passed, evaluation metrics are computed." + "Location of YAML config to control tracking. " + "Default: " + "crabs-exploration/crabs/tracking/config/tracking_config.yaml. " ), ) parser.add_argument( "--save_video", action="store_true", - help="Save video inference with tracking output", + help=( + "Add a video with tracked bounding boxes " + "to the tracking output directory. " + "The tracked video is called _tracks.mp4. " + ), ) parser.add_argument( "--save_frames", action="store_true", - help="Save frame to be used in correcting track labelling", + help=( + "Add all frames to the tracking output. " + "The frames are saved as-is, without bounding boxes, to " + "support their visualisation and correction using the VIA tool. " + ), ) parser.add_argument( - "--device", + "--accelerator", type=str, - default="cuda", - help="device for pytorch either cpu or cuda", + default="gpu", + help=( + "Accelerator for Pytorch. " + "Valid inputs are: cpu or gpu. Default: gpu." + ), + ) + parser.add_argument( + "--max_frames_to_read", + type=int, + default=None, + help=( + "Debugging option to limit " + "the maximum number of frames to read in the video. " + "It affects all the tracking outputs (csv, frames and video) " + "and the MOTA computation, which will be restricted to just " + "the first N frames. " + ), ) return parser.parse_args(args) def app_wrapper(): + """Wrap function to run the tracking application.""" torch.set_float32_matmul_precision("medium") tracking_args = tracking_parse_args(sys.argv[1:]) diff --git a/crabs/tracker/utils/io.py b/crabs/tracker/utils/io.py index b01b9750..16c82cda 100644 --- a/crabs/tracker/utils/io.py +++ b/crabs/tracker/utils/io.py @@ -1,3 +1,5 @@ +"""Utility functions for handling input and output operations.""" + import csv import os from datetime import datetime @@ -8,14 +10,13 @@ from crabs.detector.utils.visualization import draw_bbox from crabs.tracker.utils.tracking import ( - save_output_frames, + save_output_frame, write_tracked_bbox_to_csv, ) def prep_csv_writer(output_dir: str, video_file_root: str): - """ - Prepare csv writer to output tracking results. + """Prepare csv writer to output tracking results. Parameters ---------- @@ -27,16 +28,18 @@ def prep_csv_writer(output_dir: str, video_file_root: str): Returns ------- Tuple - A tuple containing the CSV writer, the CSV file object, and the tracking output directory path. - """ + A tuple containing the CSV writer, the CSV file object, and the + tracking output directory path. + """ + # Create a timestamped directory for the tracking output timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - tracking_output_dir = Path(output_dir + f"_{timestamp}") / video_file_root - # Create the subdirectory for the specific video file root + tracking_output_dir = Path(output_dir + f"_{timestamp}") tracking_output_dir.mkdir(parents=True, exist_ok=True) - csv_file = open( - f"{str(tracking_output_dir)}/predicted_tracks.csv", + # Initialise csv file + csv_file = open( # noqa: SIM115 + f"{str(tracking_output_dir)}/{video_file_root}_tracks.csv", "w", ) csv_writer = csv.writer(csv_file) @@ -60,12 +63,12 @@ def prep_csv_writer(output_dir: str, video_file_root: str): def prep_video_writer( output_dir: str, + video_file_root: str, frame_width: int, frame_height: int, cap_fps: float, ) -> cv2.VideoWriter: - """ - Prepare video writer to output processed video. + """Prepare video writer to output processed video. Parameters ---------- @@ -84,10 +87,11 @@ def prep_video_writer( ------- cv2.VideoWriter The video writer object for writing video frames. + """ output_file = os.path.join( output_dir, - "tracked_video.mp4", + f"{video_file_root}_tracks.mp4", ) output_codec = cv2.VideoWriter_fourcc("m", "p", "4", "v") video_output = cv2.VideoWriter( @@ -109,15 +113,14 @@ def save_required_output( frame_number: int, pred_scores: np.ndarray, ) -> None: - """ - Handle the output based on argument options. + """Handle the output based on argument options. Parameters ---------- video_file_root : Path The root name of the video file. - save_csv_and_frames : bool - Flag to save CSV and frames. + save_frames : bool + Flag to save frames. tracking_output_dir : Path Directory to save tracking output. csv_writer : Any @@ -134,18 +137,24 @@ def save_required_output( The frame number. pred_scores : np.ndarray The prediction score from detector + """ - frame_name = f"{video_file_root}_frame_{frame_number:08d}.png" + frame_name = f"frame_{frame_number:08d}.png" for bbox, pred_score in zip(tracked_boxes, pred_scores): write_tracked_bbox_to_csv( - bbox, frame, frame_name, csv_writer, pred_score + np.array(bbox), frame, frame_name, csv_writer, pred_score ) if save_frames: - save_output_frames( + # create subdirectory of frames + frames_subdir = tracking_output_dir / f"{video_file_root}_frames" + frames_subdir.mkdir(parents=True, exist_ok=True) + + # save frame (without bounding boxes) + save_output_frame( frame_name, - tracking_output_dir, + frames_subdir, frame, frame_number, ) @@ -165,16 +174,12 @@ def save_required_output( def close_csv_file(csv_file) -> None: - """ - Close the CSV file if it's open. - """ + """Close the CSV file if it's open.""" if csv_file: csv_file.close() def release_video(video_output) -> None: - """ - Release the video file if it's open. - """ + """Release the video file if it's open.""" if video_output: video_output.release() diff --git a/crabs/tracker/utils/sort.py b/crabs/tracker/utils/sort.py index 8b611d0e..0fe67c49 100644 --- a/crabs/tracker/utils/sort.py +++ b/crabs/tracker/utils/sort.py @@ -1,5 +1,5 @@ -""" -SORT: A Simple, Online and Realtime Tracker +"""SORT: A Simple, Online and Realtime Tracker. + Copyright (C) 2016-2020 Alex Bewley alex@bewley.ai This program is free software: you can redistribute it and/or modify @@ -16,24 +16,29 @@ along with this program. If not, see . """ -from typing import Optional, Tuple +from typing import Optional import numpy as np def linear_assignment(cost_matrix: np.ndarray) -> np.ndarray: - """ - Perform linear assignment using LAPJV algorithm if available, otherwise fallback to scipy's linear_sum_assignment. + """Perform linear assignment. + + Uses LAPJV algorithm if available, otherwise falls back to scipy's + linear_sum_assignment. Parameters ---------- cost_matrix : np.ndarray - The cost matrix representing the assignment costs between tracks and detections. + The cost matrix representing the assignment costs between + tracks and detections. Returns ------- np.ndarray - An array containing the assignment indices. Each row corresponds to a pair (track index, detection index). + An array containing the assignment indices. Each row corresponds to a + pair (track index, detection index). + """ try: import lap @@ -48,22 +53,26 @@ def linear_assignment(cost_matrix: np.ndarray) -> np.ndarray: def iou_batch(bb_test: np.ndarray, bb_gt: np.ndarray) -> np.ndarray: - """ - From SORT: Computes IOU between two bboxes in the form [x1,y1,x2,y2] - Calculate Intersection over Union (IoU) between two batches of bounding boxes. + """Compute IOU between two bboxes in the form [x1,y1,x2,y2]. + + Calculate Intersection over Union (IoU) between two batches of + bounding boxes. Parameters ---------- bb_test : np.ndarray - Bounding boxes of shape (N, 4) representing N test boxes in format [x1, y1, x2, y2]. + Bounding boxes of shape (N, 4) representing N test boxes + in format [x1, y1, x2, y2]. bb_gt : np.ndarray - Bounding boxes of shape (M, 4) representing M ground truth boxes in format [x1, y1, x2, y2]. + Bounding boxes of shape (M, 4) representing M ground truth + boxes in format [x1, y1, x2, y2]. Returns ------- np.ndarray IoU values between each pair of bounding boxes in bb_test and bb_gt. The shape of the returned array is (N, M). + """ bb_gt = np.expand_dims(bb_gt, 0) bb_test = np.expand_dims(bb_test, 1) @@ -85,8 +94,9 @@ def iou_batch(bb_test: np.ndarray, bb_gt: np.ndarray) -> np.ndarray: def convert_bbox_to_z(bbox: np.ndarray) -> np.ndarray: - """ - Convert a bounding box from [x1, y1, x2, y2] to a representation [x, y, s, r]. + """Convert a bounding box from corner form to center form. + + Corner form is [x1, y1, x2, y2] and center form is [x, y, s, r]. Parameters ---------- @@ -98,6 +108,7 @@ def convert_bbox_to_z(bbox: np.ndarray) -> np.ndarray: np.ndarray Converted representation of the bounding box as [x, y, s, r]. T + """ w = bbox[2] - bbox[0] h = bbox[3] - bbox[1] @@ -111,8 +122,9 @@ def convert_bbox_to_z(bbox: np.ndarray) -> np.ndarray: def convert_x_to_bbox( x: np.ndarray, score: Optional[float] = None ) -> np.ndarray: - """ - Convert a bounding box from center form [x, y, s, r] to corner form [x1, y1, x2, y2]. + """Convert a bounding box from center form to corner form. + + Center form is [x, y, s, r] and corner form is [x1, y1, x2, y2]. Parameters ---------- @@ -124,8 +136,11 @@ def convert_x_to_bbox( Returns ------- np.ndarray - Converted representation of the bounding box as [x1, y1, x2, y2] (and score, if provided). - The shape of the returned array is (1, 4) or (1, 5) if score is provided. + Converted representation of the bounding box as [x1, y1, x2, y2] + (and score, if provided). + The shape of the returned array is (1, 4) or (1, 5) + if score is provided. + """ w = np.sqrt(x[2] * x[3]) h = x[2] / w @@ -145,28 +160,35 @@ def convert_x_to_bbox( ).reshape((1, 5)) -def associate_detections_to_trackers( +def associate_detections_to_trackers( # noqa: C901 detections: np.ndarray, trackers: np.ndarray, iou_threshold: float = 0.3 -) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: - """ - Assigns detections to tracked objects (both represented as bounding boxes). +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Assign detections to tracked objects. + + Both detections and tracked objects are represented as bounding boxes. Parameters ---------- detections : np.ndarray - Array of shape (N, 4) representing N detection bounding boxes in format [x1, y1, x2, y2]. + Array of shape (N, 4) representing N detection bounding boxes in + format [x1, y1, x2, y2]. trackers : np.ndarray - Array of shape (M, 4) representing M tracker bounding boxes in format [x1, y1, x2, y2]. + Array of shape (M, 4) representing M tracker bounding boxes in + format [x1, y1, x2, y2]. iou_threshold : float, optional IOU threshold for associating detections with trackers. Default is 0.3. Returns ------- - Tuple[np.ndarray, np.ndarray, np.ndarray] + tuple[np.ndarray, np.ndarray, np.ndarray] Three arrays: - - matches: Array of shape (K, 2) containing indices of matched detections and trackers. - - unmatched_detections: Array of indices of detections that were not matched. - - unmatched_trackers: Array of indices of trackers that were not matched. + - matches: Array of shape (K, 2) containing indices of matched + detections and trackers. + - unmatched_detections: Array of indices of detections that were not + matched. + - unmatched_trackers: Array of indices of trackers that were not + matched. + """ if len(trackers) == 0: return ( @@ -184,29 +206,29 @@ def associate_detections_to_trackers( else: matched_indices = linear_assignment(-iou_matrix) else: - matched_indices = np.empty(shape=(0, 2)) + matched_indices = np.empty(shape=(0, 2), dtype=int) unmatched_detections = [] - for d, det in enumerate(detections): + for d, _det in enumerate(detections): if d not in matched_indices[:, 0]: unmatched_detections.append(d) unmatched_trackers = [] - for t, trk in enumerate(trackers): + for t, _trk in enumerate(trackers): if t not in matched_indices[:, 1]: unmatched_trackers.append(t) # filter out matched with low IOU - matches = [] + list_matches = [] # before: matches for m in matched_indices: if iou_matrix[m[0], m[1]] < iou_threshold: unmatched_detections.append(m[0]) unmatched_trackers.append(m[1]) else: - matches.append(m.reshape(1, 2)) - if len(matches) == 0: + list_matches.append(m.reshape(1, 2)) + if len(list_matches) == 0: matches = np.empty((0, 2), dtype=int) else: - matches = np.concatenate(matches, axis=0) + matches = np.concatenate(list_matches, axis=0) return ( matches, diff --git a/crabs/tracker/utils/tracking.py b/crabs/tracker/utils/tracking.py index b18e9045..ca83ea94 100644 --- a/crabs/tracker/utils/tracking.py +++ b/crabs/tracker/utils/tracking.py @@ -1,25 +1,28 @@ +"""Utility functions for tracking.""" + import json import logging from pathlib import Path -from typing import Any, Dict +from typing import Any import cv2 import numpy as np -def extract_bounding_box_info(row: list[str]) -> Dict[str, Any]: - """ - Extracts bounding box information from a row of data. +def extract_bounding_box_info(row: list[str]) -> dict[str, Any]: + """Extract bounding box information from a row of data. Parameters ---------- row : list[str] - A list representing a row of data containing information about a bounding box. + A list representing a row of data containing information about a + bounding box. Returns ------- - Dict[str, Any]: + dict[str, Any]: A dictionary containing the extracted bounding box information. + """ filename = row[0] region_shape_attributes = json.loads(row[5]) @@ -49,8 +52,7 @@ def write_tracked_bbox_to_csv( csv_writer: Any, pred_score: np.ndarray, ) -> None: - """ - Write bounding box annotation to a CSV file. + """Write bounding box annotation to a CSV file. Parameters ---------- @@ -65,6 +67,7 @@ def write_tracked_bbox_to_csv( The CSV writer object to write the annotation. pred_score : np.ndarray The prediction score from detector. + """ # Bounding box geometry xmin, ymin, xmax, ymax, id = bbox @@ -79,46 +82,37 @@ def write_tracked_bbox_to_csv( '{{"clip":{}}}'.format("123"), 1, 0, - '{{"name":"rect","x":{},"y":{},"width":{},"height":{}}}'.format( - xmin, ymin, width_box, height_box - ), - '{{"track":"{}", "confidence":"{}"}}'.format(int(id), pred_score), + f'{{"name":"rect","x":{xmin},"y":{ymin},"width":{width_box},"height":{height_box}}}', + f'{{"track":"{int(id)}", "confidence":"{pred_score}"}}', ) ) -def save_output_frames( +def save_output_frame( frame_name: str, tracking_output_dir: Path, frame: np.ndarray, frame_number: int, ) -> None: - """ - Save tracked bounding boxes as frames. + """Save tracked bounding boxes as frames. Parameters ---------- - video_file_root : str - The root path of the video file. + frame_name : str + The name of the image file to save frame in. tracking_output_dir : Path The directory where tracked frames and CSV file will be saved. - tracked_boxes : list[list[float]] - List of bounding boxes to be saved. frame : np.ndarray The frame image. frame_number : int The frame number. - csv_writer : Any - CSV writer object for writing bounding box data. - pred_scores : np.ndarray - The prediction score from detector Returns ------- None - """ - # Save frame as PNG - once as per frame + """ + # Save frame as PNG frame_path = tracking_output_dir / frame_name img_saved = cv2.imwrite(str(frame_path), frame) if not img_saved: @@ -128,25 +122,28 @@ def save_output_frames( def prep_sort(prediction: dict, score_threshold: float) -> np.ndarray: - """ - Put predictions in format expected by SORT + """Put predictions in format expected by SORT. Parameters ---------- prediction : dict The dictionary containing predicted bounding boxes, scores, and labels. + score_threshold : float + The threshold score for filtering out low-confidence predictions. + Returns ------- np.ndarray: An array containing sorted bounding boxes of detected objects. + """ pred_boxes = prediction[0]["boxes"].detach().cpu().numpy() pred_scores = prediction[0]["scores"].detach().cpu().numpy() pred_labels = prediction[0]["labels"].detach().cpu().numpy() pred_sort = [] - for box, score, label in zip(pred_boxes, pred_scores, pred_labels): + for box, score, _label in zip(pred_boxes, pred_scores, pred_labels): if score > score_threshold: bbox = np.concatenate((box, [score])) pred_sort.append(bbox) diff --git a/guides/ManualLabellingSteps.md b/guides/ManualLabellingSteps.md old mode 100755 new mode 100644 diff --git a/notebooks/notebook_data_augm.py b/notebooks/notebook_data_augm.py index 54c0b51a..ca8ba223 100644 --- a/notebooks/notebook_data_augm.py +++ b/notebooks/notebook_data_augm.py @@ -13,7 +13,7 @@ # %%%%%%%%%%%%%%%%%%%% # Read config as dict -with open(CONFIG, "r") as f: +with open(CONFIG) as f: config_dict = yaml.safe_load(f) # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% diff --git a/notebooks/notebook_detect_chessboard_in_sampled_frames.py b/notebooks/notebook_detect_chessboard_in_sampled_frames.py index 7f0a1ad8..1eb42e53 100644 --- a/notebooks/notebook_detect_chessboard_in_sampled_frames.py +++ b/notebooks/notebook_detect_chessboard_in_sampled_frames.py @@ -1,6 +1,4 @@ -""" Extract frames in calibration video - -""" +"""Extract frames in calibration video""" # %% from pathlib import Path diff --git a/notebooks/notebook_detect_chessboard_in_video.py b/notebooks/notebook_detect_chessboard_in_video.py index 5e15e0b4..9ae968c5 100644 --- a/notebooks/notebook_detect_chessboard_in_video.py +++ b/notebooks/notebook_detect_chessboard_in_video.py @@ -1,6 +1,4 @@ -""" Extract frames in calibration video - -""" +"""Extract frames in calibration video""" # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%% from pathlib import Path diff --git a/notebooks/notebook_overlay_video.py b/notebooks/notebook_overlay_video.py index 94ccc300..2bec0dc3 100644 --- a/notebooks/notebook_overlay_video.py +++ b/notebooks/notebook_overlay_video.py @@ -70,7 +70,6 @@ def plot_trajectories( Frame trajectory is the frame up to which plot the trajectory of the individuals. If none is specified, all frames are plotted. """ - fig, ax = plt.subplots(1, 1, figsize=(10, 10)) # add color cycler to axes diff --git a/pyproject.toml b/pyproject.toml index 172d773c..300419d6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,12 +40,17 @@ dev = [ "pytest", "pytest-cov", "coverage", - "tox", # >=4 ? - "black", + "tox", "mypy", "pre-commit", "ruff", "setuptools_scm", + "check-manifest", + # "codespell", + # "pandas-stubs", + # "types-attrs", + # "types-PyYAML", + # "types-requests", ] [project.scripts] @@ -67,15 +72,15 @@ include-package-data = true [tool.setuptools.packages.find] include = ["crabs*"] -exclude = ["tests*"] +exclude = ["tests"] [tool.pytest.ini_options] addopts = "--cov=crabs" -[tool.black] -target-version = ['py39', 'py310'] -skip-string-normalization = false -line-length = 79 +# [tool.black] +# target-version = ['py39', 'py310'] +# skip-string-normalization = false +# line-length = 79 [tool.setuptools_scm] @@ -91,26 +96,39 @@ ignore = [ [tool.ruff] line-length = 79 exclude = ["__init__.py", "build", ".eggs"] -select = ["I", "E", "F"] fix = true + +[tool.ruff.lint] +# See https://docs.astral.sh/ruff/rules/ ignore = [ - 'E501', # line too long: should be handled by black + "D203", # one blank line before class + "D213", # multi-line-summary second line +] +select = [ + "E", # pycodestyle errors + "F", # Pyflakes + "UP", # pyupgrade + "I", # isort + "B", # flake8 bugbear + "SIM", # flake8 simplify + "C90", # McCabe complexity + "D", # pydocstyle + "NPY201", # checks for syntax that was deprecated in numpy2.0 ] -# force-exclude = true -# ignore = [ -# "D203", # no-blank-line-before-class -# "D212", # multi-line-summary-first-line -# "D417", # argument description in docstring (unreliable) -# ] -# per-file-ignores = {"test_*" = [ -# "S101", -# ]} -# [tool.tomlsort] -# all = true -# spaces_indent_inline_array = 4 -# trailing_comma_inline_array = true -# overrides."project.classifiers".inline_arrays = false -# overrides."tool.coverage.paths.source".inline_arrays = false +per-file-ignores = { "tests/*" = [ + "D100", # missing docstring in public module + "D205", # missing blank line between summary and description + "D103", # missing docstring in public function +], "examples/*" = [ + "B018", # Found useless expression + "D103", # Missing docstring in public function + "D400", # first line should end with a period. + "D415", # first line should end with a period, question mark... + "D205", # missing blank line between summary and description +] } + +[tool.ruff.format] +docstring-code-format = true # Also format code in docstrings [tool.tox] diff --git a/scripts/output_video.py b/scripts/output_video.py index 55b4fc8c..cd044064 100644 --- a/scripts/output_video.py +++ b/scripts/output_video.py @@ -1,3 +1,5 @@ +"""Script to create a video with tracked bounding boxes.""" + from pathlib import Path import cv2 @@ -10,6 +12,7 @@ def create_opencv_video( list_individuals_idcs=None, list_frame_idcs=None, ): + """Create a video with bounding boxes around the selected individuals.""" # Open the video file cap = cv2.VideoCapture(input_video) @@ -138,11 +141,11 @@ def create_opencv_video( pred_csv = ( "/Users/sofia/arc/project_Zoo_crabs/escape_clips/" - "crabs_track_output_selected_clips/04.09.2023-04-Right_RE_test/predicted_tracks.csv" + "crabs_track_output_selected_clips/04.09.2023-04-Right_RE_test/predicted_tracks.csv" # noqa ) input_video = ( - "/Users/sofia/arc/project_Zoo_crabs/escape_clips/crabs_track_output_selected_clips/" + "/Users/sofia/arc/project_Zoo_crabs/escape_clips/crabs_track_output_selected_clips/" # noqa "04.09.2023-04-Right_RE_test/04.09.2023-04-Right_RE_test.mp4" ) diff --git a/tests/data/COCO_VIA_JSONS/VIA_JSON_1.json b/tests/data/COCO_VIA_JSONS/VIA_JSON_1.json old mode 100755 new mode 100644 diff --git a/tests/data/COCO_VIA_JSONS/VIA_JSON_2.json b/tests/data/COCO_VIA_JSONS/VIA_JSON_2.json old mode 100755 new mode 100644 diff --git a/tests/fixtures/frame_extraction.py b/tests/fixtures/frame_extraction.py index 65d381c8..9fb983d1 100644 --- a/tests/fixtures/frame_extraction.py +++ b/tests/fixtures/frame_extraction.py @@ -8,8 +8,7 @@ def list_files_in_dir(input_dir: str) -> list: - """Lists files in input directory""" - + """List files in input directory.""" return [ f for f in Path(input_dir).glob("*") @@ -19,7 +18,7 @@ def list_files_in_dir(input_dir: str) -> list: @pytest.fixture() def cli_inputs_dict(tmp_path: Path) -> dict: - """Returns the command line input arguments as a dictionary. + """Return the command line input arguments as a dictionary. These command line arguments are passed to the extract frames CLI command. The output path is @@ -40,10 +39,10 @@ def cli_inputs_dict(tmp_path: Path) -> dict: @pytest.fixture() def cli_inputs_list(cli_inputs_dict: dict) -> list: - """Returns the command line input arguments as a list.""" + """Return the command line input arguments as a list.""" def cli_inputs_dict_to_list(input_params: dict) -> list: - """Transforms a dictionary of parameters into a list of CLI arguments. + """Transform a dictionary of parameters into a list of CLI arguments. If for an item in the dictionary its value is True, the key is taken as a CLI boolean argument (i.e., a flag). @@ -63,9 +62,10 @@ def cli_inputs_dict_to_list(input_params: dict) -> list: ------- list a list of command line arguments to pass to `subprocess.run()`. + """ list_kys_modified = [] - for k in input_params.keys(): + for k in input_params: if input_params[k] is False: list_kys_modified.append("--no-" + k) else: @@ -96,7 +96,7 @@ def cli_inputs_list_empty(): def mock_extract_frames_app( cli_inputs_dict: dict, ) -> typer.main.Typer: - """Monkeypatches the extract-frames app to modify its default values. + """Monkeypatch the extract-frames app to modify its default values. We modify the defaults with values that are more convenient for testing. """ @@ -144,7 +144,7 @@ def mock_combine_and_format_annotations( @pytest.fixture() def video_extensions_flipped() -> list: - """Extracts the extensions of video files in INPUT_DATA_DIR + """Extract the extensions of video files in INPUT_DATA_DIR and flips their case (uppercase -> lowercase and viceversa). The file extensions would be provided by the user in the diff --git a/tests/test_integration/test_annotations.py b/tests/test_integration/test_annotations.py index 4e77ffe1..6b0a8e34 100644 --- a/tests/test_integration/test_annotations.py +++ b/tests/test_integration/test_annotations.py @@ -20,6 +20,7 @@ def via_json_1() -> str: ------- str path to a sample VIA JSON file 1. + """ # Return path to sample VIA (Visual Image Annotator) JSON file 1 return str( @@ -37,6 +38,7 @@ def via_json_2() -> str: ------- str path to a sample VIA JSON file 2. + """ # Return path to sample VIA JSON file 2 return str( @@ -64,6 +66,7 @@ def test_via_json_combine( path to a sample VIA JSON file 2. tmp_path : Path Pytest fixture providing a temporary directory path + """ # Check if the combination of 2 VIA JSON files has the same data # as the separate JSONS @@ -148,6 +151,7 @@ def test_via_json_combine_default_dir( path to a sample VIA JSON file 2. tmp_path : Path Pytest fixture providing a temporary directory path + """ # Set default directory via_default_dir = "/sample/VIA/project/directory" @@ -193,6 +197,7 @@ def test_via_json_combine_non_full_default_dir( path to a sample VIA JSON file 2. tmp_path : Path Pytest fixture providing a temporary directory path + """ # Set default directory as a non-full path via_default_dir = tmp_path.stem @@ -228,6 +233,7 @@ def test_via_json_combine_project_name( path to a sample VIA JSON file 2. tmp_path : Path Pytest fixture providing a temporary directory path + """ # Set project name via_project_name = "TEST" @@ -274,6 +280,7 @@ def test_coco_generated_from_via_json( Pytest fixture providing a temporary directory path request: pytest.FixtureRequest to request a parametrized fixture + """ # Define category attributes of the annotations coco_category_ID = 1 @@ -378,7 +385,7 @@ def test_coco_generated_from_via_json( def test_exclude_pattern(via_json_1: str, via_json_2: str, tmp_path: Path): - """Tests if exclude pattern works when combining annotation files + """Test exclude pattern when combining annotation files. Parameters ---------- @@ -388,6 +395,7 @@ def test_exclude_pattern(via_json_1: str, via_json_2: str, tmp_path: Path): path to second VIA JSON file tmp_path : Path Pytest fixture with a path to a temporary directory + """ # combine input json files, excluding those that end with _2.json json_out_fullpath = combine_multiple_via_jsons( diff --git a/tests/test_integration/test_frame_extraction.py b/tests/test_integration/test_frame_extraction.py index ee88df4b..c385ef15 100644 --- a/tests/test_integration/test_frame_extraction.py +++ b/tests/test_integration/test_frame_extraction.py @@ -23,6 +23,7 @@ def assert_output_files(list_input_videos: list, cli_dict: dict) -> None: List of videos used for frame extraction. cli_dict : dict A validation dictionary with the parameters of the frame extraction. + """ # check name of directory with extracted frames list_subfolders = [ @@ -30,13 +31,9 @@ def assert_output_files(list_input_videos: list, cli_dict: dict) -> None: ] extracted_frames_dir = Path(list_subfolders[0]) assert len(list_subfolders) == 1 - assert ( - type( - datetime.datetime.strptime( - extracted_frames_dir.name, "%Y%m%d_%H%M%S" - ) - ) - == datetime.datetime + assert isinstance( + datetime.datetime.strptime(extracted_frames_dir.name, "%Y%m%d_%H%M%S"), + datetime.datetime, ) # check there is an extracted_frames.json file @@ -74,7 +71,7 @@ def assert_output_files(list_input_videos: list, cli_dict: dict) -> None: ) # only one must match # check n_elements in json file matches n of files generated - with open((extracted_frames_dir / "extracted_frames.json")) as js: + with open(extracted_frames_dir / "extracted_frames.json") as js: extracted_frames_dict = json.load(js) n_extracted_frames = sum( [len(list_idcs) for list_idcs in extracted_frames_dict.values()] @@ -172,11 +169,9 @@ def test_frame_extraction_one_dir( def test_extension_case_insensitive(video_extensions_flipped: list) -> None: - """ - Tests that the function that computes the list of SLEAP videos + """Tests that the function that computes the list of SLEAP videos is case-insensitive for the user-provided extension. """ - # build list of video files in dir list_files = list_files_in_dir(INPUT_DATA_DIR) diff --git a/tests/test_unit/test_datamodules.py b/tests/test_unit/test_datamodules.py index 2a823605..74cff89d 100644 --- a/tests/test_unit/test_datamodules.py +++ b/tests/test_unit/test_datamodules.py @@ -24,7 +24,7 @@ @pytest.fixture def default_train_config(): config_file = DEFAULT_CONFIG - with open(config_file, "r") as f: + with open(config_file) as f: return yaml.safe_load(f) @@ -86,7 +86,6 @@ def expected_no_data_augm_transforms(): def compare_transforms_attrs_excluding(transform1, transform2, keys_to_skip): """Compare the attributes of two transforms excluding those in list.""" - transform1_attrs_without_fns = { key: val for key, val in transform1.__dict__.items() @@ -106,16 +105,16 @@ def compare_transforms_attrs_excluding(transform1, transform2, keys_to_skip): def create_dummy_dataset(): """Return a factory of dummy images and annotations for testing. - The created datasets consist of N images, with a random number of bounding boxes - per image. The bounding boxes have fixed width and height, but their location - is randomized. Both images and annotations are torch tensors. + The created datasets consist of N images, with a random number of bounding + boxes per image. The bounding boxes have fixed width and height, but their + location is randomized. Both images and annotations are torch tensors. """ def _create_dummy_dataset(n_images): """Create a dataset with N images and random bounding boxes per image. - The number of images in the dataset needs to be > 5 to avoid floating point errors - in the dataset split. + The number of images in the dataset needs to be > 5 to avoid floating + point errors in the dataset split. """ img_size = 256 fixed_width_height = 10 @@ -127,7 +126,8 @@ def _create_dummy_dataset(n_images): n_bboxes = random.randint(1, 5) boxes = [] for _ in range(n_bboxes): - # Randomise the location of the top left corner of the bounding box + # Randomise the location of the top left corner of the + # bounding box x_min = random.randint(0, img_size - fixed_width_height) y_min = random.randint(0, img_size - fixed_width_height) @@ -193,7 +193,7 @@ def _create_dummy_dataset_dirs(n_images): def test_get_train_transform( crabs_data_module, expected_train_transforms, request ): - """Test transforms linked to training set are as expected""" + """Test transforms linked to training set are as expected.""" crabs_data_module = request.getfixturevalue(crabs_data_module) expected_train_transforms = request.getfixturevalue( expected_train_transforms @@ -209,7 +209,8 @@ def test_get_train_transform( expected_train_transforms.transforms, ): # we skip the attribute `_labels_getter` of `SanitizeBoundingBoxes` - # because it points to a lambda function, which does not have a comparison defined. + # because it points to a lambda function, which does not have a + # comparison defined. assert compare_transforms_attrs_excluding( transform1=train_tr, transform2=expected_train_tr, @@ -233,7 +234,7 @@ def test_get_train_transform( def test_get_test_val_transform( crabs_data_module, expected_test_val_transforms, request ): - """Test transforms linked to test and validation sets are as expected""" + """Test transforms linked to test and validation sets are as expected.""" crabs_data_module = request.getfixturevalue(crabs_data_module) expected_test_val_transforms = request.getfixturevalue( expected_test_val_transforms @@ -279,7 +280,10 @@ def test_collate_fn(crabs_data_module, create_dummy_dataset, request): @pytest.mark.parametrize( - "dataset_size, seed, train_fraction, val_over_test_fraction, expected_img_ids_per_split", + ( + "dataset_size, seed, train_fraction, " + "val_over_test_fraction, expected_img_ids_per_split" + ), [ ( 50, @@ -317,9 +321,9 @@ def test_compute_splits( create_dummy_dataset_dirs, default_train_config, ): - """Test dataset splits are reproducible and according to the requested - fraction""" - + """Test dataset splits are reproducible and match + the requested fraction. + """ # Create a dummy dataset and get paths to its directories dataset_dirs = create_dummy_dataset_dirs(n_images=dataset_size) diff --git a/tests/test_unit/test_datasets.py b/tests/test_unit/test_datasets.py index f9b11eee..53139bd9 100644 --- a/tests/test_unit/test_datasets.py +++ b/tests/test_unit/test_datasets.py @@ -32,8 +32,7 @@ ) @pytest.mark.parametrize("n_files_to_exclude", [0, 1, 20, 97]) def test_exclude_files(list_datasets, list_annotations, n_files_to_exclude): - """ - Test if the required files are excluded correctly from the + """Test if the required files are excluded correctly from the dataset defined by a list of annotation files, and a list of corresponding image directories. diff --git a/tests/test_unit/test_evaluate_tracker.py b/tests/test_unit/test_evaluate_tracker.py index 6dba2d4f..a08ec8d0 100644 --- a/tests/test_unit/test_evaluate_tracker.py +++ b/tests/test_unit/test_evaluate_tracker.py @@ -141,12 +141,14 @@ def test_ground_truth_data_from_csv(evaluation): {1: 11, 2: 12, 3: 13, 4: np.nan}, {1: 11, 2: 12, 3: 13, 5: np.nan}, 0, - ), # crab disappears but was missed detection in frame f-1, with a new missed crab in frame f + ), # crab disappears but was missed detection in frame f-1, + # with a new missed crab in frame f ( {1: 11, 2: 12, 3: 13, 4: np.nan}, {1: 11, 2: 12, 3: np.nan}, 0, - ), # crab disappears but was missed detection in frame f-1, and existing crab was missed in frame f + ), # crab disappears but was missed detection in frame f-1, + # and existing crab was missed in frame f # ----- a crab (GT=4) appears --------- ( {1: 11, 2: 12, 3: 13}, @@ -172,34 +174,49 @@ def test_ground_truth_data_from_csv(evaluation): {1: 11, 2: 12, 3: 13, 5: np.nan}, {1: 11, 2: 12, 3: 13, 4: np.nan}, 0, - ), # crab that appears is missed detection in current frame, and another missed detection in previous frame disappears + ), # crab that appears is missed detection in current frame, + # and another missed detection in previous frame disappears ( {1: 11, 2: 12, 3: np.nan}, {1: 11, 2: 12, 3: 13, 4: np.nan}, 0, - ), # crab that appears is missed detection in current frame, and a pre-existing crab is missed detection in previous frame - # ---------- Test consistency with last predicted ID if a crab (GT=3) that continues to exist is not detected for a few frames (>= 1) ------------ + ), # crab that appears is missed detection in current frame, + # and a pre-existing crab is missed detection in previous frame + # ---------- + # Test consistency with last predicted ID if a crab (GT=3) + # that continues to exist is not detected for a few frames (>= 1) + # ------------ ( {1: 11, 2: 12, 3: np.nan}, {1: 11, 2: 12, 3: 13}, 0, - ), # crab that continues to exist, and the current predicted ID is consistent with last_known_predicted_ids={1: 11, 2: 12, 3: 13, 4: 14} + ), # crab that continues to exist, and the current predicted ID is + # consistent with last_known_predicted_ids={1: 11, 2: 12, 3: 13, 4: 14} ( {1: 11, 2: 12, 3: np.nan}, {1: 11, 2: 12, 3: 14}, 1, - ), # crab that continues to exist, and the current predicted ID is NOT consistent with the last_known_predicted_ids={1: 11, 2: 12, 3: 13, 4: 14} - # ---------- Test consistency with last predicted ID if a crab (GT=3) re-appears after a few frames (>= 1) ------------ + ), # crab that continues to exist, and the current predicted ID + # is NOT consistent with the + # last_known_predicted_ids={1: 11, 2: 12, 3: 13, 4: 14} + # ---------- + # Test consistency with last predicted ID if a crab (GT=3) + # re-appears after a few frames (>= 1) + # ------------ ( {1: 11, 2: 12}, {1: 11, 2: 12, 3: 13}, 0, - ), # crab whose GT ID is in last_known_predicted_ids, appears in the current frame, and the current predicted ID is consistent with last_known_predicted_ids + ), # crab whose GT ID is in last_known_predicted_ids, appears + # in the current frame, and the current predicted ID is consistent + # with last_known_predicted_ids ( {1: 11, 2: 12}, {1: 11, 2: 12, 3: 14}, 1, - ), # crab whose GT ID is in last_known_predicted_ids, appears in the current frame, and the current predicted ID is NOT consistent with last_known_predicted_ids + ), # crab whose GT ID is in last_known_predicted_ids, appears + # in the current frame, and the current predicted ID is NOT consistent + # with last_known_predicted_ids ], ) def test_count_identity_switches( diff --git a/tests/test_unit/test_track_video.py b/tests/test_unit/test_track_video.py index 6833242b..7d7ffa89 100644 --- a/tests/test_unit/test_track_video.py +++ b/tests/test_unit/test_track_video.py @@ -17,8 +17,8 @@ def mock_args(): video_path="/path/to/video.mp4", trained_model_path="/path/to/model.ckpt", output_dir=temp_dir, - device="cuda", - gt_path=None, + accelerator="gpu", + annotations_file=None, save_video=None, ) diff --git a/tests/test_unit/test_train_model.py b/tests/test_unit/test_train_model.py index 7c7729d8..95ce6f91 100644 --- a/tests/test_unit/test_train_model.py +++ b/tests/test_unit/test_train_model.py @@ -18,9 +18,7 @@ [["/home/data/dataset1"], ["/home/data/dataset1", "/home/data/dataset2"]], ) def test_prep_img_directories(dataset_dirs: list): - """ - Test parsing of image directories when training a model. - """ + """Test parsing of image directories when training a model.""" from crabs.detector.train_model import DectectorTrain # prepare parser @@ -46,10 +44,9 @@ def test_prep_img_directories(dataset_dirs: list): ], ) def test_prep_annotation_files_single_dataset(annotation_files, expected): + """Test parsing of annotation files when training a model on a single + dataset. """ - Test parsing of annotation files when training a model on a single dataset. - """ - from crabs.detector.train_model import DectectorTrain # prepare CLI arguments @@ -89,10 +86,9 @@ def test_prep_annotation_files_single_dataset(annotation_files, expected): ], ) def test_prep_annotation_files_multiple_datasets(annotation_files, expected): + """Test parsing of annotation files when training + a model on two datasets. """ - Test parsing of annotation files when training a model on two datasets. - """ - from crabs.detector.train_model import DectectorTrain # prepare CLI arguments considering multiple dataset