diff --git a/hubconf.py b/hubconf.py index 98416a5b8563..ba150e0d0699 100644 --- a/hubconf.py +++ b/hubconf.py @@ -10,6 +10,7 @@ import torch +from models.common import NMS from models.yolo import Model from utils.google_utils import attempt_download @@ -35,6 +36,12 @@ def create(name, pretrained, channels, classes): state_dict = torch.load(ckpt, map_location=torch.device('cpu'))['model'].float().state_dict() # to FP32 state_dict = {k: v for k, v in state_dict.items() if model.state_dict()[k].shape == v.shape} # filter model.load_state_dict(state_dict, strict=False) # load + + m = NMS() + m.f = -1 # from + m.i = model.model[-1].i + 1 # index + model.model.add_module(name='%s' % m.i, module=m) # add NMS + model.eval() return model except Exception as e: diff --git a/models/common.py b/models/common.py index e8c07f4db657..314c31f91aac 100644 --- a/models/common.py +++ b/models/common.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn +from utils.general import non_max_suppression def autopad(k, p=None): # kernel, padding @@ -98,6 +99,19 @@ def forward(self, x): return torch.cat(x, self.d) +class NMS(nn.Module): + # Non-Maximum Suppression (NMS) module + conf = 0.3 # confidence threshold + iou = 0.6 # IoU threshold + classes = None # (optional list) filter by class + + def __init__(self, dimension=1): + super(NMS, self).__init__() + + def forward(self, x): + return non_max_suppression(x[0], conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) + + class Flatten(nn.Module): # Use after nn.AdaptiveAvgPool2d(1) to remove last 2 dimensions @staticmethod