Skip to content

Commit

Permalink
MOTA revisited (#181)
Browse files Browse the repository at this point in the history
* modify id switches

* change the mota and test

* change the variable

* some bug fixed, load from checkpoint

* change list type, add gt_ids

* fixed the error id switches

* changes id switches

* fixing some test and type hint

* fixing test, parametrize the test with additional test

* cleane dup

* checking some test

* cleaned up

* test works

* test works

* aded specific example

* some more test

* Update crabs/tracker/evaluate_tracker.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

* Update crabs/tracker/evaluate_tracker.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

* combine gt functions, fix test

* rename test

* cleaned up linting

* adding some more description

* change the nested folder structure for output

* adding device to cli

* attempt yesterday

* small changes in docstring

* Update crabs/tracker/utils/io.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

* Update crabs/tracker/evaluate_tracker.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

* changes for gt dict

* predicted as dict

* rename varibale, fix test

* reviewing id switch

* commented out the test that fail

* commented out the test that fail

* seems working

* small modification for the test

* cleaned up

* cleaned up

* Update crabs/tracker/track_video.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

* Update crabs/tracker/track_video.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

* Update crabs/tracker/evaluate_tracker.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

* Update crabs/tracker/track_video.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

* Update crabs/tracker/evaluate_tracker.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

* Update crabs/tracker/evaluate_tracker.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

* fixed frame_number vs frame_idx

* Update crabs/tracker/evaluate_tracker.py

Co-authored-by: sfmig <[email protected]>
Signed-off-by: nikk-nikaznan <[email protected]>

---------

Signed-off-by: nikk-nikaznan <[email protected]>
Co-authored-by: sfmig <[email protected]>
  • Loading branch information
nikk-nikaznan and sfmig authored Jul 9, 2024
1 parent ff59d84 commit 99ed754
Show file tree
Hide file tree
Showing 10 changed files with 780 additions and 433 deletions.
2 changes: 1 addition & 1 deletion crabs/detector/config/faster_rcnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ num_classes: 2
# -------------------------------
# Training & validation parameters
# -------------------------------
n_epochs: 250
n_epochs: 1
learning_rate: 0.00005
wdecay: 0.00005
batch_size_train: 4
Expand Down
349 changes: 208 additions & 141 deletions crabs/tracker/evaluate_tracker.py

Large diffs are not rendered by default.

98 changes: 49 additions & 49 deletions crabs/tracker/track_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
)
from crabs.tracker.utils.tracking import prep_sort

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Tracking:
"""
Expand All @@ -47,58 +45,49 @@ class Tracking:

def __init__(self, args: argparse.Namespace) -> None:
self.args = args

self.config_file = args.config_file
self.load_config_yaml() # TODO: load config from trained model (like in evaluation)?

self.video_path = args.video_path
self.video_file_root = f"{Path(self.video_path).stem}"
self.trained_model_path = self.args.trained_model_path
self.device = self.args.device

self.trained_model = self.load_trained_model()
self.setup()
self.prep_outputs()

self.sort_tracker = Sort(
max_age=self.config["max_age"],
min_hits=self.config["min_hits"],
iou_threshold=self.config["iou_threshold"],
)

(
self.csv_writer,
self.csv_file,
self.tracking_output_dir,
) = prep_csv_writer(self.args.output_dir, self.video_file_root)

def load_config_yaml(self):
def setup(self):
"""
Load yaml file that contains config parameters.
Load tracking config, trained model and input video path.
"""
with open(self.config_file, "r") as f:
self.config = yaml.safe_load(f)

def load_trained_model(self) -> torch.nn.Module:
"""
Load the trained model.
Returns
-------
torch.nn.Module
"""
# Get trained model
trained_model = FasterRCNN.load_from_checkpoint(
self.trained_model = FasterRCNN.load_from_checkpoint(
self.trained_model_path
)
trained_model.eval()
trained_model.to(DEVICE) # Should device be a CLI?
return trained_model
self.trained_model.eval()
self.trained_model.to(self.device)

def load_video(self) -> None:
"""
Load the input video, and prepare the output video if required.
"""
# Load the input video
self.video = cv2.VideoCapture(self.video_path)
if not self.video.isOpened():
raise Exception("Error opening video file")
self.video_file_root = f"{Path(self.video_path).stem}"

def prep_outputs(self):
"""
Prepare csv writer and if required, video writer.
"""
(
self.csv_writer,
self.csv_file,
self.tracking_output_dir,
) = prep_csv_writer(self.args.output_dir, self.video_file_root)

if self.args.save_video:
frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH))
Expand All @@ -107,7 +96,6 @@ def load_video(self) -> None:

self.video_output = prep_video_writer(
self.tracking_output_dir,
self.video_file_root,
frame_width,
frame_height,
cap_fps,
Expand Down Expand Up @@ -135,7 +123,7 @@ def get_prediction(self, frame: np.ndarray) -> torch.Tensor:
transforms.ToDtype(torch.float32, scale=True),
]
)
img = transform(frame).to(DEVICE)
img = transform(frame).to(self.device)
img = img.unsqueeze(0)
with torch.no_grad():
prediction = self.trained_model(img)
Expand All @@ -156,9 +144,9 @@ def update_tracking(self, prediction: dict) -> list[list[float]]:
list of tracked bounding boxes after updating the tracking system.
"""
pred_sort = prep_sort(prediction, self.config["score_threshold"])
tracked_boxes = self.sort_tracker.update(pred_sort)
self.tracked_list.append(tracked_boxes)
return tracked_boxes
tracked_boxes_id_per_frame = self.sort_tracker.update(pred_sort)
self.tracked_bbox_id.append(tracked_boxes_id_per_frame)
return tracked_boxes_id_per_frame

def run_tracking(self):
"""
Expand All @@ -171,52 +159,59 @@ def run_tracking(self):
)
return

# In any case run inference
# initialisation
frame_number = 1
self.tracked_list = []
frame_idx = 0
self.tracked_bbox_id = []

# Loop through frames of the video in batches
while self.video.isOpened():
# Break if beyond end frame (mostly for debugging)
if (
self.args.max_frames_to_read
and frame_number > self.args.max_frames_to_read
and frame_idx + 1 > self.args.max_frames_to_read
):
break

# get total n frames
total_frames = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT))

# read frame
ret, frame = self.video.read()
if not ret:
print("No frame read. Exiting...")
if not ret and (frame_idx == total_frames):
logging.info(f"All {total_frames} frames processed")
break
elif not ret:
logging.info(
f"Cannot read frame {frame_idx+1}/{total_frames}. Exiting..."
)
break

# predict bounding boxes
prediction = self.get_prediction(frame)
pred_scores = prediction[0]["scores"].detach().cpu().numpy()

# run tracking
tracked_boxes = self.update_tracking(prediction)
tracked_boxes_id_per_frame = self.update_tracking(prediction)
save_required_output(
self.video_file_root,
self.args.save_frames,
self.tracking_output_dir,
self.csv_writer,
self.args.save_video,
self.video_output,
tracked_boxes,
tracked_boxes_id_per_frame,
frame,
frame_number,
frame_idx + 1,
pred_scores,
)

# update frame number
frame_number += 1
frame_idx += 1

if self.args.gt_path:
evaluation = TrackerEvaluate(
self.args.gt_path,
self.tracked_list,
self.tracked_bbox_id,
self.config["iou_threshold"],
)
evaluation.run_evaluation()
Expand Down Expand Up @@ -247,7 +242,6 @@ def main(args) -> None:
"""

inference = Tracking(args)
inference.load_video()
inference.run_tracking()


Expand Down Expand Up @@ -277,7 +271,7 @@ def tracking_parse_args(args):
parser.add_argument(
"--output_dir",
type=str,
default="crabs_track_output",
default="tracking_output",
help="Directory to save the track output", # is this a csv or a video? (or both)
)
parser.add_argument(
Expand Down Expand Up @@ -305,6 +299,12 @@ def tracking_parse_args(args):
action="store_true",
help="Save frame to be used in correcting track labelling",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="device for pytorch either cpu or cuda",
)
return parser.parse_args(args)


Expand Down
10 changes: 5 additions & 5 deletions crabs/tracker/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import csv
import os
from datetime import datetime
from pathlib import Path

import cv2
Expand Down Expand Up @@ -29,13 +30,13 @@ def prep_csv_writer(output_dir: str, video_file_root: str):
A tuple containing the CSV writer, the CSV file object, and the tracking output directory path.
"""

crabs_tracks_label_dir = Path(output_dir) / "crabs_tracks_label"
tracking_output_dir = crabs_tracks_label_dir / video_file_root
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
tracking_output_dir = Path(output_dir + f"_{timestamp}") / video_file_root
# Create the subdirectory for the specific video file root
tracking_output_dir.mkdir(parents=True, exist_ok=True)

csv_file = open(
f"{str(tracking_output_dir / video_file_root)}.csv",
f"{str(tracking_output_dir)}/predicted_tracks.csv",
"w",
)
csv_writer = csv.writer(csv_file)
Expand All @@ -59,7 +60,6 @@ def prep_csv_writer(output_dir: str, video_file_root: str):

def prep_video_writer(
output_dir: str,
video_file_root: str,
frame_width: int,
frame_height: int,
cap_fps: float,
Expand Down Expand Up @@ -87,7 +87,7 @@ def prep_video_writer(
"""
output_file = os.path.join(
output_dir,
f"{os.path.basename(video_file_root)}_output_video.mp4",
"tracked_video.mp4",
)
output_codec = cv2.VideoWriter_fourcc("m", "p", "4", "v")
video_output = cv2.VideoWriter(
Expand Down
2 changes: 1 addition & 1 deletion crabs/tracker/utils/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def extract_bounding_box_info(row: list[str]) -> Dict[str, Any]:
height = region_shape_attributes["height"]
track_id = region_attributes["track"]

frame_number = int(filename.split("_")[-1].split(".")[0]) - 1
frame_number = int(filename.split("_")[-1].split(".")[0])
return {
"frame_number": frame_number,
"x": x,
Expand Down
6 changes: 3 additions & 3 deletions tests/data/gt_test.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
filename,file_size,file_attributes,region_count,region_id,region_shape_attributes,region_attributes
frame_00000001.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":2894.860594987354,""y"":975.8516839863181,""width"":51,""height"":41}","{""track"":""2.0""}"
frame_00000001.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":940.6088870891139,""y"":1192.6369631796642,""width"":49,""height"":38}","{""track"":""1.0""}"
frame_00000002.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":940.6088870891139,""y"":1192.6369631796642,""width"":49,""height"":38}","{""track"":""2.0""}"
frame_00000011.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":2894.860594987354,""y"":975.8516839863181,""width"":51,""height"":41}","{""track"":""2.0""}"
frame_00000011.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":940.6088870891139,""y"":1192.6369631796642,""width"":49,""height"":38}","{""track"":""1.0""}"
frame_00000021.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":940.6088870891139,""y"":1192.6369631796642,""width"":49,""height"":38}","{""track"":""2.0""}"
Loading

0 comments on commit 99ed754

Please sign in to comment.