From 00d6490bab4504df9975d49f159ab755284a5b9a Mon Sep 17 00:00:00 2001 From: ManoleAlexandru99 Date: Tue, 11 Apr 2023 14:24:12 +0300 Subject: [PATCH] Added Dropout #0012 lead to overfitting. We attempt to address this in the added segmentation path --- models/common.py | 7 +++++++ train.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/models/common.py b/models/common.py index 5637f353480f..ecc0dc3a6156 100644 --- a/models/common.py +++ b/models/common.py @@ -856,16 +856,23 @@ def __init__(self, in_channels): self.cv2 = Conv(32, 64, k=3) self.cv3 = Conv(64, 1, act=False) self.relu = nn.ReLU() + + self.dropout_weak= nn.Dropout(0.25) + self.dropout_normal = nn.Dropout(0.5) # self.sigmoid = nn.Sigmoid() def forward(self, x): # print('----entry shape', x.shape, '---\n') + x = self.dropout_weak(x) x = self.cv1(x) x = self.upsample(x) # x = self.relu(x) # print('----upsample shape', x.shape, '---\n') + x = self.dropout_normal(x) x = self.cv2(x) x = self.upsample(x) + + x = self.dropout_normal(x) # x = self.relu(x) x = self.cv3(x) # print('----out shape', x.shape, '---\n') diff --git a/train.py b/train.py index e725543e70ec..f43f9166b86f 100644 --- a/train.py +++ b/train.py @@ -536,7 +536,7 @@ def parse_opt(known=False): def main(opt, callbacks=Callbacks()): - print('\n---------- VERSION:', '#0012', '----------\n') + print('\n---------- VERSION:', '#0013', '----------\n') # Checks if RANK in {-1, 0}: print_args(vars(opt))