Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small edits to training #235

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions crabs/detector/train_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Train FasterRCNN model for object detection."""

import argparse
import logging
import os
import sys
from pathlib import Path
Expand All @@ -27,7 +28,7 @@
)


class DectectorTrain:
class DetectorTrain:
"""Training class for detector algorithm.

Parameters
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self, args: argparse.Namespace):

# MLflow
self.experiment_name = args.experiment_name
self.run_name = set_mlflow_run_name()
self.mlflow_folder = args.mlflow_folder

# Debugging
Expand All @@ -65,15 +67,28 @@ def __init__(self, args: argparse.Namespace):
# Restart from checkpoint
self.checkpoint_path = args.checkpoint_path

# Log dataset and MLflow details to screen
# log_job_metadata_to_screen(self)
logging.info("Dataset")
logging.info(f"Images directories: {self.images_dirs}")
logging.info(f"Annotation files: {self.annotation_files}")
logging.info(f"Seed: {self.seed_n}")
logging.info("---------------------------------")

# Log MLflow information to screen
logging.info("MLflow logs for current job")
logging.info(f"Experiment name: {self.experiment_name}")
logging.info(f"Run name: {self.run_name}")
logging.info(f"Folder: {Path(self.mlflow_folder).resolve()}")
logging.info("---------------------------------")

def load_config_yaml(self):
"""Load yaml file that contains config parameters."""
with open(self.config_file) as f:
self.config = yaml.safe_load(f)

def setup_trainer(self):
"""Set up trainer with logging and checkpointing."""
self.run_name = set_mlflow_run_name()

# Setup logger with checkpointing
mlf_logger = setup_mlflow_logger(
experiment_name=self.experiment_name,
Expand Down Expand Up @@ -248,7 +263,7 @@ def main(args) -> None:
None

"""
trainer = DectectorTrain(args)
trainer = DetectorTrain(args)
trainer.train_model()


Expand Down
6 changes: 3 additions & 3 deletions tests/test_unit/test_optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from crabs.detector.train_model import DectectorTrain
from crabs.detector.train_model import DetectorTrain
from crabs.detector.utils.hpo import compute_optimal_hyperparameters


Expand Down Expand Up @@ -37,8 +37,8 @@ def args():

@pytest.fixture
def detector_train(args, config):
with patch.object(DectectorTrain, "load_config_yaml", MagicMock()):
train_instance = DectectorTrain(args=args)
with patch.object(DetectorTrain, "load_config_yaml", MagicMock()):
train_instance = DetectorTrain(args=args)
print(config)
train_instance.config = config
return train_instance
Expand Down
12 changes: 6 additions & 6 deletions tests/test_unit/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
)
def test_prep_img_directories(dataset_dirs: list):
"""Test parsing of image directories when training a model."""
from crabs.detector.train_model import DectectorTrain
from crabs.detector.train_model import DetectorTrain

# prepare parser
train_args = train_parse_args(["--dataset_dirs"] + dataset_dirs)

# instantiate detector
detector = DectectorTrain(train_args)
detector = DetectorTrain(train_args)

# check image directories are parsed correctly
list_imgs_dirs = [str(Path(d) / "frames") for d in dataset_dirs]
Expand All @@ -47,7 +47,7 @@ def test_prep_annotation_files_single_dataset(annotation_files, expected):
"""Test parsing of annotation files when training a model on a single
dataset.
"""
from crabs.detector.train_model import DectectorTrain
from crabs.detector.train_model import DetectorTrain

# prepare CLI arguments
cli_inputs = ["--dataset_dirs", DATASET_1]
Expand All @@ -59,7 +59,7 @@ def test_prep_annotation_files_single_dataset(annotation_files, expected):
train_args = train_parse_args(cli_inputs + annotation_files)

# instantiate detector
detector = DectectorTrain(train_args)
detector = DetectorTrain(train_args)

# check annotation files are as expected
assert detector.annotation_files == expected
Expand Down Expand Up @@ -89,7 +89,7 @@ def test_prep_annotation_files_multiple_datasets(annotation_files, expected):
"""Test parsing of annotation files when training
a model on two datasets.
"""
from crabs.detector.train_model import DectectorTrain
from crabs.detector.train_model import DetectorTrain

# prepare CLI arguments considering multiple dataset
cli_inputs = ["--dataset_dirs", DATASET_1, DATASET_2]
Expand All @@ -101,7 +101,7 @@ def test_prep_annotation_files_multiple_datasets(annotation_files, expected):
train_args = train_parse_args(cli_inputs + annotation_files)

# instantiate detector
detector = DectectorTrain(train_args)
detector = DetectorTrain(train_args)

# check annotation files are as expected
assert detector.annotation_files == expected