Skip to content

Commit

Permalink
Data augmentation (#141)
Browse files Browse the repository at this point in the history
* Move checkpoint type computation to utils

* Refactor checkpointing in training script

* Get ckpt type if ckpt is passed

* optionally apply a data augmentation method (WIP)

* fix config syntax in code

* add data augmentation notebook

* notebook to explore params of individual transformations

* add transforms from config

* Add keywords to datamodule params

* Optionally skip data augmentation

* If data augmentation key in config, apply

* Update tests

* Change tests to read default config

* Refactor transform functions and clean up

* update notebook

* Fix data augmentation default config

* Optionally log data augmentation transforms as artifacts

* Rename skip to 'no_data_augmentation'
  • Loading branch information
sfmig authored Jun 28, 2024
1 parent 87babb5 commit 81db31e
Show file tree
Hide file tree
Showing 7 changed files with 354 additions and 116 deletions.
21 changes: 17 additions & 4 deletions crabs/detection_tracking/config/faster_rcnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,29 @@ batch_size_test: 4
# -------------------
# Data augmentation
# -------------------
transform_brightness: 0.5
transform_hue: 0.3
gaussian_blur_params:
gaussian_blur:
kernel_size:
- 5
- 9
sigma:
- 0.1
- 5.0

color_jitter:
brightness: 0.5
hue: 0.3
random_horizontal_flip:
p: 0.5
random_rotation:
degrees: [-10.0, 10.0]
random_adjust_sharpness:
p: 0.5
sharpness_factor: 0.5
random_autocontrast:
p: 0.5
random_equalize:
p: 0.5
clamp_and_sanitize_bboxes:
min_size: 1.0
# ----------------------------
# Hyperparameter optimisation
# -----------------------------
Expand Down
78 changes: 68 additions & 10 deletions crabs/detection_tracking/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,65 @@ def __init__(
list_annotation_files: list[str],
config: dict,
split_seed: Optional[int] = None,
no_data_augmentation: bool = False,
):
super().__init__()
self.list_img_dirs = list_img_dirs
self.list_annotation_files = list_annotation_files
self.split_seed = split_seed
self.config = config
self.no_data_augmentation = no_data_augmentation

def _transform_str_to_operator(self, transform_str):
"""Get transform operator from its name in snake case"""

def snake_to_camel_case(snake_str):
return "".join(
x.capitalize() for x in snake_str.lower().split("_")
)

transform_callable = getattr(
transforms, snake_to_camel_case(transform_str)
)

return transform_callable(**self.config[transform_str])

def _compute_list_of_transforms(self) -> list[torchvision.transforms.v2]:
"""Read transforms from config and add to list"""

# Initialise list
train_data_augm: list[torchvision.transforms.v2] = []

# Apply standard transforms if defined
for transform_str in [
"gaussian_blur",
"color_jitter",
"random_horizontal_flip",
"random_rotation",
"random_adjust_sharpness",
"random_autocontrast",
"random_equalize",
]:
if transform_str in self.config:
transform_operator = self._transform_str_to_operator(
transform_str
)
train_data_augm.append(transform_operator)

# Apply clamp and sanitize bboxes if defined
# See https://pytorch.org/vision/main/generated/torchvision.transforms.v2.SanitizeBoundingBoxes.html#torchvision.transforms.v2.SanitizeBoundingBoxes
if "clamp_and_sanitize_bboxes" in self.config:
# Clamp bounding boxes
train_data_augm.append(transforms.ClampBoundingBoxes())

# Sanitize
sanitize = transforms.SanitizeBoundingBoxes(
min_size=self.config["clamp_and_sanitize_bboxes"]["min_size"],
labels_getter=None, # only bboxes are sanitized
)
train_data_augm.append(sanitize)

return train_data_augm

def _get_train_transform(self) -> torchvision.transforms:
"""Define data augmentation transforms for the train set.
Expand All @@ -38,17 +91,22 @@ def _get_train_transform(self) -> torchvision.transforms:
https://pytorch.org/vision/stable/transforms.html#v1-or-v2-which-one-should-i-use
https://pytorch.org/vision/main/auto_examples/transforms/plot_transforms_e2e.html#transforms
ToDtype is the recommended replacement for ConvertImageDtype(dtype)
https://pytorch.org/vision/0.17/generated/torchvision.transforms.v2.ToDtype.html#torchvision.transforms.v2.ToDtype
"""
jitter = transforms.ColorJitter(
brightness=self.config["transform_brightness"],
hue=self.config["transform_hue"],
)
gauss = transforms.GaussianBlur(
kernel_size=self.config["gaussian_blur_params"]["kernel_size"],
sigma=self.config["gaussian_blur_params"]["sigma"],
)
todtype = transforms.ToDtype(torch.float32, scale=True)
train_transforms = [transforms.ToImage(), jitter, gauss, todtype]
# Compute list of transforms to apply
if self.no_data_augmentation:
train_data_augm = []
else:
train_data_augm = self._compute_list_of_transforms()

# Define a Compose transform with them
train_transforms = [
transforms.ToImage(),
*train_data_augm,
transforms.ToDtype(torch.float32, scale=True),
]
return transforms.Compose(train_transforms)

def _get_test_val_transform(self) -> torchvision.transforms:
Expand Down
36 changes: 35 additions & 1 deletion crabs/detection_tracking/detection_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import argparse
import datetime
import logging
import os
from pathlib import Path
from typing import Any
from typing import Any, Optional

import torch
from lightning.pytorch.loggers import MLFlowLogger

DEFAULT_ANNOTATIONS_FILENAME = "VIA_JSON_combined_coco_gen.json"
Expand Down Expand Up @@ -240,3 +242,35 @@ def slurm_logs_as_artifacts(logger, slurm_job_id):
logger.run_id,
f"{log_filename}.{ext}",
)


def log_data_augm_as_artifacts(logger, data_module):
"""Log data augmentation transforms as artifacts in MLflow."""
for transform_str in ["train_transform", "test_val_transform"]:
logger.experiment.log_text(
text=str(getattr(data_module, f"_get_{transform_str}")()),
artifact_file=f"{transform_str}.txt",
run_id=logger.run_id,
)


def get_checkpoint_type(checkpoint_path: Optional[str]) -> Optional[str]:
"""Get checkpoint type (full or weights) from the checkpoint path."""
checkpoint = torch.load(checkpoint_path) # fails if path doesn't exist
if all(
[
param in checkpoint
for param in ["optimizer_states", "lr_schedulers"]
]
):
checkpoint_type = "full" # for resuming training
logging.info(
f"Resuming training from checkpoint at: {checkpoint_path}"
)
else:
checkpoint_type = "weights" # for fine tuning
logging.info(
f"Fine-tuning training from checkpoint at: {checkpoint_path}"
)

return checkpoint_type
8 changes: 4 additions & 4 deletions crabs/detection_tracking/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,10 @@ def evaluate_model(self) -> None:
"""
# Create datamodule
data_module = CrabsDataModule(
self.images_dirs,
self.annotation_files,
self.config,
self.seed_n,
list_img_dirs=self.images_dirs,
list_annotation_files=self.annotation_files,
split_seed=self.seed_n,
config=self.config,
)

# Get trained model
Expand Down
91 changes: 40 additions & 51 deletions crabs/detection_tracking/train_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import argparse
import logging
import os
import sys
from pathlib import Path
Expand All @@ -12,6 +11,8 @@

from crabs.detection_tracking.datamodules import CrabsDataModule
from crabs.detection_tracking.detection_utils import (
get_checkpoint_type,
log_data_augm_as_artifacts,
prep_annotation_files,
prep_img_directories,
set_mlflow_run_name,
Expand Down Expand Up @@ -57,6 +58,7 @@ def __init__(self, args):
self.fast_dev_run = args.fast_dev_run
self.limit_train_batches = args.limit_train_batches

# Restart from checkpoint
self.checkpoint_path = args.checkpoint_path

def load_config_yaml(self):
Expand Down Expand Up @@ -158,67 +160,44 @@ def core_training(self) -> lightning.Trainer:
"""
# Create data module
data_module = CrabsDataModule(
self.images_dirs,
self.annotation_files,
self.config,
self.seed_n,
list_img_dirs=self.images_dirs,
list_annotation_files=self.annotation_files,
split_seed=self.seed_n,
config=self.config,
no_data_augmentation=self.args.no_data_augmentation,
)

# Get checkpoint type
if self.checkpoint_path and os.path.exists(self.checkpoint_path):
checkpoint = torch.load(self.checkpoint_path)
if all(
[
param in checkpoint
for param in ["optimizer_states", "lr_schedulers"]
]
):
checkpoint_type = "full" # for resuming training
logging.info(
f"Resuming training from checkpoint at: {self.checkpoint_path}"
)
else:
checkpoint_type = "weights" # for fine tuning
logging.info(
f"Fine-tuning training from checkpoint at: {self.checkpoint_path}"
)
else:
checkpoint_type = None

# Get model
if checkpoint_type == "weights":
# Note: weights-only checkpoint contains hyperparameters
# see https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
lightning_model = FasterRCNN.load_from_checkpoint(
self.checkpoint_path,
config=self.config,
optuna_log=self.args.optuna,
# overwrite checkpoint hyperparameters with config ones
# otherwise ckpt hyperparameters are logged to MLflow, but yaml hyperparameters are used
)
else:
if not self.checkpoint_path:
lightning_model = FasterRCNN(
self.config, optuna_log=self.args.optuna
)
checkpoint_type = None
else:
checkpoint_type = get_checkpoint_type(self.checkpoint_path)
if checkpoint_type == "weights":
lightning_model = FasterRCNN.load_from_checkpoint(
self.checkpoint_path,
config=self.config, # overwrite hparams from ckpt with config
optuna_log=self.args.optuna,
) # a 'weights' checkpoint is one saved with `save_weights_only=True`

# Get trainer
trainer = self.setup_trainer()
if self.args.log_data_augmentation:
log_data_augm_as_artifacts(trainer.logger, data_module)

# Run training
# Resume from full checkpoint if available
# (automatically restores model, epoch, step, LR schedulers, etc...)
# https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
if checkpoint_type == "full":
trainer.fit(
lightning_model,
data_module,
ckpt_path=self.checkpoint_path, # needs to having been saved with `save_weights_only=False`
)
else: # for "weights" or no checkpoint
trainer.fit(
lightning_model,
data_module,
)
trainer.fit(
lightning_model,
data_module,
ckpt_path=(
self.checkpoint_path if checkpoint_type == "full" else None
),
# a 'full' checkpoint is one saved with `save_weights_only=False`
# (automatically restores model, epoch, step, LR schedulers, etc...)
# see https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
)

return trainer

Expand Down Expand Up @@ -344,6 +323,16 @@ def train_parse_args(args):
action="store_true",
help="Run a hyperparameter optimisation using Optuna prior to training the model",
)
parser.add_argument(
"--no_data_augmentation",
action="store_true",
help="Ignore the data augmentation transforms defined in config file",
)
parser.add_argument(
"--log_data_augmentation",
action="store_true",
help="Log data augmentation transforms linked to datamodule as MLflow artifacts",
)
return parser.parse_args(args)


Expand Down
45 changes: 45 additions & 0 deletions notebooks/notebook_data_augm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# %%
import yaml # type: ignore

from crabs.detection_tracking.datamodules import CrabsDataModule
from crabs.detection_tracking.visualization import plot_sample

# %%%%%%%%%%%%%%%%%%%
# Input data
IMG_DIR = "/home/sminano/swc/project_crabs/data/sep2023-full/frames"
ANNOT_FILE = "/home/sminano/swc/project_crabs/data/sep2023-full/annotations/VIA_JSON_combined_coco_gen.json"
CONFIG = "/home/sminano/swc/project_crabs/crabs-exploration/crabs/detection_tracking/config/faster_rcnn.yaml"
SPLIT_SEED = 42

# %%%%%%%%%%%%%%%%%%%%
# Read config as dict
with open(CONFIG, "r") as f:
config_dict = yaml.safe_load(f)

# %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
# Create datamodule for the input data
dm = CrabsDataModule(
list_img_dirs=[IMG_DIR],
list_annotation_files=[ANNOT_FILE],
config=config_dict,
split_seed=SPLIT_SEED,
)
# %%%%%%%%%%%%%%%%%%%%%%%%
# Setup for train / test
dm.prepare_data()
dm.setup("fit")


# %%%%%%%%%%%%%%%%%%%%%%%%%%%
# after this: dm.train_dataset should have transforms, (but not dm.test_dataset)
print(dm.train_transform)
print(dm.val_transform)
print(dm.test_transform)

# %%%%%%%%%%%%%%%%%%%%%%%%%
# visualize
train_dataset = dm.train_dataset
train_sample = train_dataset[0]
plot_sample([train_sample])

# %%
Loading

0 comments on commit 81db31e

Please sign in to comment.