Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MOTA revisited #181

Merged
merged 57 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
73ea420
modify id switches
nikk-nikaznan Jun 5, 2024
768b81e
change the mota and test
nikk-nikaznan Jun 5, 2024
3036b06
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 5, 2024
f18a44c
change the variable
nikk-nikaznan Jun 14, 2024
6c0b0ed
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 18, 2024
a851e9f
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 21, 2024
611bd60
some bug fixed, load from checkpoint
nikk-nikaznan Jun 24, 2024
11bded7
Merge branch 'nikkna/id_switches' of github.com:SainsburyWellcomeCent…
nikk-nikaznan Jun 24, 2024
4b9a794
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 24, 2024
cfe67ca
change list type, add gt_ids
nikk-nikaznan Jun 25, 2024
7de7561
fixed the error id switches
nikk-nikaznan Jun 25, 2024
25f36c7
changes id switches
nikk-nikaznan Jun 25, 2024
de0b9cd
fixing some test and type hint
nikk-nikaznan Jun 27, 2024
5e56e1a
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 27, 2024
6aa1b63
fixing test, parametrize the test with additional test
nikk-nikaznan Jun 28, 2024
05faa18
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jun 28, 2024
a592c1a
cleane dup
nikk-nikaznan Jun 28, 2024
eefcf33
checking some test
nikk-nikaznan Jun 28, 2024
f0bcf65
rebase
nikk-nikaznan Jun 28, 2024
da0c62b
cleaned up
nikk-nikaznan Jun 28, 2024
01cb0f4
test works
nikk-nikaznan Jun 28, 2024
3ca1872
test works
nikk-nikaznan Jun 28, 2024
4d32f77
aded specific example
nikk-nikaznan Jun 28, 2024
ff059ea
some more test
nikk-nikaznan Jun 28, 2024
1195854
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 3, 2024
67fe295
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 3, 2024
9bf3a73
combine gt functions, fix test
nikk-nikaznan Jul 3, 2024
93915e1
rename test
nikk-nikaznan Jul 3, 2024
e1a8537
cleaned up linting
nikk-nikaznan Jul 3, 2024
d6401e1
adding some more description
nikk-nikaznan Jul 3, 2024
64ae8e8
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jul 3, 2024
73ae064
change the nested folder structure for output
nikk-nikaznan Jul 3, 2024
163cc06
adding device to cli
nikk-nikaznan Jul 4, 2024
cc6bbcf
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jul 5, 2024
f042b2b
attempt yesterday
nikk-nikaznan Jul 5, 2024
56ff81d
small changes in docstring
nikk-nikaznan Jul 5, 2024
73607b4
Update crabs/tracker/utils/io.py
nikk-nikaznan Jul 5, 2024
f8c91a9
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 5, 2024
21930ac
changes for gt dict
nikk-nikaznan Jul 5, 2024
0e8a687
predicted as dict
nikk-nikaznan Jul 5, 2024
6e12530
rename varibale, fix test
nikk-nikaznan Jul 5, 2024
36757a5
reviewing id switch
nikk-nikaznan Jul 5, 2024
5b5e1de
commented out the test that fail
nikk-nikaznan Jul 5, 2024
e8f4446
commented out the test that fail
nikk-nikaznan Jul 5, 2024
a464d15
seems working
nikk-nikaznan Jul 5, 2024
98e77c3
small modification for the test
nikk-nikaznan Jul 5, 2024
41ab8cc
cleaned up
nikk-nikaznan Jul 5, 2024
491ae68
cleaned up
nikk-nikaznan Jul 8, 2024
43e3c86
Merge branch 'main' into nikkna/id_switches
nikk-nikaznan Jul 8, 2024
5e11991
Update crabs/tracker/track_video.py
nikk-nikaznan Jul 9, 2024
8b8b9c9
Update crabs/tracker/track_video.py
nikk-nikaznan Jul 9, 2024
cc1e8c6
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 9, 2024
fcf6e66
Update crabs/tracker/track_video.py
nikk-nikaznan Jul 9, 2024
7a2ba9f
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 9, 2024
fa7fa08
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 9, 2024
24935ae
fixed frame_number vs frame_idx
nikk-nikaznan Jul 9, 2024
1d45dd6
Update crabs/tracker/evaluate_tracker.py
nikk-nikaznan Jul 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crabs/detector/config/faster_rcnn.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ num_classes: 2
# -------------------------------
# Training & validation parameters
# -------------------------------
n_epochs: 250
n_epochs: 1
learning_rate: 0.00005
wdecay: 0.00005
batch_size_train: 4
Expand Down
348 changes: 207 additions & 141 deletions crabs/tracker/evaluate_tracker.py

Large diffs are not rendered by default.

98 changes: 49 additions & 49 deletions crabs/tracker/track_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
)
from crabs.tracker.utils.tracking import prep_sort

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Tracking:
"""
Expand All @@ -47,58 +45,49 @@ class Tracking:

def __init__(self, args: argparse.Namespace) -> None:
self.args = args

self.config_file = args.config_file
self.load_config_yaml() # TODO: load config from trained model (like in evaluation)?

self.video_path = args.video_path
self.video_file_root = f"{Path(self.video_path).stem}"
self.trained_model_path = self.args.trained_model_path
self.device = self.args.device

self.trained_model = self.load_trained_model()
self.setup()
self.prep_outputs()

self.sort_tracker = Sort(
max_age=self.config["max_age"],
min_hits=self.config["min_hits"],
iou_threshold=self.config["iou_threshold"],
)

(
self.csv_writer,
self.csv_file,
self.tracking_output_dir,
) = prep_csv_writer(self.args.output_dir, self.video_file_root)

def load_config_yaml(self):
def setup(self):
"""
Load yaml file that contains config parameters.
Load tracking config, trained model and input video path.
"""
with open(self.config_file, "r") as f:
self.config = yaml.safe_load(f)

def load_trained_model(self) -> torch.nn.Module:
"""
Load the trained model.

Returns
-------
torch.nn.Module
"""
# Get trained model
trained_model = FasterRCNN.load_from_checkpoint(
self.trained_model = FasterRCNN.load_from_checkpoint(
self.trained_model_path
)
trained_model.eval()
trained_model.to(DEVICE) # Should device be a CLI?
return trained_model
self.trained_model.eval()
self.trained_model.to(self.device)

def load_video(self) -> None:
"""
Load the input video, and prepare the output video if required.
"""
# Load the input video
self.video = cv2.VideoCapture(self.video_path)
if not self.video.isOpened():
raise Exception("Error opening video file")
self.video_file_root = f"{Path(self.video_path).stem}"

def prep_outputs(self):
nikk-nikaznan marked this conversation as resolved.
Show resolved Hide resolved
"""
Prepare csv writer and if required, video writer.
"""
(
self.csv_writer,
self.csv_file,
self.tracking_output_dir,
) = prep_csv_writer(self.args.output_dir, self.video_file_root)

if self.args.save_video:
frame_width = int(self.video.get(cv2.CAP_PROP_FRAME_WIDTH))
Expand All @@ -107,7 +96,6 @@ def load_video(self) -> None:

self.video_output = prep_video_writer(
self.tracking_output_dir,
self.video_file_root,
frame_width,
frame_height,
cap_fps,
Expand Down Expand Up @@ -135,7 +123,7 @@ def get_prediction(self, frame: np.ndarray) -> torch.Tensor:
transforms.ToDtype(torch.float32, scale=True),
]
)
img = transform(frame).to(DEVICE)
img = transform(frame).to(self.device)
img = img.unsqueeze(0)
with torch.no_grad():
prediction = self.trained_model(img)
Expand All @@ -156,9 +144,9 @@ def update_tracking(self, prediction: dict) -> list[list[float]]:
list of tracked bounding boxes after updating the tracking system.
"""
pred_sort = prep_sort(prediction, self.config["score_threshold"])
tracked_boxes = self.sort_tracker.update(pred_sort)
self.tracked_list.append(tracked_boxes)
return tracked_boxes
tracked_boxes_id_per_frame = self.sort_tracker.update(pred_sort)
self.tracked_bbox_id.append(tracked_boxes_id_per_frame)
return tracked_boxes_id_per_frame

def run_tracking(self):
"""
Expand All @@ -171,52 +159,59 @@ def run_tracking(self):
)
return

# In any case run inference
# initialisation
frame_number = 1
self.tracked_list = []
frame_idx = 0
self.tracked_bbox_id = []

# Loop through frames of the video in batches
while self.video.isOpened():
# Break if beyond end frame (mostly for debugging)
if (
self.args.max_frames_to_read
and frame_number > self.args.max_frames_to_read
and frame_idx + 1 > self.args.max_frames_to_read
):
break
nikk-nikaznan marked this conversation as resolved.
Show resolved Hide resolved

# get total n frames
total_frames = int(self.video.get(cv2.CAP_PROP_FRAME_COUNT))

# read frame
ret, frame = self.video.read()
if not ret:
print("No frame read. Exiting...")
if not ret and (frame_idx == total_frames):
logging.info(f"All {total_frames} frames processed")
break
elif not ret:
logging.info(
f"Cannot read frame {frame_idx+1}/{total_frames}. Exiting..."
)
break

# predict bounding boxes
prediction = self.get_prediction(frame)
pred_scores = prediction[0]["scores"].detach().cpu().numpy()

# run tracking
tracked_boxes = self.update_tracking(prediction)
tracked_boxes_id_per_frame = self.update_tracking(prediction)
save_required_output(
self.video_file_root,
self.args.save_frames,
self.tracking_output_dir,
self.csv_writer,
self.args.save_video,
self.video_output,
tracked_boxes,
tracked_boxes_id_per_frame,
frame,
frame_number,
frame_idx + 1,
pred_scores,
)

# update frame number
frame_number += 1
frame_idx += 1
nikk-nikaznan marked this conversation as resolved.
Show resolved Hide resolved

if self.args.gt_path:
evaluation = TrackerEvaluate(
self.args.gt_path,
self.tracked_list,
self.tracked_bbox_id,
self.config["iou_threshold"],
)
evaluation.run_evaluation()
Expand Down Expand Up @@ -247,7 +242,6 @@ def main(args) -> None:
"""

inference = Tracking(args)
inference.load_video()
inference.run_tracking()


Expand Down Expand Up @@ -277,7 +271,7 @@ def tracking_parse_args(args):
parser.add_argument(
"--output_dir",
type=str,
default="crabs_track_output",
default="tracking_output",
help="Directory to save the track output", # is this a csv or a video? (or both)
)
parser.add_argument(
Expand Down Expand Up @@ -305,6 +299,12 @@ def tracking_parse_args(args):
action="store_true",
help="Save frame to be used in correcting track labelling",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
help="device for pytorch either cpu or cuda",
)
return parser.parse_args(args)


Expand Down
10 changes: 5 additions & 5 deletions crabs/tracker/utils/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import csv
import os
from datetime import datetime
from pathlib import Path

import cv2
Expand Down Expand Up @@ -29,13 +30,13 @@ def prep_csv_writer(output_dir: str, video_file_root: str):
A tuple containing the CSV writer, the CSV file object, and the tracking output directory path.
"""

crabs_tracks_label_dir = Path(output_dir) / "crabs_tracks_label"
tracking_output_dir = crabs_tracks_label_dir / video_file_root
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
tracking_output_dir = Path(output_dir + f"_{timestamp}") / video_file_root
# Create the subdirectory for the specific video file root
tracking_output_dir.mkdir(parents=True, exist_ok=True)

csv_file = open(
f"{str(tracking_output_dir / video_file_root)}.csv",
f"{str(tracking_output_dir)}/predicted_tracks.csv",
"w",
)
csv_writer = csv.writer(csv_file)
Expand All @@ -59,7 +60,6 @@ def prep_csv_writer(output_dir: str, video_file_root: str):

def prep_video_writer(
output_dir: str,
video_file_root: str,
frame_width: int,
frame_height: int,
cap_fps: float,
Expand Down Expand Up @@ -87,7 +87,7 @@ def prep_video_writer(
"""
output_file = os.path.join(
output_dir,
f"{os.path.basename(video_file_root)}_output_video.mp4",
"tracked_video.mp4",
)
output_codec = cv2.VideoWriter_fourcc("m", "p", "4", "v")
video_output = cv2.VideoWriter(
Expand Down
2 changes: 1 addition & 1 deletion crabs/tracker/utils/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def extract_bounding_box_info(row: list[str]) -> Dict[str, Any]:
height = region_shape_attributes["height"]
track_id = region_attributes["track"]

frame_number = int(filename.split("_")[-1].split(".")[0]) - 1
frame_number = int(filename.split("_")[-1].split(".")[0])
return {
"frame_number": frame_number,
"x": x,
Expand Down
6 changes: 3 additions & 3 deletions tests/data/gt_test.csv
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
filename,file_size,file_attributes,region_count,region_id,region_shape_attributes,region_attributes
frame_00000001.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":2894.860594987354,""y"":975.8516839863181,""width"":51,""height"":41}","{""track"":""2.0""}"
frame_00000001.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":940.6088870891139,""y"":1192.6369631796642,""width"":49,""height"":38}","{""track"":""1.0""}"
frame_00000002.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":940.6088870891139,""y"":1192.6369631796642,""width"":49,""height"":38}","{""track"":""2.0""}"
frame_00000011.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":2894.860594987354,""y"":975.8516839863181,""width"":51,""height"":41}","{""track"":""2.0""}"
frame_00000011.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":940.6088870891139,""y"":1192.6369631796642,""width"":49,""height"":38}","{""track"":""1.0""}"
frame_00000021.png,26542080,"{""clip"":123}",1,0,"{""name"":""rect"",""x"":940.6088870891139,""y"":1192.6369631796642,""width"":49,""height"":38}","{""track"":""2.0""}"
Loading