Skip to content

Commit

Permalink
Small refactoring to detector evaluation (#240)
Browse files Browse the repository at this point in the history
* Small additions

* Update readme
  • Loading branch information
sfmig authored Nov 8, 2024
1 parent 420e0f7 commit 9b4722e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 6 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ If a file with ground-truth annotations is passed to the command (with the `--an

To see the full list of possible arguments to the `evaluate-detector` command, run it with the `--help` flag.


## Task-specific guides
For further information on specific tasks, such as launching a training job or evaluating a set of models in the HPC cluster, please see [our guides](guides).

<!-- ### Evaluate the tracking performance
Expand Down
11 changes: 7 additions & 4 deletions crabs/detector/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ def __init__(self, args: argparse.Namespace) -> None:
)
self.evaluation_split = "test" if self.args.use_test_set else "val"

# output directory for frames
self.frames_output_dir_root = str(
Path(self.args.frames_output_dir)
/ f"evaluation_output_{self.evaluation_split}"
)

# Hardware
self.accelerator = args.accelerator

Expand Down Expand Up @@ -172,10 +178,7 @@ def evaluate_model(self) -> None:
save_images_with_boxes(
dataloader=eval_dataloader,
trained_model=trained_model,
output_dir=str(
Path(self.args.frames_output_dir)
/ f"evaluation_output_{self.evaluation_split}"
),
output_dir=self.frames_output_dir_root,
score_threshold=self.args.frames_score_threshold,
)

Expand Down
5 changes: 4 additions & 1 deletion crabs/detector/utils/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import ast
import sys
from pathlib import Path
from typing import Optional

import torchvision
import yaml # type: ignore
Expand Down Expand Up @@ -145,7 +146,9 @@ def get_mlflow_parameters_from_ckpt(trained_model_path: str) -> dict:
return params


def get_config_from_ckpt(config_file: str, trained_model_path: str) -> dict:
def get_config_from_ckpt(
config_file: Optional[str], trained_model_path: str
) -> dict:
"""Get config from checkpoint if config is not passed as a CLI argument."""
# If config in CLI arguments: used passed config
if config_file:
Expand Down

0 comments on commit 9b4722e

Please sign in to comment.