Skip to content

Commit

Permalink
Optionally log data augmentation transforms as artifacts
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig committed Jun 27, 2024
1 parent 43f852b commit 5119bd1
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
10 changes: 10 additions & 0 deletions crabs/detection_tracking/detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,16 @@ def slurm_logs_as_artifacts(logger, slurm_job_id):
)


def log_data_augm_as_artifacts(logger, data_module):
"""Log data augmentation transforms as artifacts in MLflow."""
for transform_str in ["train_transform", "test_val_transform"]:
logger.experiment.log_text(
text=str(getattr(data_module, f"_get_{transform_str}")()),
artifact_file=f"{transform_str}.txt",
run_id=logger.run_id,
)


def get_checkpoint_type(checkpoint_path: Optional[str]) -> Optional[str]:
"""Get checkpoint type (full or weights) from the checkpoint path."""
checkpoint = torch.load(checkpoint_path) # fails if path doesn't exist
Expand Down
8 changes: 8 additions & 0 deletions crabs/detection_tracking/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from crabs.detection_tracking.datamodules import CrabsDataModule
from crabs.detection_tracking.detection_utils import (
get_checkpoint_type,
log_data_augm_as_artifacts,
prep_annotation_files,
prep_img_directories,
set_mlflow_run_name,
Expand Down Expand Up @@ -183,6 +184,8 @@ def core_training(self) -> lightning.Trainer:

# Get trainer
trainer = self.setup_trainer()
if self.args.log_data_augmentation:
log_data_augm_as_artifacts(trainer.logger, data_module)

# Run training
trainer.fit(
Expand Down Expand Up @@ -325,6 +328,11 @@ def train_parse_args(args):
action="store_true",
help="Ignore the data augmentation transforms defined in config file",
)
parser.add_argument(
"--log_data_augmentation",
action="store_true",
help="Log data augmentation transforms linked to datamodule as MLflow artifacts",
)
return parser.parse_args(args)


Expand Down

0 comments on commit 5119bd1

Please sign in to comment.