Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding checkpoint_path for resume training #182

Merged
merged 25 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ecc750d
adding ckpt_path to fit to resume training
nikk-nikaznan Jun 7, 2024
68c0053
option to resume or fine tunning
nikk-nikaznan Jun 11, 2024
d412475
small changes
nikk-nikaznan Jun 12, 2024
839df35
add checkpoint option in the guide
nikk-nikaznan Jun 12, 2024
c9b64a3
cleaned up guide
nikk-nikaznan Jun 12, 2024
7938a6b
Merge branch 'main' into nikkna/resume_training
nikk-nikaznan Jun 18, 2024
562c20e
Merge branch 'main' into nikkna/resume_training
nikk-nikaznan Jun 20, 2024
05856d2
cleaned up
nikk-nikaznan Jun 20, 2024
933aeef
tring rename the ckpt
nikk-nikaznan Jun 21, 2024
154c1ec
Merge branch 'main' into nikkna/resume_training
nikk-nikaznan Jun 21, 2024
6db4b76
Merge branch 'nikkna/resume_training' of github.com:SainsburyWellcome…
nikk-nikaznan Jun 21, 2024
1d37237
cleaned up after rebase
nikk-nikaznan Jun 21, 2024
a553a20
some changes in the guide
nikk-nikaznan Jun 21, 2024
7210436
run pre-commit
nikk-nikaznan Jun 21, 2024
466265c
fixed test
nikk-nikaznan Jun 21, 2024
0666762
Merge branch 'main' into nikkna/resume_training
nikk-nikaznan Jun 24, 2024
6aa414f
parsing the config to model instance during fine-tunning
nikk-nikaznan Jun 25, 2024
ba5fb78
Merge branch 'nikkna/resume_training' of github.com:SainsburyWellcome…
nikk-nikaznan Jun 25, 2024
d717661
small changes on guide
nikk-nikaznan Jun 26, 2024
7f71743
changes based on the review
nikk-nikaznan Jun 26, 2024
5df0361
small changes
nikk-nikaznan Jun 26, 2024
35e08e5
Update crabs/detection_tracking/train_model.py
nikk-nikaznan Jun 26, 2024
6c2d99e
Update crabs/detection_tracking/train_model.py
nikk-nikaznan Jun 26, 2024
36b5cbe
Update crabs/detection_tracking/train_model.py
nikk-nikaznan Jun 26, 2024
7f425cc
cleaned up pre-commit
nikk-nikaznan Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 64 additions & 9 deletions crabs/detection_tracking/train_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import logging
import os
import sys
from pathlib import Path
Expand Down Expand Up @@ -56,6 +57,8 @@ def __init__(self, args):
self.fast_dev_run = args.fast_dev_run
self.limit_train_batches = args.limit_train_batches

self.checkpoint_path = args.checkpoint_path

def load_config_yaml(self):
with open(self.config_file, "r") as f:
self.config = yaml.safe_load(f)
Expand All @@ -77,16 +80,16 @@ def setup_trainer(self):
)

# Define checkpointing callback for trainer
config = self.config.get("checkpoint_saving")
if config:
config_ckpt = self.config.get("checkpoint_saving")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like a different name 😁 ✨

if config_ckpt:
checkpoint_callback = ModelCheckpoint(
filename="checkpoint-{epoch}",
every_n_epochs=config["every_n_epochs"],
save_top_k=config["keep_last_n_ckpts"],
every_n_epochs=config_ckpt["every_n_epochs"],
save_top_k=config_ckpt["keep_last_n_ckpts"],
monitor="epoch", # monitor the metric "epoch" for selecting which checkpoints to save
mode="max", # get the max of the monitored metric
save_last=config["save_last"],
save_weights_only=config["save_weights_only"],
save_last=config_ckpt["save_last"],
save_weights_only=config_ckpt["save_weights_only"],
)
enable_checkpointing = True
else:
Expand Down Expand Up @@ -161,12 +164,58 @@ def core_training(self) -> lightning.Trainer:
self.seed_n,
)

# Get checkpoint type
if self.checkpoint_path and os.path.exists(self.checkpoint_path):
checkpoint = torch.load(self.checkpoint_path)
if all(
[
param in checkpoint
for param in ["optimizer_states", "lr_schedulers"]
]
):
checkpoint_type = "full" # for resuming training
logging.info(
f"Resuming training from checkpoint at: {self.checkpoint_path}"
)
else:
checkpoint_type = "weights" # for fine tuning
logging.info(
f"Fine-tuning training from checkpoint at: {self.checkpoint_path}"
)
else:
checkpoint_type = None

# Get model
lightning_model = FasterRCNN(self.config)
if checkpoint_type == "weights":
# Note: weights-only checkpoint contains hyperparameters
# see https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
lightning_model = FasterRCNN.load_from_checkpoint(
self.checkpoint_path,
config=self.config,
# overwrite checkpoint hyperparameters with config ones
# otherwise ckpt hyperparameters are logged to MLflow, but yaml hyperparameters are used
)
else:
lightning_model = FasterRCNN(self.config)

# Run training
# Get trainer
trainer = self.setup_trainer()
trainer.fit(lightning_model, data_module)

# Run training
# Resume from full checkpoint if available
# (automatically restores model, epoch, step, LR schedulers, etc...)
# https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
if checkpoint_type == "full":
trainer.fit(
lightning_model,
data_module,
ckpt_path=self.checkpoint_path, # needs to having been saved with `save_weights_only=False`
)
else: # for "weights" or no checkpoint
trainer.fit(
lightning_model,
data_module,
)

return trainer

Expand Down Expand Up @@ -281,6 +330,12 @@ def train_parse_args(args):
default="./ml-runs",
help=("Path to MLflow directory. Default: ./ml-runs"),
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help=("Path to checkpoint for resume training"),
)
parser.add_argument(
"--optuna",
action="store_true",
Expand Down
25 changes: 22 additions & 3 deletions guides/TrainingModelsHPC.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,26 @@
>
> If we launch a job and then modify the config file _before_ the job has been able to read it, we may be using an undesired version of the config in our job! To avoid this, it is best to wait until you can verify in MLflow that the job has the expected config parameters (and then edit the file to launch a new job if needed).

6. **Optional argument - Optuna**
6. **Restarting training from a checkpoint**

The `checkpoint_path` argument can be useful. There are two primary options related to checkpoints:

- Resume training

- This option is useful for interrupted training sessions or extending training duration.
- If training is disrupted and stops mid-way, you can resume it by adding `--checkpoint_path $CKPT_PATH \` to your bash script.
- The training will pick up from the last saved epoch and continue until the specified n_epoch.
- Similarly, if training completes but you want to extend it based on metric evaluations, you can increase the n_epoch value (e.g., from `n` to `n + y`). If n_epoch is the same, no new training will be continued as the max_epoch has been reached.
Again, use `--checkpoint_path $CKPT_PATH \` in your bash script, and training will resume from epoch `n` to `n + y`.
- Ensure the `save_weights_only` parameter under `checkpoint_saving` in the config file is set to `False` to resume training, as this option requires loading both weights and the training state.

- Fine-tunning
- This option is useful for fine-tuning a pre-trained model on a different dataset.
- It loads the weights from a checkpoint, allowing you to leverage pre-trained weights from another dataset.
- Add `--checkpoint_path $CKPT_PATH \` to your bash script to use this option.
- Set the `save_weights_only` parameter under `checkpoint_saving` in the config file to `True`, as only the weights are needed for fine-tuning.

7. **Optional argument - Optuna**

We have the option to run [Optuna](https://optuna.org) which is a hyperparameter optimization framework that allows us the find the best hyperparameters for our model.

Expand All @@ -117,15 +136,15 @@
--optuna
```

7. **Run the training job using the SLURM scheduler**
8. **Run the training job using the SLURM scheduler**

To launch a job, use the `sbatch` command with the relevant training script:

```
sbatch <path-to-training-bash-script>
```

8. **Check the status of the training job**
9. **Check the status of the training job**

To do this, we can:

Expand Down
1 change: 1 addition & 0 deletions tests/test_unit/test_optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def args():
mlflow_folder="/tmp/mlflow",
fast_dev_run=True,
limit_train_batches=False,
checkpoint_path=None,
)


Expand Down