Skip to content

Commit

Permalink
Merge branch 'main' into smg/data-augm
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig authored Jun 27, 2024
2 parents 90da5cb + e18357e commit 1546145
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion crabs/detection_tracking/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,23 @@ def core_training(self) -> lightning.Trainer:
checkpoint_type = None

# Get model
lightning_model = FasterRCNN(self.config, optuna_log=self.args.optuna)
if checkpoint_type == "weights":
# Note: weights-only checkpoint contains hyperparameters
# see https://lightning.ai/docs/pytorch/stable/common/checkpointing_basic.html#save-hyperparameters
lightning_model = FasterRCNN.load_from_checkpoint(
self.checkpoint_path,
config=self.config,
optuna_log=self.args.optuna,
# overwrite checkpoint hyperparameters with config ones
# otherwise ckpt hyperparameters are logged to MLflow, but yaml hyperparameters are used
)
else:
lightning_model = FasterRCNN(
self.config, optuna_log=self.args.optuna
)

# Get trainer
trainer = self.setup_trainer()

# Get trainer
trainer = self.setup_trainer()
Expand Down

0 comments on commit 1546145

Please sign in to comment.