Skip to content

Commit

Permalink
Fix draw_detection to work with batch_size > 1
Browse files Browse the repository at this point in the history
  • Loading branch information
sfmig committed Oct 28, 2024
1 parent 5d58d85 commit 83f61bb
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 34 deletions.
90 changes: 58 additions & 32 deletions crabs/detector/utils/visualization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from datetime import datetime
from typing import Any, Optional
from typing import Optional

import cv2
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -66,59 +66,67 @@ def draw_bbox(
)


def draw_detection(
imgs: list,
annotations: dict,
detections: Optional[dict[Any, Any]] = None,
score_threshold: Optional[float] = None,
) -> np.ndarray:
def draw_bboxes_on_images(
list_input_images, #: list[dict],
list_annotations, #: list[dict],
list_detections=None, #: Optional[list[dict[Any, Any]]] = None,
score_threshold=None, #: Optional[float] = None,
) -> list[np.ndarray]:
"""
Draw the results based on the detection.
Parameters
----------
imgs : list
list_input_images : list
List of images.
annotations : dict
Ground truth annotations.
detections : dict, optional
Detected objects.
list_annotations : dict
List of ground truth annotations.
list_detections : dict, optional
List of detections
score_threshold : float, optional
The confidence threshold for detection scores.
The confidence threshold for the detection scores.
Returns
-------
np.ndarray
Image(s) with bounding boxes drawn on them.
list[np.ndarray]
List of images with bounding boxes drawn on them.
"""
coco_list = COCO_INSTANCE_CATEGORY_NAMES
image_with_boxes = None
# coco_list = COCO_INSTANCE_CATEGORY_NAMES

list_images_with_boxes = []
for image, label, prediction in zip(
imgs, annotations, detections or [None] * len(imgs)
list_input_images,
list_annotations,
list_detections or [None] * len(list_input_images),
):
# prepare image
image = image.cpu().numpy().transpose(1, 2, 0)
image = (image * 255).astype("uint8")
image_with_boxes = image.copy()

# prepare annotations
target_boxes = [
[(i[0], i[1]), (i[2], i[3])]
for i in list(label["boxes"].detach().cpu().numpy())
]

# plot annotations
for i in range(len(target_boxes)):
draw_bbox(
image_with_boxes,
((target_boxes[i][0])[0], (target_boxes[i][0])[1]),
((target_boxes[i][1])[0], (target_boxes[i][1])[1]),
colour=(0, 255, 0),
)

# plot predictions
if prediction:
pred_score = list(prediction["scores"].detach().cpu().numpy())
pred_t = [pred_score.index(x) for x in pred_score][-1]

pred_class = [
coco_list[i]
COCO_INSTANCE_CATEGORY_NAMES[i]
for i in list(prediction["labels"].detach().cpu().numpy())
]

Expand Down Expand Up @@ -147,7 +155,11 @@ def draw_detection(
(0, 0, 255),
label_text,
)
return image_with_boxes

# append result to list
list_images_with_boxes.append(image_with_boxes)

return list_images_with_boxes


def save_images_with_boxes(
Expand Down Expand Up @@ -183,22 +195,36 @@ def save_images_with_boxes(

if not output_dir:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = f"results_{timestamp}"
output_dir = f"results_{timestamp}" # _score_th_{score_threshold}"
os.makedirs(output_dir, exist_ok=True)

with torch.no_grad():
imgs_id = 0
for imgs, annotations in test_dataloader:
imgs_id += 1
imgs = list(img.to(device) for img in imgs)

detections = trained_model(imgs)

image_with_boxes = draw_detection(
imgs, annotations, detections, score_threshold
for img_batch, annotation_batch in test_dataloader:
# pass images to device
img_batch_device = list(img.to(device) for img in img_batch)

# compute detections
detection_batch = trained_model(img_batch_device)
# breakpoint()

# draw detections and annotations on images
img_with_boxes_batch = draw_bboxes_on_images(
list_input_images=img_batch_device, # tuple?
list_annotations=annotation_batch, # tuple?
list_detections=detection_batch,
score_threshold=score_threshold,
)

cv2.imwrite(f"{output_dir}/imgs{imgs_id}.jpg", image_with_boxes)
# breakpoint()

# draw detections and bboxesc
for image_with_boxes, annotation in zip(
img_with_boxes_batch, annotation_batch
):
# save one image
cv2.imwrite(
f"{output_dir}/img_{annotation['image_id']}.jpg",
image_with_boxes,
)


def plot_sample(imgs: list, row_title: Optional[str] = None, **imshow_kwargs):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_unit/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from crabs.detector.utils.visualization import (
draw_bbox,
draw_detection,
draw_bboxes_on_images,
save_images_with_boxes,
)

Expand Down Expand Up @@ -146,7 +146,7 @@ def test_draw_bbox_green(sample_image, top_left, bottom_right, color):
)
def test_draw_detection(annotations, detections):
imgs = [torch.rand(3, 100, 100)]
image_with_boxes = draw_detection(imgs, annotations, detections)
image_with_boxes = draw_bboxes_on_images(imgs, annotations, detections)
assert image_with_boxes is not None


Expand Down

0 comments on commit 83f61bb

Please sign in to comment.