Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding checkpoint_path for resume training #182

Merged
merged 25 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ecc750d
adding ckpt_path to fit to resume training
nikk-nikaznan Jun 7, 2024
68c0053
option to resume or fine tunning
nikk-nikaznan Jun 11, 2024
d412475
small changes
nikk-nikaznan Jun 12, 2024
839df35
add checkpoint option in the guide
nikk-nikaznan Jun 12, 2024
c9b64a3
cleaned up guide
nikk-nikaznan Jun 12, 2024
7938a6b
Merge branch 'main' into nikkna/resume_training
nikk-nikaznan Jun 18, 2024
562c20e
Merge branch 'main' into nikkna/resume_training
nikk-nikaznan Jun 20, 2024
05856d2
cleaned up
nikk-nikaznan Jun 20, 2024
933aeef
tring rename the ckpt
nikk-nikaznan Jun 21, 2024
154c1ec
Merge branch 'main' into nikkna/resume_training
nikk-nikaznan Jun 21, 2024
6db4b76
Merge branch 'nikkna/resume_training' of github.com:SainsburyWellcome…
nikk-nikaznan Jun 21, 2024
1d37237
cleaned up after rebase
nikk-nikaznan Jun 21, 2024
a553a20
some changes in the guide
nikk-nikaznan Jun 21, 2024
7210436
run pre-commit
nikk-nikaznan Jun 21, 2024
466265c
fixed test
nikk-nikaznan Jun 21, 2024
0666762
Merge branch 'main' into nikkna/resume_training
nikk-nikaznan Jun 24, 2024
6aa414f
parsing the config to model instance during fine-tunning
nikk-nikaznan Jun 25, 2024
ba5fb78
Merge branch 'nikkna/resume_training' of github.com:SainsburyWellcome…
nikk-nikaznan Jun 25, 2024
d717661
small changes on guide
nikk-nikaznan Jun 26, 2024
7f71743
changes based on the review
nikk-nikaznan Jun 26, 2024
5df0361
small changes
nikk-nikaznan Jun 26, 2024
35e08e5
Update crabs/detection_tracking/train_model.py
nikk-nikaznan Jun 26, 2024
6c2d99e
Update crabs/detection_tracking/train_model.py
nikk-nikaznan Jun 26, 2024
36b5cbe
Update crabs/detection_tracking/train_model.py
nikk-nikaznan Jun 26, 2024
7f425cc
cleaned up pre-commit
nikk-nikaznan Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 59 additions & 11 deletions crabs/detection_tracking/train_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import argparse
import logging
import os
import sys
from pathlib import Path
Expand Down Expand Up @@ -56,6 +57,8 @@ def __init__(self, args):
self.fast_dev_run = args.fast_dev_run
self.limit_train_batches = args.limit_train_batches

self.checkpoint_path = args.checkpoint_path

def load_config_yaml(self):
with open(self.config_file, "r") as f:
self.config = yaml.safe_load(f)
Expand Down Expand Up @@ -94,14 +97,17 @@ def setup_trainer(self):
enable_checkpointing = False

# Return trainer linked to callbacks and logger
return lightning.Trainer(
max_epochs=self.config["n_epochs"],
accelerator=self.accelerator,
logger=mlf_logger,
enable_checkpointing=enable_checkpointing,
callbacks=checkpoint_callback,
fast_dev_run=self.fast_dev_run,
limit_train_batches=self.limit_train_batches,
return (
lightning.Trainer(
max_epochs=self.config["n_epochs"],
accelerator=self.accelerator,
logger=mlf_logger,
enable_checkpointing=enable_checkpointing,
callbacks=[checkpoint_callback] if checkpoint_callback else [],
fast_dev_run=self.fast_dev_run,
limit_train_batches=self.limit_train_batches,
),
checkpoint_callback,
)
nikk-nikaznan marked this conversation as resolved.
Show resolved Hide resolved

def optuna_objective_fn(self, trial: optuna.Trial) -> float:
Expand Down Expand Up @@ -163,10 +169,46 @@ def core_training(self) -> lightning.Trainer:

# Get model
lightning_model = FasterRCNN(self.config)

# Run training
trainer = self.setup_trainer()
trainer.fit(lightning_model, data_module)
trainer, checkpoint_callback = self.setup_trainer()
nikk-nikaznan marked this conversation as resolved.
Show resolved Hide resolved

if self.checkpoint_path and os.path.exists(self.checkpoint_path):
logging.info(
f"Checking contents of checkpoint: {self.checkpoint_path}"
)
checkpoint = torch.load(self.checkpoint_path)

if (
"optimizer_states" in checkpoint
and "lr_schedulers" in checkpoint
):
logging.info(
"Checkpoint contains full training state. Resume training"
)
# Resume training from the checkpoint
trainer.fit(
lightning_model,
data_module,
ckpt_path=self.checkpoint_path,
)
else:
logging.info(
"Checkpoint contains only model weights. Load the weight from a trained model"
)
# Load model weights and start fine-tuning
model = FasterRCNN.load_from_checkpoint(self.checkpoint_path)
trainer.fit(model, data_module)

else:
trainer.fit(
lightning_model,
data_module,
)

if checkpoint_callback and checkpoint_callback.last_model_path:
logging.info(
f"Last checkpoint path: {checkpoint_callback.last_model_path}"
)
nikk-nikaznan marked this conversation as resolved.
Show resolved Hide resolved

return trainer

Expand Down Expand Up @@ -281,6 +323,12 @@ def train_parse_args(args):
default="./ml-runs",
help=("Path to MLflow directory. Default: ./ml-runs"),
)
parser.add_argument(
"--checkpoint_path",
type=str,
default=None,
help=("Path to checkpoint for resume training"),
)
parser.add_argument(
"--optuna",
action="store_true",
Expand Down
24 changes: 21 additions & 3 deletions guides/TrainingModelsHPC.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,25 @@
>
> If we launch a job and then modify the config file _before_ the job has been able to read it, we may be using an undesired version of the config in our job! To avoid this, it is best to wait until you can verify in MLflow that the job has the expected config parameters (and then edit the file to launch a new job if needed).

6. **Optional argument - Optuna**
6. **Checkpoint**

The `checkpoint_path` argument can be useful. There are two primary options related to checkpoints:

- Resume training

- This option is useful for interrupted training sessions or extending training duration.
- If training is disrupted and stops mid-way, you can resume it by adding `--checkpoint_path $CKPT_PATH \` to your bash script.
- The training will pick up from the last saved epoch and continue until the specified n_epoch.
- Similarly, if training completes but you want to extend it based on metric evaluations, you can increase the n_epoch value (e.g., from `n` to `n + y`). Again, use `--checkpoint_path $CKPT_PATH \` in your bash script, and training will resume from epoch `n` to `n + y`.
- Ensure the `save_weights_only` parameter under `checkpoint_saving` in the config file is set to `False` to resume training, as this option requires loading both weights and the training state.

- Fine-tunning
- This option is useful for fine-tuning a pre-trained model on a different dataset.
- It loads the weights from a checkpoint, allowing you to leverage pre-trained weights from another dataset.
- Add `--checkpoint_path $CKPT_PATH \` to your bash script to use this option.
- Set the `save_weights_only` parameter under `checkpoint_saving` in the config file to `True`, as only the weights are needed for fine-tuning.
sfmig marked this conversation as resolved.
Show resolved Hide resolved

7. **Optional argument - Optuna**

We have the option to run [Optuna](https://optuna.org) which is a hyperparameter optimization framework that allows us the find the best hyperparameters for our model.

Expand All @@ -117,15 +135,15 @@
--optuna
```

7. **Run the training job using the SLURM scheduler**
8. **Run the training job using the SLURM scheduler**

To launch a job, use the `sbatch` command with the relevant training script:

```
sbatch <path-to-training-bash-script>
```

8. **Check the status of the training job**
9. **Check the status of the training job**

To do this, we can:

Expand Down
1 change: 1 addition & 0 deletions tests/test_unit/test_optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def args():
mlflow_folder="/tmp/mlflow",
fast_dev_run=True,
limit_train_batches=False,
checkpoint_path=None,
)


Expand Down