Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix draw_detection to work with batch_size > 1 #232

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 58 additions & 32 deletions crabs/detector/utils/visualization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from datetime import datetime
from typing import Any, Optional
from typing import Optional

import cv2
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -66,59 +66,67 @@ def draw_bbox(
)


def draw_detection(
imgs: list,
annotations: dict,
detections: Optional[dict[Any, Any]] = None,
score_threshold: Optional[float] = None,
) -> np.ndarray:
def draw_bboxes_on_images(
list_input_images, #: list[dict],
list_annotations, #: list[dict],
list_detections=None, #: Optional[list[dict[Any, Any]]] = None,
score_threshold=None, #: Optional[float] = None,
) -> list[np.ndarray]:
"""
Draw the results based on the detection.
Parameters
----------
imgs : list
list_input_images : list
List of images.
annotations : dict
Ground truth annotations.
detections : dict, optional
Detected objects.
list_annotations : dict
List of ground truth annotations.
list_detections : dict, optional
List of detections
score_threshold : float, optional
The confidence threshold for detection scores.
The confidence threshold for the detection scores.
Returns
-------
np.ndarray
Image(s) with bounding boxes drawn on them.
list[np.ndarray]
List of images with bounding boxes drawn on them.
"""
coco_list = COCO_INSTANCE_CATEGORY_NAMES
image_with_boxes = None
# coco_list = COCO_INSTANCE_CATEGORY_NAMES

list_images_with_boxes = []
for image, label, prediction in zip(
imgs, annotations, detections or [None] * len(imgs)
list_input_images,
list_annotations,
list_detections or [None] * len(list_input_images),
):
# prepare image
image = image.cpu().numpy().transpose(1, 2, 0)
image = (image * 255).astype("uint8")
image_with_boxes = image.copy()

# prepare annotations
target_boxes = [
[(i[0], i[1]), (i[2], i[3])]
for i in list(label["boxes"].detach().cpu().numpy())
]

# plot annotations
for i in range(len(target_boxes)):
draw_bbox(
image_with_boxes,
((target_boxes[i][0])[0], (target_boxes[i][0])[1]),
((target_boxes[i][1])[0], (target_boxes[i][1])[1]),
colour=(0, 255, 0),
)

# plot predictions
if prediction:
pred_score = list(prediction["scores"].detach().cpu().numpy())
pred_t = [pred_score.index(x) for x in pred_score][-1]

pred_class = [
coco_list[i]
COCO_INSTANCE_CATEGORY_NAMES[i]
for i in list(prediction["labels"].detach().cpu().numpy())
]

Expand Down Expand Up @@ -147,7 +155,11 @@ def draw_detection(
(0, 0, 255),
label_text,
)
return image_with_boxes

# append result to list
list_images_with_boxes.append(image_with_boxes)

return list_images_with_boxes


def save_images_with_boxes(
Expand Down Expand Up @@ -183,22 +195,36 @@ def save_images_with_boxes(

if not output_dir:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = f"results_{timestamp}"
output_dir = f"results_{timestamp}" # _score_th_{score_threshold}"
os.makedirs(output_dir, exist_ok=True)

with torch.no_grad():
imgs_id = 0
for imgs, annotations in test_dataloader:
imgs_id += 1
imgs = list(img.to(device) for img in imgs)

detections = trained_model(imgs)

image_with_boxes = draw_detection(
imgs, annotations, detections, score_threshold
for img_batch, annotation_batch in test_dataloader:
# pass images to device
img_batch_device = list(img.to(device) for img in img_batch)

# compute detections
detection_batch = trained_model(img_batch_device)
# breakpoint()

# draw detections and annotations on images
img_with_boxes_batch = draw_bboxes_on_images(
list_input_images=img_batch_device, # tuple?
list_annotations=annotation_batch, # tuple?
list_detections=detection_batch,
score_threshold=score_threshold,
)

cv2.imwrite(f"{output_dir}/imgs{imgs_id}.jpg", image_with_boxes)
# breakpoint()

# draw detections and bboxesc
for image_with_boxes, annotation in zip(
img_with_boxes_batch, annotation_batch
):
# save one image
cv2.imwrite(
f"{output_dir}/img_{annotation['image_id']}.jpg",
image_with_boxes,
)


def plot_sample(imgs: list, row_title: Optional[str] = None, **imshow_kwargs):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_unit/test_visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from crabs.detector.utils.visualization import (
draw_bbox,
draw_detection,
draw_bboxes_on_images,
save_images_with_boxes,
)

Expand Down Expand Up @@ -146,7 +146,7 @@ def test_draw_bbox_green(sample_image, top_left, bottom_right, color):
)
def test_draw_detection(annotations, detections):
imgs = [torch.rand(3, 100, 100)]
image_with_boxes = draw_detection(imgs, annotations, detections)
image_with_boxes = draw_bboxes_on_images(imgs, annotations, detections)
assert image_with_boxes is not None


Expand Down