import torch import re import numpy as np import torch.nn as nn import torch.nn.functional as F import logging import cv2 import math import itertools import collections from torchvision.ops import nms GlobalParams = collections.namedtuple('GlobalParams', [ 'batch_norm_momentum', 'batch_norm_epsilon', 'dropout_rate', 'num_classes', 'width_coefficient', 'depth_coefficient', 'depth_divisor', 'min_depth', 'drop_connect_rate', 'image_size']) # Parameters for an individual model block BlockArgs = collections.namedtuple('BlockArgs', [ 'kernel_size', 'num_repeat', 'input_filters', 'output_filters', 'expand_ratio', 'id_skip', 'stride', 'se_ratio']) # https://stackoverflow.com/a/18348004 # Change namedtuple defaults GlobalParams.__new__.__defaults__ = (None,) * len(GlobalParams._fields) BlockArgs.__new__.__defaults__ = (None,) * len(BlockArgs._fields) # in the old version, g_simple_padding = False, which tries to align # tensorflow's implementation, which is not required here. g_simple_padding = True class MaxPool2dStaticSamePadding(nn.Module): """ created by Zylo117 The real keras/tensorflow MaxPool2d with same padding """ def __init__(self, kernel_size, stride): super().__init__() if g_simple_padding: self.pool = nn.MaxPool2d(kernel_size, stride, padding=(kernel_size-1)//2) else: assert ValueError() self.pool = nn.MaxPool2d(kernel_size, stride) self.stride = self.pool.stride self.kernel_size = self.pool.kernel_size if isinstance(self.stride, int): self.stride = [self.stride] * 2 elif len(self.stride) == 1: self.stride = [self.stride[0]] * 2 if isinstance(self.kernel_size, int): self.kernel_size = [self.kernel_size] * 2 elif len(self.kernel_size) == 1: self.kernel_size = [self.kernel_size[0]] * 2 def forward(self, x): if g_simple_padding: return self.pool(x) else: assert ValueError() h, w = x.shape[-2:] h_step = math.ceil(w / self.stride[1]) v_step = math.ceil(h / self.stride[0]) h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1) v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1) extra_h = h_cover_len - w extra_v = v_cover_len - h left = extra_h // 2 right = extra_h - left top = extra_v // 2 bottom = extra_v - top x = F.pad(x, [left, right, top, bottom]) x = self.pool(x) return x class Conv2dStaticSamePadding(nn.Module): """ created by Zylo117 The real keras/tensorflow conv2d with same padding """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, groups=1, dilation=1, **kwargs): super().__init__() if g_simple_padding: assert kernel_size % 2 == 1 assert dilation == 1 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, bias=bias, groups=groups, padding=(kernel_size - 1) // 2) self.stride = self.conv.stride if isinstance(self.stride, int): self.stride = [self.stride] * 2 elif len(self.stride) == 1: self.stride = [self.stride[0]] * 2 else: self.stride = list(self.stride) else: assert ValueError() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, bias=bias, groups=groups) self.stride = self.conv.stride self.kernel_size = self.conv.kernel_size self.dilation = self.conv.dilation if isinstance(self.stride, int): self.stride = [self.stride] * 2 elif len(self.stride) == 1: self.stride = [self.stride[0]] * 2 if isinstance(self.kernel_size, int): self.kernel_size = [self.kernel_size] * 2 elif len(self.kernel_size) == 1: self.kernel_size = [self.kernel_size[0]] * 2 def forward(self, x): if g_simple_padding: return self.conv(x) else: assert ValueError() h, w = x.shape[-2:] h_step = math.ceil(w / self.stride[1]) v_step = math.ceil(h / self.stride[0]) h_cover_len = self.stride[1] * (h_step - 1) + 1 + (self.kernel_size[1] - 1) v_cover_len = self.stride[0] * (v_step - 1) + 1 + (self.kernel_size[0] - 1) extra_h = h_cover_len - w extra_v = v_cover_len - h left = extra_h // 2 right = extra_h - left top = extra_v // 2 bottom = extra_v - top x = F.pad(x, [left, right, top, bottom]) x = self.conv(x) return x class SeparableConvBlock(nn.Module): """ created by Zylo117 """ def __init__(self, in_channels, out_channels=None, norm=True, activation=False, onnx_export=False): super(SeparableConvBlock, self).__init__() if out_channels is None: out_channels = in_channels # Q: whether separate conv # share bias between depthwise_conv and pointwise_conv # or just pointwise_conv apply bias. # A: Confirmed, just pointwise_conv applies bias, depthwise_conv has no bias. self.depthwise_conv = Conv2dStaticSamePadding(in_channels, in_channels, kernel_size=3, stride=1, groups=in_channels, bias=False) self.pointwise_conv = Conv2dStaticSamePadding(in_channels, out_channels, kernel_size=1, stride=1) self.norm = norm if self.norm: # Warning: pytorch momentum is different from tensorflow's, momentum_pytorch = 1 - momentum_tensorflow self.bn = nn.BatchNorm2d(num_features=out_channels, momentum=0.01, eps=1e-3) self.activation = activation if self.activation: self.swish = MemoryEfficientSwish() if not onnx_export else Swish() def forward(self, x): x = self.depthwise_conv(x) x = self.pointwise_conv(x) if self.norm: x = self.bn(x) if self.activation: x = self.swish(x) return x class BiFPN(nn.Module): """ modified by Zylo117 """ def __init__(self, num_channels, conv_channels, first_time=False, epsilon=1e-4, onnx_export=False, attention=True, adaptive_up=False): """ Args: num_channels: conv_channels: first_time: whether the input comes directly from the efficientnet, if True, downchannel it first, and downsample P5 to generate P6 then P7 epsilon: epsilon of fast weighted attention sum of BiFPN, not the BN's epsilon onnx_export: if True, use Swish instead of MemoryEfficientSwish """ super(BiFPN, self).__init__() self.epsilon = epsilon # Conv layers self.conv6_up = SeparableConvBlock(num_channels, onnx_export=onnx_export) self.conv5_up = SeparableConvBlock(num_channels, onnx_export=onnx_export) self.conv4_up = SeparableConvBlock(num_channels, onnx_export=onnx_export) self.conv3_up = SeparableConvBlock(num_channels, onnx_export=onnx_export) self.conv4_down = SeparableConvBlock(num_channels, onnx_export=onnx_export) self.conv5_down = SeparableConvBlock(num_channels, onnx_export=onnx_export) self.conv6_down = SeparableConvBlock(num_channels, onnx_export=onnx_export) self.conv7_down = SeparableConvBlock(num_channels, onnx_export=onnx_export) # Feature scaling layers self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest') self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest') self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest') self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest') self.adaptive_up = adaptive_up self.p4_downsample = MaxPool2dStaticSamePadding(3, 2) self.p5_downsample = MaxPool2dStaticSamePadding(3, 2) self.p6_downsample = MaxPool2dStaticSamePadding(3, 2) self.p7_downsample = MaxPool2dStaticSamePadding(3, 2) self.swish = MemoryEfficientSwish() if not onnx_export else Swish() self.first_time = first_time if self.first_time: self.p5_down_channel = nn.Sequential( Conv2dStaticSamePadding(conv_channels[2], num_channels, 1), nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), ) self.p4_down_channel = nn.Sequential( Conv2dStaticSamePadding(conv_channels[1], num_channels, 1), nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), ) self.p3_down_channel = nn.Sequential( Conv2dStaticSamePadding(conv_channels[0], num_channels, 1), nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), ) if len(conv_channels) == 3: self.p5_to_p6 = nn.Sequential( Conv2dStaticSamePadding(conv_channels[2], num_channels, 1), nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), MaxPool2dStaticSamePadding(3, 2) ) else: assert len(conv_channels) == 4 self.p6_down_channel = nn.Sequential( Conv2dStaticSamePadding(conv_channels[3], num_channels, 1), nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), ) self.p6_to_p7 = nn.Sequential( MaxPool2dStaticSamePadding(3, 2) ) self.p4_down_channel_2 = nn.Sequential( Conv2dStaticSamePadding(conv_channels[1], num_channels, 1), nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), ) self.p5_down_channel_2 = nn.Sequential( Conv2dStaticSamePadding(conv_channels[2], num_channels, 1), nn.BatchNorm2d(num_channels, momentum=0.01, eps=1e-3), ) # Weight self.p6_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) self.p6_w1_relu = nn.ReLU() self.p5_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) self.p5_w1_relu = nn.ReLU() self.p4_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) self.p4_w1_relu = nn.ReLU() self.p3_w1 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) self.p3_w1_relu = nn.ReLU() self.p4_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) self.p4_w2_relu = nn.ReLU() self.p5_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) self.p5_w2_relu = nn.ReLU() self.p6_w2 = nn.Parameter(torch.ones(3, dtype=torch.float32), requires_grad=True) self.p6_w2_relu = nn.ReLU() self.p7_w2 = nn.Parameter(torch.ones(2, dtype=torch.float32), requires_grad=True) self.p7_w2_relu = nn.ReLU() self.attention = attention def forward(self, inputs): """ illustration of a minimal bifpn unit P7_0 -------------------------> P7_2 --------> |-------------| ↑ ↓ | P6_0 ---------> P6_1 ---------> P6_2 --------> |-------------|--------------↑ ↑ ↓ | P5_0 ---------> P5_1 ---------> P5_2 --------> |-------------|--------------↑ ↑ ↓ | P4_0 ---------> P4_1 ---------> P4_2 --------> |-------------|--------------↑ ↑ |--------------↓ | P3_0 -------------------------> P3_2 --------> """ # downsample channels using same-padding conv2d to target phase's if not the same # judge: same phase as target, # if same, pass; # elif earlier phase, downsample to target phase's by pooling # elif later phase, upsample to target phase's by nearest interpolation if self.attention: p3_out, p4_out, p5_out, p6_out, p7_out = self._forward_fast_attention(inputs) else: p3_out, p4_out, p5_out, p6_out, p7_out = self._forward(inputs) return p3_out, p4_out, p5_out, p6_out, p7_out def _forward_fast_attention(self, inputs): if self.first_time: if len(inputs) == 3: p3, p4, p5 = inputs p6_in = self.p5_to_p6(p5) else: p3, p4, p5, p6 = inputs p6_in = self.p6_down_channel(p6) p7_in = self.p6_to_p7(p6_in) p3_in = self.p3_down_channel(p3) p4_in = self.p4_down_channel(p4) p5_in = self.p5_down_channel(p5) else: # P3_0, P4_0, P5_0, P6_0 and P7_0 p3_in, p4_in, p5_in, p6_in, p7_in = inputs # P7_0 to P7_2 if not self.adaptive_up: # Weights for P6_0 and P7_0 to P6_1 p6_w1 = self.p6_w1_relu(self.p6_w1) weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon) # Connections for P6_0 and P7_0 to P6_1 respectively p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * self.p6_upsample(p7_in))) # Weights for P5_0 and P6_0 to P5_1 p5_w1 = self.p5_w1_relu(self.p5_w1) weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon) # Connections for P5_0 and P6_0 to P5_1 respectively p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * self.p5_upsample(p6_up))) # Weights for P4_0 and P5_0 to P4_1 p4_w1 = self.p4_w1_relu(self.p4_w1) weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon) # Connections for P4_0 and P5_0 to P4_1 respectively p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * self.p4_upsample(p5_up))) # Weights for P3_0 and P4_1 to P3_2 p3_w1 = self.p3_w1_relu(self.p3_w1) weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon) # Connections for P3_0 and P4_1 to P3_2 respectively p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * self.p3_upsample(p4_up))) else: # Weights for P6_0 and P7_0 to P6_1 p6_w1 = self.p6_w1_relu(self.p6_w1) weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon) # Connections for P6_0 and P7_0 to P6_1 respectively p6_upsample = nn.Upsample(size=p6_in.shape[-2:]) p6_up = self.conv6_up(self.swish(weight[0] * p6_in + weight[1] * p6_upsample(p7_in))) # Weights for P5_0 and P6_0 to P5_1 p5_w1 = self.p5_w1_relu(self.p5_w1) weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon) # Connections for P5_0 and P6_0 to P5_1 respectively p5_upsample = nn.Upsample(size=p5_in.shape[-2:]) p5_up = self.conv5_up(self.swish(weight[0] * p5_in + weight[1] * p5_upsample(p6_up))) # Weights for P4_0 and P5_0 to P4_1 p4_w1 = self.p4_w1_relu(self.p4_w1) weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon) # Connections for P4_0 and P5_0 to P4_1 respectively p4_upsample = nn.Upsample(size=p4_in.shape[-2:]) p4_up = self.conv4_up(self.swish(weight[0] * p4_in + weight[1] * p4_upsample(p5_up))) # Weights for P3_0 and P4_1 to P3_2 p3_w1 = self.p3_w1_relu(self.p3_w1) weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon) p3_upsample = nn.Upsample(size=p3_in.shape[-2:]) # Connections for P3_0 and P4_1 to P3_2 respectively p3_out = self.conv3_up(self.swish(weight[0] * p3_in + weight[1] * p3_upsample(p4_up))) if self.first_time: p4_in = self.p4_down_channel_2(p4) p5_in = self.p5_down_channel_2(p5) # Weights for P4_0, P4_1 and P3_2 to P4_2 p4_w2 = self.p4_w2_relu(self.p4_w2) weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon) # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively p4_out = self.conv4_down( self.swish(weight[0] * p4_in + weight[1] * p4_up + weight[2] * self.p4_downsample(p3_out))) # Weights for P5_0, P5_1 and P4_2 to P5_2 p5_w2 = self.p5_w2_relu(self.p5_w2) weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon) # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively p5_out = self.conv5_down( self.swish(weight[0] * p5_in + weight[1] * p5_up + weight[2] * self.p5_downsample(p4_out))) # Weights for P6_0, P6_1 and P5_2 to P6_2 p6_w2 = self.p6_w2_relu(self.p6_w2) weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon) # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively p6_out = self.conv6_down( self.swish(weight[0] * p6_in + weight[1] * p6_up + weight[2] * self.p6_downsample(p5_out))) # Weights for P7_0 and P6_2 to P7_2 p7_w2 = self.p7_w2_relu(self.p7_w2) weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon) # Connections for P7_0 and P6_2 to P7_2 p7_out = self.conv7_down(self.swish(weight[0] * p7_in + weight[1] * self.p7_downsample(p6_out))) return p3_out, p4_out, p5_out, p6_out, p7_out def _forward(self, inputs): if self.first_time: p3, p4, p5 = inputs p6_in = self.p5_to_p6(p5) p7_in = self.p6_to_p7(p6_in) p3_in = self.p3_down_channel(p3) p4_in = self.p4_down_channel(p4) p5_in = self.p5_down_channel(p5) else: # P3_0, P4_0, P5_0, P6_0 and P7_0 p3_in, p4_in, p5_in, p6_in, p7_in = inputs # P7_0 to P7_2 # Connections for P6_0 and P7_0 to P6_1 respectively p6_up = self.conv6_up(self.swish(p6_in + self.p6_upsample(p7_in))) # Connections for P5_0 and P6_0 to P5_1 respectively p5_up = self.conv5_up(self.swish(p5_in + self.p5_upsample(p6_up))) # Connections for P4_0 and P5_0 to P4_1 respectively p4_up = self.conv4_up(self.swish(p4_in + self.p4_upsample(p5_up))) # Connections for P3_0 and P4_1 to P3_2 respectively p3_out = self.conv3_up(self.swish(p3_in + self.p3_upsample(p4_up))) if self.first_time: p4_in = self.p4_down_channel_2(p4) p5_in = self.p5_down_channel_2(p5) # Connections for P4_0, P4_1 and P3_2 to P4_2 respectively p4_out = self.conv4_down( self.swish(p4_in + p4_up + self.p4_downsample(p3_out))) # Connections for P5_0, P5_1 and P4_2 to P5_2 respectively p5_out = self.conv5_down( self.swish(p5_in + p5_up + self.p5_downsample(p4_out))) # Connections for P6_0, P6_1 and P5_2 to P6_2 respectively p6_out = self.conv6_down( self.swish(p6_in + p6_up + self.p6_downsample(p5_out))) # Connections for P7_0 and P6_2 to P7_2 p7_out = self.conv7_down(self.swish(p7_in + self.p7_downsample(p6_out))) return p3_out, p4_out, p5_out, p6_out, p7_out class Regressor(nn.Module): """ modified by Zylo117 """ def __init__(self, in_channels, num_anchors, num_layers, onnx_export=False): super(Regressor, self).__init__() self.num_layers = num_layers self.num_layers = num_layers self.conv_list = nn.ModuleList( [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)]) self.bn_list = nn.ModuleList( [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in range(5)]) self.header = SeparableConvBlock(in_channels, num_anchors * 4, norm=False, activation=False) self.swish = MemoryEfficientSwish() if not onnx_export else Swish() def forward(self, inputs): feats = [] for feat, bn_list in zip(inputs, self.bn_list): for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list): feat = conv(feat) feat = bn(feat) feat = self.swish(feat) feat = self.header(feat) feat = feat.permute(0, 2, 3, 1) feat = feat.contiguous().view(feat.shape[0], -1, 4) feats.append(feat) feats = torch.cat(feats, dim=1) return feats class SwishImplementation(torch.autograd.Function): @staticmethod def forward(ctx, i): result = i * torch.sigmoid(i) ctx.save_for_backward(i) return result @staticmethod def backward(ctx, grad_output): i = ctx.saved_variables[0] sigmoid_i = torch.sigmoid(i) return grad_output * (sigmoid_i * (1 + i * (1 - sigmoid_i))) class MemoryEfficientSwish(nn.Module): def forward(self, x): if torch._C._get_tracing_state(): return x * torch.sigmoid(x) return SwishImplementation.apply(x) class Swish(nn.Module): def forward(self, x): return x * torch.sigmoid(x) class Classifier(nn.Module): """ modified by Zylo117 """ def __init__(self, in_channels, num_anchors, num_classes, num_layers, onnx_export=False, prior_prob=0.01): super(Classifier, self).__init__() self.num_anchors = num_anchors self.num_classes = num_classes self.num_layers = num_layers self.conv_list = nn.ModuleList( [SeparableConvBlock(in_channels, in_channels, norm=False, activation=False) for i in range(num_layers)]) self.bn_list = nn.ModuleList( [nn.ModuleList([nn.BatchNorm2d(in_channels, momentum=0.01, eps=1e-3) for i in range(num_layers)]) for j in range(5)]) self.header = SeparableConvBlock(in_channels, num_anchors * num_classes, norm=False, activation=False) prior_prob = prior_prob bias_value = -math.log((1 - prior_prob) / prior_prob) torch.nn.init.normal_(self.header.pointwise_conv.conv.weight, std=0.01) torch.nn.init.constant_(self.header.pointwise_conv.conv.bias, bias_value) self.swish = MemoryEfficientSwish() if not onnx_export else Swish() def forward(self, inputs): feats = [] for feat, bn_list in zip(inputs, self.bn_list): for i, bn, conv in zip(range(self.num_layers), bn_list, self.conv_list): feat = conv(feat) feat = bn(feat) feat = self.swish(feat) feat = self.header(feat) feat = feat.permute(0, 2, 3, 1) feat = feat.contiguous().view(feat.shape[0], feat.shape[1], feat.shape[2], self.num_anchors, self.num_classes) feat = feat.contiguous().view(feat.shape[0], -1, self.num_classes) feats.append(feat) feats = torch.cat(feats, dim=1) #feats = feats.sigmoid() return feats class Conv2dDynamicSamePadding(nn.Conv2d): """ 2D Convolutions like TensorFlow, for a dynamic image size """ def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, groups=1, bias=True): super().__init__(in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) raise ValueError('tend to be deprecated') self.stride = self.stride if len(self.stride) == 2 else [self.stride[0]] * 2 def forward(self, x): ih, iw = x.size()[-2:] kh, kw = self.weight.size()[-2:] sh, sw = self.stride oh, ow = math.ceil(ih / sh), math.ceil(iw / sw) pad_h = max((oh - 1) * self.stride[0] + (kh - 1) * self.dilation[0] + 1 - ih, 0) pad_w = max((ow - 1) * self.stride[1] + (kw - 1) * self.dilation[1] + 1 - iw, 0) if pad_h > 0 or pad_w > 0: x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) return F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) #TODO: it seems like the standard conv layer is good enough with proper padding # parameters. def get_same_padding_conv2d(image_size=None): """ Chooses static padding if you have specified an image size, and dynamic padding otherwise. Static padding is necessary for ONNX exporting of models. """ if image_size is None: raise ValueError('not validated') return Conv2dDynamicSamePadding else: from functools import partial return partial(Conv2dStaticSamePadding, image_size=image_size) def round_filters(filters, global_params): """ Calculate and round number of filters based on depth multiplier. """ multiplier = global_params.width_coefficient if not multiplier: return filters divisor = global_params.depth_divisor min_depth = global_params.min_depth filters *= multiplier min_depth = min_depth or divisor new_filters = max(min_depth, int(filters + divisor / 2) // divisor * divisor) if new_filters < 0.9 * filters: # prevent rounding by more than 10% new_filters += divisor return int(new_filters) def round_repeats(repeats, global_params): """ Round number of filters based on depth multiplier. """ multiplier = global_params.depth_coefficient if not multiplier: return repeats return int(math.ceil(multiplier * repeats)) def drop_connect(inputs, p, training): """ Drop connect. """ if not training: return inputs batch_size = inputs.shape[0] keep_prob = 1 - p random_tensor = keep_prob random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device) binary_tensor = torch.floor(random_tensor) output = inputs / keep_prob * binary_tensor return output class MBConvBlock(nn.Module): """ Mobile Inverted Residual Bottleneck Block Args: block_args (namedtuple): BlockArgs, see above global_params (namedtuple): GlobalParam, see above Attributes: has_se (bool): Whether the block contains a Squeeze and Excitation layer. """ def __init__(self, block_args, global_params): super().__init__() self._block_args = block_args self._bn_mom = 1 - global_params.batch_norm_momentum self._bn_eps = global_params.batch_norm_epsilon self.has_se = (self._block_args.se_ratio is not None) and (0 < self._block_args.se_ratio <= 1) self.id_skip = block_args.id_skip # skip connection and drop connect # Get static or dynamic convolution depending on image size Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) # Expansion phase inp = self._block_args.input_filters # number of input channels oup = self._block_args.input_filters * self._block_args.expand_ratio # number of output channels if self._block_args.expand_ratio != 1: self._expand_conv = Conv2d(in_channels=inp, out_channels=oup, kernel_size=1, bias=False) self._bn0 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) # Depthwise convolution phase k = self._block_args.kernel_size s = self._block_args.stride if isinstance(s, (tuple, list)) and all([s0 == s[0] for s0 in s]): s = s[0] self._depthwise_conv = Conv2d( in_channels=oup, out_channels=oup, groups=oup, # groups makes it depthwise kernel_size=k, stride=s, bias=False) self._bn1 = nn.BatchNorm2d(num_features=oup, momentum=self._bn_mom, eps=self._bn_eps) # Squeeze and Excitation layer, if desired if self.has_se: num_squeezed_channels = max(1, int(self._block_args.input_filters * self._block_args.se_ratio)) self._se_reduce = Conv2d(in_channels=oup, out_channels=num_squeezed_channels, kernel_size=1) self._se_expand = Conv2d(in_channels=num_squeezed_channels, out_channels=oup, kernel_size=1) # Output phase final_oup = self._block_args.output_filters self._project_conv = Conv2d(in_channels=oup, out_channels=final_oup, kernel_size=1, bias=False) self._bn2 = nn.BatchNorm2d(num_features=final_oup, momentum=self._bn_mom, eps=self._bn_eps) self._swish = MemoryEfficientSwish() def forward(self, inputs, drop_connect_rate=None): """ :param inputs: input tensor :param drop_connect_rate: drop connect rate (float, between 0 and 1) :return: output of block """ # Expansion and Depthwise Convolution x = inputs if self._block_args.expand_ratio != 1: x = self._expand_conv(inputs) x = self._bn0(x) x = self._swish(x) x = self._depthwise_conv(x) x = self._bn1(x) x = self._swish(x) # Squeeze and Excitation if self.has_se: x_squeezed = F.adaptive_avg_pool2d(x, 1) x_squeezed = self._se_reduce(x_squeezed) x_squeezed = self._swish(x_squeezed) x_squeezed = self._se_expand(x_squeezed) x = torch.sigmoid(x_squeezed) * x x = self._project_conv(x) x = self._bn2(x) # Skip connection and drop connect input_filters, output_filters = self._block_args.input_filters, self._block_args.output_filters if self.id_skip and self._block_args.stride == 1 and input_filters == output_filters: if drop_connect_rate: x = drop_connect(x, p=drop_connect_rate, training=self.training) x = x + inputs # skip connection return x def set_swish(self, memory_efficient=True): """Sets swish function as memory efficient (for training) or standard (for export)""" self._swish = MemoryEfficientSwish() if memory_efficient else Swish() class BlockDecoder(object): """ Block Decoder for readability, straight from the official TensorFlow repository """ @staticmethod def _decode_block_string(block_string): """ Gets a block through a string notation of arguments. """ assert isinstance(block_string, str) ops = block_string.split('_') options = {} for op in ops: splits = re.split(r'(\d.*)', op) if len(splits) >= 2: key, value = splits[:2] options[key] = value # Check stride assert (('s' in options and len(options['s']) == 1) or (len(options['s']) == 2 and options['s'][0] == options['s'][1])) return BlockArgs( kernel_size=int(options['k']), num_repeat=int(options['r']), input_filters=int(options['i']), output_filters=int(options['o']), expand_ratio=int(options['e']), id_skip=('noskip' not in block_string), se_ratio=float(options['se']) if 'se' in options else None, stride=[int(options['s'][0])]) @staticmethod def _encode_block_string(block): """Encodes a block to a string.""" args = [ 'r%d' % block.num_repeat, 'k%d' % block.kernel_size, 's%d%d' % (block.strides[0], block.strides[1]), 'e%s' % block.expand_ratio, 'i%d' % block.input_filters, 'o%d' % block.output_filters ] if 0 < block.se_ratio <= 1: args.append('se%s' % block.se_ratio) if block.id_skip is False: args.append('noskip') return '_'.join(args) @staticmethod def decode(string_list): """ Decodes a list of string notations to specify blocks inside the network. :param string_list: a list of strings, each string is a notation of block :return: a list of BlockArgs namedtuples of block args """ assert isinstance(string_list, list) blocks_args = [] for block_string in string_list: blocks_args.append(BlockDecoder._decode_block_string(block_string)) return blocks_args @staticmethod def encode(blocks_args): """ Encodes a list of BlockArgs to a list of strings. :param blocks_args: a list of BlockArgs namedtuples of block args :return: a list of strings, each string is a notation of block """ block_strings = [] for block in blocks_args: block_strings.append(BlockDecoder._encode_block_string(block)) return block_strings def efficientnet(width_coefficient=None, depth_coefficient=None, dropout_rate=0.2, drop_connect_rate=0.2, image_size=None, num_classes=1000): """ Creates a efficientnet model. """ blocks_args = [ 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', 'r1_k3_s11_e6_i192_o320_se0.25', ] blocks_args = BlockDecoder.decode(blocks_args) global_params = GlobalParams( batch_norm_momentum=0.99, batch_norm_epsilon=1e-3, dropout_rate=dropout_rate, drop_connect_rate=drop_connect_rate, # data_format='channels_last', # removed, this is always true in PyTorch num_classes=num_classes, width_coefficient=width_coefficient, depth_coefficient=depth_coefficient, depth_divisor=8, min_depth=None, image_size=image_size, ) return blocks_args, global_params def efficientnet_params(model_name): """ Map EfficientNet model name to parameter coefficients. """ params_dict = { # Coefficients: width,depth,res,dropout 'efficientnet-b0': (1.0, 1.0, 224, 0.2), 'efficientnet-b1': (1.0, 1.1, 240, 0.2), 'efficientnet-b2': (1.1, 1.2, 260, 0.3), 'efficientnet-b3': (1.2, 1.4, 300, 0.3), 'efficientnet-b4': (1.4, 1.8, 380, 0.4), 'efficientnet-b5': (1.6, 2.2, 456, 0.4), 'efficientnet-b6': (1.8, 2.6, 528, 0.5), 'efficientnet-b7': (2.0, 3.1, 600, 0.5), 'efficientnet-b8': (2.2, 3.6, 672, 0.5), 'efficientnet-l2': (4.3, 5.3, 800, 0.5), } return params_dict[model_name] def get_model_params(model_name, override_params): """ Get the block args and global params for a given model """ if model_name.startswith('efficientnet'): w, d, s, p = efficientnet_params(model_name) # note: all models have drop connect rate = 0.2 blocks_args, global_params = efficientnet( width_coefficient=w, depth_coefficient=d, dropout_rate=p, image_size=s) else: raise NotImplementedError('model name is not pre-defined: %s' % model_name) if override_params: # ValueError will be raised here if override_params has fields not included in global_params. global_params = global_params._replace(**override_params) return blocks_args, global_params url_map = { 'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b0-355c32eb.pth', 'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b1-f1951068.pth', 'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b2-8bb594d6.pth', 'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b3-5fb5a3c3.pth', 'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b4-6ed6700e.pth', 'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b5-b6417697.pth', 'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b6-c76e70fd.pth', 'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/aa/efficientnet-b7-dcc49843.pth', } url_map_advprop = { 'efficientnet-b0': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b0-b64d5a18.pth', 'efficientnet-b1': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b1-0f3ce85a.pth', 'efficientnet-b2': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b2-6e9d97e5.pth', 'efficientnet-b3': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b3-cdd7c0f4.pth', 'efficientnet-b4': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b4-44fb3a87.pth', 'efficientnet-b5': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b5-86493f6b.pth', 'efficientnet-b6': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b6-ac80338e.pth', 'efficientnet-b7': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b7-4652b6dd.pth', 'efficientnet-b8': 'https://publicmodels.blob.core.windows.net/container/advprop/efficientnet-b8-22a8fe65.pth', } def load_pretrained_weights(model, model_name, load_fc=True, advprop=False): """ Loads pretrained weights, and downloads if loading for the first time. """ # AutoAugment or Advprop (different preprocessing) url_map_ = url_map_advprop if advprop else url_map from torch.utils import model_zoo state_dict = model_zoo.load_url(url_map_[model_name], map_location=torch.device('cpu')) # state_dict = torch.load('../../weights/backbone_efficientnetb0.pth') if load_fc: ret = model.load_state_dict(state_dict, strict=False) print(ret) else: state_dict.pop('_fc.weight') state_dict.pop('_fc.bias') res = model.load_state_dict(state_dict, strict=False) assert set(res.missing_keys) == set(['_fc.weight', '_fc.bias']), 'issue loading pretrained weights' print('Loaded pretrained weights for {}'.format(model_name)) class EfficientNet(nn.Module): """ An EfficientNet model. Most easily loaded with the .from_name or .from_pretrained methods Args: blocks_args (list): A list of BlockArgs to construct blocks global_params (namedtuple): A set of GlobalParams shared between blocks Example: model = EfficientNet.from_pretrained('efficientnet-b0') """ def __init__(self, blocks_args=None, global_params=None): super().__init__() assert isinstance(blocks_args, list), 'blocks_args should be a list' assert len(blocks_args) > 0, 'block args must be greater than 0' self._global_params = global_params self._blocks_args = blocks_args # Get static or dynamic convolution depending on image size Conv2d = get_same_padding_conv2d(image_size=global_params.image_size) # Batch norm parameters bn_mom = 1 - self._global_params.batch_norm_momentum bn_eps = self._global_params.batch_norm_epsilon # Stem in_channels = 3 # rgb out_channels = round_filters(32, self._global_params) # number of output channels self._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) self._bn0 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) # Build blocks self._blocks = nn.ModuleList([]) for block_args in self._blocks_args: # Update block input and output filters based on depth multiplier. block_args = block_args._replace( input_filters=round_filters(block_args.input_filters, self._global_params), output_filters=round_filters(block_args.output_filters, self._global_params), num_repeat=round_repeats(block_args.num_repeat, self._global_params) ) # The first block needs to take care of stride and filter size increase. self._blocks.append(MBConvBlock(block_args, self._global_params)) if block_args.num_repeat > 1: block_args = block_args._replace(input_filters=block_args.output_filters, stride=1) for _ in range(block_args.num_repeat - 1): self._blocks.append(MBConvBlock(block_args, self._global_params)) # Head in_channels = block_args.output_filters # output of final block out_channels = round_filters(1280, self._global_params) self._conv_head = Conv2d(in_channels, out_channels, kernel_size=1, bias=False) self._bn1 = nn.BatchNorm2d(num_features=out_channels, momentum=bn_mom, eps=bn_eps) # Final linear layer self._avg_pooling = nn.AdaptiveAvgPool2d(1) self._dropout = nn.Dropout(self._global_params.dropout_rate) self._fc = nn.Linear(out_channels, self._global_params.num_classes) self._swish = MemoryEfficientSwish() def set_swish(self, memory_efficient=True): """Sets swish function as memory efficient (for training) or standard (for export)""" self._swish = MemoryEfficientSwish() if memory_efficient else Swish() for block in self._blocks: block.set_swish(memory_efficient) def extract_features(self, inputs): """ Returns output of the final convolution layer """ # Stem x = self._swish(self._bn0(self._conv_stem(inputs))) # Blocks for idx, block in enumerate(self._blocks): drop_connect_rate = self._global_params.drop_connect_rate if drop_connect_rate: drop_connect_rate *= float(idx) / len(self._blocks) x = block(x, drop_connect_rate=drop_connect_rate) # Head x = self._swish(self._bn1(self._conv_head(x))) return x def forward(self, inputs): """ Calls extract_features to extract features, applies final linear layer, and returns logits. """ bs = inputs.size(0) # Convolution layers x = self.extract_features(inputs) # Pooling and final linear layer x = self._avg_pooling(x) x = x.view(bs, -1) x = self._dropout(x) x = self._fc(x) return x @classmethod def from_name(cls, model_name, override_params=None): cls._check_model_name_is_valid(model_name) blocks_args, global_params = get_model_params(model_name, override_params) return cls(blocks_args, global_params) @classmethod def from_pretrained(cls, model_name, load_weights=True, advprop=True, num_classes=1000, in_channels=3): model = cls.from_name(model_name, override_params={'num_classes': num_classes}) if load_weights: load_pretrained_weights(model, model_name, load_fc=(num_classes == 1000), advprop=advprop) if in_channels != 3: Conv2d = get_same_padding_conv2d(image_size = model._global_params.image_size) out_channels = round_filters(32, model._global_params) model._conv_stem = Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=False) return model @classmethod def get_image_size(cls, model_name): cls._check_model_name_is_valid(model_name) _, _, res, _ = efficientnet_params(model_name) return res @classmethod def _check_model_name_is_valid(cls, model_name): """ Validates model name. """ valid_models = ['efficientnet-b'+str(i) for i in range(9)] if model_name not in valid_models: raise ValueError('model_name should be one of: ' + ', '.join(valid_models)) class EfficientNetD(nn.Module): """ modified by Zylo117 """ def __init__(self, compound_coef, load_weights=False): super().__init__() model = EfficientNet.from_pretrained(f'efficientnet-b{compound_coef}', load_weights) del model._conv_head del model._bn1 del model._avg_pooling del model._dropout del model._fc self.model = model def forward(self, x): x = self.model._conv_stem(x) x = self.model._bn0(x) x = self.model._swish(x) feature_maps = [] # TODO: temporarily storing extra tensor last_x and del it later might not be a good idea, # try recording stride changing when creating efficientnet, # and then apply it here. last_x = None for idx, block in enumerate(self.model._blocks): drop_connect_rate = self.model._global_params.drop_connect_rate if drop_connect_rate: drop_connect_rate *= float(idx) / len(self.model._blocks) x = block(x, drop_connect_rate=drop_connect_rate) if tuple(block._depthwise_conv.stride) == (2, 2): feature_maps.append(last_x) elif idx == len(self.model._blocks) - 1: feature_maps.append(x) last_x = x del last_x return feature_maps[1:] class Anchors(nn.Module): """ adapted and modified from https://github.com/google/automl/blob/master/efficientdet/anchors.py by Zylo117 """ def __init__(self, anchor_scale=4., pyramid_levels=None, **kwargs): super().__init__() from qd.qd_common import print_frame_info print_frame_info() self.anchor_scale = anchor_scale if pyramid_levels is None: self.pyramid_levels = [3, 4, 5, 6, 7] self.strides = kwargs.get('strides', [2 ** x for x in self.pyramid_levels]) self.scales = np.array(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])) self.ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]) self.buffer = {} @torch.no_grad() def forward(self, image, dtype=torch.float32, features=None): """Generates multiscale anchor boxes. Args: image_size: integer number of input image size. The input image has the same dimension for width and height. The image_size should be divided by the largest feature stride 2^max_level. anchor_scale: float number representing the scale of size of the base anchor to the feature stride 2^level. anchor_configs: a dictionary with keys as the levels of anchors and values as a list of anchor configuration. Returns: anchor_boxes: a numpy array with shape [N, 4], which stacks anchors on all feature levels. Raises: ValueError: input size must be the multiple of largest feature stride. """ image_shape = image.shape[2:] anchor_key = self.get_key('anchor', image_shape) stride_idx_key = self.get_key('anchor_stride_index', image_shape) if anchor_key in self.buffer: return {'stride_idx': self.buffer[stride_idx_key].detach(), 'anchor': self.buffer[anchor_key].detach()} if dtype == torch.float16: dtype = np.float16 else: dtype = np.float32 boxes_all = [] all_idx_strides = [] for idx_stride, stride in enumerate(self.strides): boxes_level = [] for scale, ratio in itertools.product(self.scales, self.ratios): if features is not None: f_h, f_w = features[idx_stride].shape[-2:] x = np.arange(stride / 2, stride * f_w, stride) y = np.arange(stride / 2, stride * f_h, stride) else: if image_shape[1] % stride != 0: x_max = stride * ((image_shape[1] + stride - 1) // stride) y_max = stride * ((image_shape[0] + stride - 1) // stride) else: x_max = image_shape[1] y_max = image_shape[0] x = np.arange(stride / 2, x_max, stride) y = np.arange(stride / 2, y_max, stride) xv, yv = np.meshgrid(x, y) xv = xv.reshape(-1) yv = yv.reshape(-1) base_anchor_size = self.anchor_scale * stride * scale anchor_size_x_2 = base_anchor_size * ratio[0] / 2.0 anchor_size_y_2 = base_anchor_size * ratio[1] / 2.0 # y1,x1,y2,x2 boxes = np.vstack((yv - anchor_size_y_2, xv - anchor_size_x_2, yv + anchor_size_y_2, xv + anchor_size_x_2)) boxes = np.swapaxes(boxes, 0, 1) boxes_level.append(np.expand_dims(boxes, axis=1)) # concat anchors on the same level to the reshape NxAx4 boxes_level = np.concatenate(boxes_level, axis=1) boxes_level = boxes_level.reshape([-1, 4]) idx_strides = torch.tensor([idx_stride] * len(boxes_level)) all_idx_strides.append(idx_strides) boxes_all.append(boxes_level) anchor_boxes = np.vstack(boxes_all) anchor_stride_indices = torch.cat(all_idx_strides).to(image.device) self.buffer[stride_idx_key] = anchor_stride_indices anchor_boxes = torch.from_numpy(anchor_boxes.astype(dtype)).to(image.device) anchor_boxes = anchor_boxes.unsqueeze(0) # save it for later use to reduce overhead self.buffer[anchor_key] = anchor_boxes return {'stride_idx': self.buffer[stride_idx_key], 'anchor': self.buffer[anchor_key]} def get_key(self, hint, image_shape): return '{}_{}'.format(hint, '_'.join(map(str, image_shape))) class EffNetFPN(nn.Module): def __init__(self, compound_coef=0, start_from=3): super().__init__() self.backbone_net = EfficientNetD(EfficientDetBackbone.backbone_compound_coef[compound_coef], load_weights=False) if start_from == 3: conv_channel_coef = EfficientDetBackbone.conv_channel_coef[compound_coef] else: conv_channel_coef = EfficientDetBackbone.conv_channel_coef2345[compound_coef] self.bifpn = nn.Sequential( *[BiFPN(EfficientDetBackbone.fpn_num_filters[compound_coef], conv_channel_coef, True if _ == 0 else False, attention=True if compound_coef < 6 else False, adaptive_up=True) for _ in range(EfficientDetBackbone.fpn_cell_repeats[compound_coef])]) self.out_channels = EfficientDetBackbone.fpn_num_filters[compound_coef] self.start_from = start_from assert self.start_from in [2, 3] def forward(self, inputs): if self.start_from == 3: _, p3, p4, p5 = self.backbone_net(inputs) features = (p3, p4, p5) features = self.bifpn(features) return features else: p2, p3, p4, p5 = self.backbone_net(inputs) features = (p2, p3, p4, p5) features = self.bifpn(features) return features class EfficientDetBackbone(nn.Module): backbone_compound_coef = [0, 1, 2, 3, 4, 5, 6, 6] fpn_num_filters = [64, 88, 112, 160, 224, 288, 384, 384] conv_channel_coef = { # the channels of P3/P4/P5. 0: [40, 112, 320], 1: [40, 112, 320], 2: [48, 120, 352], 3: [48, 136, 384], 4: [56, 160, 448], 5: [64, 176, 512], 6: [72, 200, 576], 7: [72, 200, 576], } conv_channel_coef2345 = { # the channels of P2/P3/P4/P5. 0: [24, 40, 112, 320], # to be determined for the following 1: [24, 40, 112, 320], 2: [24, 48, 120, 352], 3: [32, 48, 136, 384], 4: [32, 56, 160, 448], 5: [40, 64, 176, 512], 6: [72, 200], 7: [72, 200], } fpn_cell_repeats = [3, 4, 5, 6, 7, 7, 8, 8] def __init__(self, num_classes=80, compound_coef=0, load_weights=False, prior_prob=0.01, **kwargs): super(EfficientDetBackbone, self).__init__() self.compound_coef = compound_coef self.input_sizes = [512, 640, 768, 896, 1024, 1280, 1280, 1536] self.box_class_repeats = [3, 3, 3, 4, 4, 4, 5, 5] self.anchor_scale = [4., 4., 4., 4., 4., 4., 4., 5.] self.aspect_ratios = kwargs.get('ratios', [(1.0, 1.0), (1.4, 0.7), (0.7, 1.4)]) self.num_scales = len(kwargs.get('scales', [2 ** 0, 2 ** (1.0 / 3.0), 2 ** (2.0 / 3.0)])) num_anchors = len(self.aspect_ratios) * self.num_scales self.bifpn = nn.Sequential( *[BiFPN(self.fpn_num_filters[self.compound_coef], self.conv_channel_coef[compound_coef], True if _ == 0 else False, attention=True if compound_coef < 6 else False, adaptive_up=kwargs.get('adaptive_up')) for _ in range(self.fpn_cell_repeats[compound_coef])]) self.num_classes = num_classes self.regressor = Regressor(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, num_layers=self.box_class_repeats[self.compound_coef]) self.classifier = Classifier(in_channels=self.fpn_num_filters[self.compound_coef], num_anchors=num_anchors, num_classes=num_classes, num_layers=self.box_class_repeats[self.compound_coef], prior_prob=prior_prob) anchor_scale = self.anchor_scale[compound_coef] if kwargs.get('anchor_scale'): anchor_scale = kwargs.pop('anchor_scale') if 'anchor_scale' in kwargs: del kwargs['anchor_scale'] self.anchors = Anchors(anchor_scale=anchor_scale, **kwargs) self.backbone_net = EfficientNetD(self.backbone_compound_coef[compound_coef], load_weights) def freeze_bn(self): for m in self.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() def forward(self, inputs): _, p3, p4, p5 = self.backbone_net(inputs) features = (p3, p4, p5) features = self.bifpn(features) regression = self.regressor(features) classification = self.classifier(features) anchors = self.anchors(inputs, inputs.dtype, features=features) return features, regression, classification, anchors def init_backbone(self, path): state_dict = torch.load(path) try: ret = self.load_state_dict(state_dict, strict=False) print(ret) except RuntimeError as e: print('Ignoring ' + str(e) + '"') def init_weights(model): for name, module in model.named_modules(): is_conv_layer = isinstance(module, nn.Conv2d) if is_conv_layer: nn.init.kaiming_uniform_(module.weight.data) if module.bias is not None: module.bias.data.zero_() def calc_iou(a, b): # a(anchor) [boxes, (y1, x1, y2, x2)] # b(gt, coco-style) [boxes, (x1, y1, x2, y2)] area = (b[:, 2] - b[:, 0]) * (b[:, 3] - b[:, 1]) iw = torch.min(torch.unsqueeze(a[:, 3], dim=1), b[:, 2]) - torch.max(torch.unsqueeze(a[:, 1], 1), b[:, 0]) ih = torch.min(torch.unsqueeze(a[:, 2], dim=1), b[:, 3]) - torch.max(torch.unsqueeze(a[:, 0], 1), b[:, 1]) iw = torch.clamp(iw, min=0) ih = torch.clamp(ih, min=0) ua = torch.unsqueeze((a[:, 2] - a[:, 0]) * (a[:, 3] - a[:, 1]), dim=1) + area - iw * ih ua = torch.clamp(ua, min=1e-8) intersection = iw * ih IoU = intersection / ua return IoU class BBoxTransform(nn.Module): def forward(self, anchors, regression): """ decode_box_outputs adapted from https://github.com/google/automl/blob/master/efficientdet/anchors.py Args: anchors: [batchsize, boxes, (y1, x1, y2, x2)] regression: [batchsize, boxes, (dy, dx, dh, dw)] Returns: """ y_centers_a = (anchors[..., 0] + anchors[..., 2]) / 2 x_centers_a = (anchors[..., 1] + anchors[..., 3]) / 2 ha = anchors[..., 2] - anchors[..., 0] wa = anchors[..., 3] - anchors[..., 1] w = regression[..., 3].exp() * wa h = regression[..., 2].exp() * ha y_centers = regression[..., 0] * ha + y_centers_a x_centers = regression[..., 1] * wa + x_centers_a ymin = y_centers - h / 2. xmin = x_centers - w / 2. ymax = y_centers + h / 2. xmax = x_centers + w / 2. if len(anchors.shape) == 3: return torch.stack([xmin, ymin, xmax, ymax], dim=2) else: return torch.stack([xmin, ymin, xmax, ymax], dim=1) class ClipBoxes(nn.Module): def __init__(self): super(ClipBoxes, self).__init__() def forward(self, boxes, img): batch_size, num_channels, height, width = img.shape boxes[:, :, 0] = torch.clamp(boxes[:, :, 0], min=0) boxes[:, :, 1] = torch.clamp(boxes[:, :, 1], min=0) boxes[:, :, 2] = torch.clamp(boxes[:, :, 2], max=width - 1) boxes[:, :, 3] = torch.clamp(boxes[:, :, 3], max=height - 1) return boxes def postprocess2(x, anchors, regression, classification, transformed_anchors, threshold, iou_threshold, max_box): anchors = anchors['anchor'] all_above_th = classification > threshold out = [] num_image = x.shape[0] num_class = classification.shape[-1] #classification = classification.cpu() #transformed_anchors = transformed_anchors.cpu() #all_above_th = all_above_th.cpu() max_box_pre_nms = 1000 for i in range(num_image): all_rois = [] all_class_ids = [] all_scores = [] for c in range(num_class): above_th = all_above_th[i, :, c].nonzero() if len(above_th) == 0: continue above_prob = classification[i, above_th, c].squeeze(1) if len(above_th) > max_box_pre_nms: _, idx = above_prob.topk(max_box_pre_nms) above_th = above_th[idx] above_prob = above_prob[idx] transformed_anchors_per = transformed_anchors[i,above_th,:].squeeze(dim=1) from torchvision.ops import nms nms_idx = nms(transformed_anchors_per, above_prob, iou_threshold=iou_threshold) if len(nms_idx) > 0: all_rois.append(transformed_anchors_per[nms_idx]) ids = torch.tensor([c] * len(nms_idx)) all_class_ids.append(ids) all_scores.append(above_prob[nms_idx]) if len(all_rois) > 0: rois = torch.cat(all_rois) class_ids = torch.cat(all_class_ids) scores = torch.cat(all_scores) if len(scores) > max_box: _, idx = torch.topk(scores, max_box) rois = rois[idx, :] class_ids = class_ids[idx] scores = scores[idx] out.append({ 'rois': rois, 'class_ids': class_ids, 'scores': scores, }) else: out.append({ 'rois': [], 'class_ids': [], 'scores': [], }) return out def postprocess(x, anchors, regression, classification, regressBoxes, clipBoxes, threshold, iou_threshold): anchors = anchors['anchor'] transformed_anchors = regressBoxes(anchors, regression) transformed_anchors = clipBoxes(transformed_anchors, x) scores = torch.max(classification, dim=2, keepdim=True)[0] scores_over_thresh = (scores > threshold)[:, :, 0] out = [] for i in range(x.shape[0]): if scores_over_thresh.sum() == 0: out.append({ 'rois': [], 'class_ids': [], 'scores': [], }) continue classification_per = classification[i, scores_over_thresh[i, :], ...].permute(1, 0) transformed_anchors_per = transformed_anchors[i, scores_over_thresh[i, :], ...] scores_per = scores[i, scores_over_thresh[i, :], ...] from torchvision.ops import nms anchors_nms_idx = nms(transformed_anchors_per, scores_per[:, 0], iou_threshold=iou_threshold) if anchors_nms_idx.shape[0] != 0: scores_, classes_ = classification_per[:, anchors_nms_idx].max(dim=0) boxes_ = transformed_anchors_per[anchors_nms_idx, :] out.append({ 'rois': boxes_, 'class_ids': classes_, 'scores': scores_, }) else: out.append({ 'rois': [], 'class_ids': [], 'scores': [], }) return out def display(preds, imgs, obj_list, imshow=True, imwrite=False): for i in range(len(imgs)): if len(preds[i]['rois']) == 0: continue for j in range(len(preds[i]['rois'])): (x1, y1, x2, y2) = preds[i]['rois'][j].detach().cpu().numpy().astype(np.int) logging.info((x1, y1, x2, y2)) cv2.rectangle(imgs[i], (x1, y1), (x2, y2), (255, 255, 0), 2) #obj = obj_list[preds[i]['class_ids'][j]] #score = float(preds[i]['scores'][j]) #cv2.putText(imgs[i], '{}, {:.3f}'.format(obj, score), #(x1, y1 + 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, #(255, 255, 0), 1) #break if imshow: cv2.imshow('image', imgs[i]) cv2.waitKey(0) def calculate_focal_loss2(classification, target_list, alpha, gamma): from maskrcnn_benchmark.layers.sigmoid_focal_loss import sigmoid_focal_loss_cuda cls_loss = sigmoid_focal_loss_cuda(classification, target_list.int(), gamma, alpha) return cls_loss def calculate_focal_loss(classification, targets, alpha, gamma): classification = classification.sigmoid() device = classification.device alpha_factor = torch.ones_like(targets) * alpha alpha_factor = alpha_factor.to(device) alpha_factor = torch.where(torch.eq(targets, 1.), alpha_factor, 1. - alpha_factor) focal_weight = torch.where(torch.eq(targets, 1.), 1. - classification, classification) focal_weight = alpha_factor * torch.pow(focal_weight, gamma) bce = -(targets * torch.log(classification) + (1.0 - targets) * torch.log(1.0 - classification)) cls_loss = focal_weight * bce zeros = torch.zeros_like(cls_loss) zeros = zeros.to(device) cls_loss = torch.where(torch.ne(targets, -1.0), cls_loss, zeros) return cls_loss.mean() def calculate_giou(pred, gt): ax1, ay1, ax2, ay2 = pred[:, 0], pred[:, 1], pred[:, 2], pred[:, 3] bx1, by1, bx2, by2 = gt[:, 0], gt[:, 1], gt[:, 2], gt[:, 3] a = (ax2 - ax1) * (ay2 - ay1) b = (bx2 - bx1) * (by2 - by1) max_x1, _ = torch.max(torch.stack([ax1, bx1], dim=1), dim=1) max_y1, _ = torch.max(torch.stack([ay1, by1], dim=1), dim=1) min_x2, _ = torch.min(torch.stack([ax2, bx2], dim=1), dim=1) min_y2, _ = torch.min(torch.stack([ay2, by2], dim=1), dim=1) inter = (min_x2 > max_x1) * (min_y2 > max_y1) inter = inter * (min_x2 - max_x1) * (min_y2 - max_y1) min_x1, _ = torch.min(torch.stack([ax1, bx1], dim=1), dim=1) min_y1, _ = torch.min(torch.stack([ay1, by1], dim=1), dim=1) max_x2, _ = torch.max(torch.stack([ax2, bx2], dim=1), dim=1) max_y2, _ = torch.max(torch.stack([ay2, by2], dim=1), dim=1) cover = (max_x2 - min_x1) * (max_y2 - min_y1) union = a + b - inter iou = inter / (union + 1e-5) giou = iou - (cover - union) / (cover + 1e-5) return giou class FocalLoss(nn.Module): def __init__(self, alpha=0.25, gamma=2., cls_loss_type='FL', smooth_bce_pos=0.99, smooth_bce_neg=0.01, reg_loss_type='L1', at_least_1_assgin=False, neg_iou_th=0.4, pos_iou_th=0.5, cls_weight=1., reg_weight=1., ): super(FocalLoss, self).__init__() from qd.qd_common import print_frame_info print_frame_info() self.iter = 0 self.reg_loss_type = reg_loss_type self.regressBoxes = BBoxTransform() if cls_loss_type == 'FL': from qd.layers.loss import FocalLossWithLogitsNegLoss self.cls_loss = FocalLossWithLogitsNegLoss(alpha, gamma) elif cls_loss_type == 'BCE': from qd.qd_pytorch import BCEWithLogitsNegLoss self.cls_loss = BCEWithLogitsNegLoss(reduction='sum') elif cls_loss_type == 'SmoothBCE': from qd.layers.loss import SmoothBCEWithLogitsNegLoss self.cls_loss = SmoothBCEWithLogitsNegLoss( pos=smooth_bce_pos, neg=smooth_bce_neg) elif cls_loss_type == 'SmoothFL': from qd.layers.loss import FocalSmoothBCEWithLogitsNegLoss self.cls_loss = FocalSmoothBCEWithLogitsNegLoss( alpha=alpha, gamma=2., pos=smooth_bce_pos, neg=smooth_bce_neg) else: raise NotImplementedError(cls_loss_type) self.at_least_1_assgin = at_least_1_assgin self.gt_total = 0 self.gt_saved_by_at_least = 0 self.neg_iou_th = neg_iou_th self.pos_iou_th = pos_iou_th self.cls_weight = cls_weight self.reg_weight = reg_weight self.buf = {} def forward(self, classifications, regressions, anchor_info, annotations, **kwargs): debug = (self.iter % 100) == 0 self.iter += 1 if debug: from collections import defaultdict debug_info = defaultdict(list) batch_size = classifications.shape[0] classification_losses = [] regression_losses = [] anchors = anchor_info['anchor'] anchor = anchors[0, :, :] # assuming all image sizes are the same, which it is dtype = anchors.dtype anchor_widths = anchor[:, 3] - anchor[:, 1] anchor_heights = anchor[:, 2] - anchor[:, 0] anchor_ctr_x = anchor[:, 1] + 0.5 * anchor_widths anchor_ctr_y = anchor[:, 0] + 0.5 * anchor_heights #anchor_widths = anchor[:, 2] - anchor[:, 0] #anchor_heights = anchor[:, 3] - anchor[:, 1] #anchor_ctr_x = anchor[:, 0] + 0.5 * anchor_widths #anchor_ctr_y = anchor[:, 1] + 0.5 * anchor_heights device = classifications.device for j in range(batch_size): classification = classifications[j, :, :] regression = regressions[j, :, :] bbox_annotation = annotations[j] bbox_annotation = bbox_annotation[bbox_annotation[:, 4] != -1] #classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4) if bbox_annotation.shape[0] == 0: #cls_loss = calculate_focal_loss2(classification, #torch.zeros(len(classification)), alpha, #gamma) #cls_loss = cls_loss.mean() cls_loss = torch.tensor(0).to(dtype).to(device) regression_losses.append(torch.tensor(0).to(dtype).to(device)) classification_losses.append(cls_loss) continue IoU = calc_iou(anchor[:, :], bbox_annotation[:, :4]) IoU_max, IoU_argmax = torch.max(IoU, dim=1) if self.at_least_1_assgin: iou_max_gt, iou_argmax_gt = torch.max(IoU, dim=0) curr_saved = (iou_max_gt < self.pos_iou_th).sum() self.gt_saved_by_at_least += curr_saved self.gt_total += len(iou_argmax_gt) IoU_max[iou_argmax_gt] = 1. IoU_argmax[iou_argmax_gt] = torch.arange(len(iou_argmax_gt)).to(device) # compute the loss for classification targets = torch.ones_like(classification) * -1 targets = targets.to(device) targets[torch.lt(IoU_max, self.neg_iou_th), :] = 0 positive_indices = torch.ge(IoU_max, self.pos_iou_th) num_positive_anchors = positive_indices.sum() assigned_annotations = bbox_annotation[IoU_argmax, :] targets[positive_indices, :] = 0 targets[positive_indices, assigned_annotations[positive_indices, 4].long()] = 1 if debug: if num_positive_anchors > 0: debug_info['pos_conf'].append(classification[ positive_indices, assigned_annotations[positive_indices, 4].long()].mean()) debug_info['neg_conf'].append(classification[targets == 0].mean()) stride_idx = anchor_info['stride_idx'] positive_stride_idx = stride_idx[positive_indices] pos_count_each_stride = torch.tensor( [(positive_stride_idx == i).sum() for i in range(5)]) if 'cum_pos_count_each_stride' not in self.buf: self.buf['cum_pos_count_each_stride'] = pos_count_each_stride else: cum_pos_count_each_stride = self.buf['cum_pos_count_each_stride'] cum_pos_count_each_stride += pos_count_each_stride self.buf['cum_pos_count_each_stride'] = cum_pos_count_each_stride #cls_loss = calculate_focal_loss(classification, targets, alpha, #gamma) cls_loss = self.cls_loss(classification, targets) cls_loss = cls_loss.sum() / torch.clamp(num_positive_anchors.to(dtype), min=1.0) assert cls_loss == cls_loss classification_losses.append(cls_loss) if positive_indices.sum() > 0: assigned_annotations = assigned_annotations[positive_indices, :] if self.reg_loss_type == 'L1': anchor_widths_pi = anchor_widths[positive_indices] anchor_heights_pi = anchor_heights[positive_indices] anchor_ctr_x_pi = anchor_ctr_x[positive_indices] anchor_ctr_y_pi = anchor_ctr_y[positive_indices] gt_widths = assigned_annotations[:, 2] - assigned_annotations[:, 0] gt_heights = assigned_annotations[:, 3] - assigned_annotations[:, 1] gt_ctr_x = assigned_annotations[:, 0] + 0.5 * gt_widths gt_ctr_y = assigned_annotations[:, 1] + 0.5 * gt_heights # efficientdet style gt_widths = torch.clamp(gt_widths, min=1) gt_heights = torch.clamp(gt_heights, min=1) targets_dx = (gt_ctr_x - anchor_ctr_x_pi) / anchor_widths_pi targets_dy = (gt_ctr_y - anchor_ctr_y_pi) / anchor_heights_pi targets_dw = torch.log(gt_widths / anchor_widths_pi) targets_dh = torch.log(gt_heights / anchor_heights_pi) targets = torch.stack((targets_dy, targets_dx, targets_dh, targets_dw)) targets = targets.t() regression_diff = torch.abs(targets - regression[positive_indices, :]) regression_loss = torch.where( torch.le(regression_diff, 1.0 / 9.0), 0.5 * 9.0 * torch.pow(regression_diff, 2), regression_diff - 0.5 / 9.0 ).mean() elif self.reg_loss_type == 'GIOU': curr_regression = regression[positive_indices, :] curr_anchors = anchor[positive_indices] curr_pred_xyxy = self.regressBoxes(curr_anchors, curr_regression) regression_loss = 1.- calculate_giou(curr_pred_xyxy, assigned_annotations) regression_loss = regression_loss.mean() assert regression_loss == regression_loss else: raise NotImplementedError regression_losses.append(regression_loss) else: if torch.cuda.is_available(): regression_losses.append(torch.tensor(0).to(dtype).cuda()) else: regression_losses.append(torch.tensor(0).to(dtype)) if debug: if len(debug_info) > 0: logging.info('pos = {}; neg = {}, saved_ratio = {}/{}={:.1f}, ' 'stride_info = {}' .format( torch.tensor(debug_info['pos_conf']).mean(), torch.tensor(debug_info['neg_conf']).mean(), self.gt_saved_by_at_least, self.gt_total, 1. * self.gt_saved_by_at_least / self.gt_total, self.buf['cum_pos_count_each_stride'], )) return self.cls_weight * torch.stack(classification_losses).mean(dim=0, keepdim=True), \ self.reg_weight * torch.stack(regression_losses).mean(dim=0, keepdim=True) class ModelWithLoss(nn.Module): def __init__(self, model, criterion): super().__init__() self.criterion = criterion self.module = model def forward(self, *args): if len(args) == 2: imgs, annotations = args elif len(args) == 1: imgs, annotations = args[0][:2] _, regression, classification, anchors = self.module(imgs) cls_loss, reg_loss = self.criterion(classification, regression, anchors, annotations) return {'cls_loss': cls_loss, 'reg_loss': reg_loss} class TorchVisionNMS(nn.Module): def __init__(self, iou_threshold): super().__init__() self.iou_threshold = iou_threshold def forward(self, box, prob): nms_idx = nms(box, prob, iou_threshold=self.iou_threshold) return nms_idx class PostProcess(nn.Module): def __init__(self, iou_threshold): super().__init__() self.nms = TorchVisionNMS(iou_threshold) def forward(self, x, anchors, regression, classification, transformed_anchors, threshold, max_box): all_above_th = classification > threshold out = [] num_image = x.shape[0] num_class = classification.shape[-1] #classification = classification.cpu() #transformed_anchors = transformed_anchors.cpu() #all_above_th = all_above_th.cpu() max_box_pre_nms = 1000 for i in range(num_image): all_rois = [] all_class_ids = [] all_scores = [] for c in range(num_class): above_th = all_above_th[i, :, c].nonzero() if len(above_th) == 0: continue above_prob = classification[i, above_th, c].squeeze(1) if len(above_th) > max_box_pre_nms: _, idx = above_prob.topk(max_box_pre_nms) above_th = above_th[idx] above_prob = above_prob[idx] transformed_anchors_per = transformed_anchors[i,above_th,:].squeeze(dim=1) nms_idx = self.nms(transformed_anchors_per, above_prob) if len(nms_idx) > 0: all_rois.append(transformed_anchors_per[nms_idx]) ids = torch.tensor([c] * len(nms_idx)) all_class_ids.append(ids) all_scores.append(above_prob[nms_idx]) if len(all_rois) > 0: rois = torch.cat(all_rois) class_ids = torch.cat(all_class_ids) scores = torch.cat(all_scores) if len(scores) > max_box: _, idx = torch.topk(scores, max_box) rois = rois[idx, :] class_ids = class_ids[idx] scores = scores[idx] out.append({ 'rois': rois, 'class_ids': class_ids, 'scores': scores, }) else: out.append({ 'rois': [], 'class_ids': [], 'scores': [], }) return out class InferenceModel(nn.Module): def __init__(self, model): super().__init__() self.module = model self.regressBoxes = BBoxTransform() self.clipBoxes = ClipBoxes() self.threshold = 0.01 self.nms_threshold = 0.5 self.max_box = 100 self.debug = False self.post_process = PostProcess(self.nms_threshold) def forward(self, sample): features, regression, classification, anchor_info = self.module(sample['image']) anchors = anchor_info['anchor'] classification = classification.sigmoid() transformed_anchors = self.regressBoxes(anchors, regression) transformed_anchors = self.clipBoxes(transformed_anchors, sample['image']) preds = self.post_process(sample['image'], anchors, regression, classification, transformed_anchors, self.threshold, self.max_box) if self.debug: logging.info('debugging') imgs = sample['image'] imgs = imgs.permute(0, 2, 3, 1).cpu().numpy() imgs = ((imgs * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]) * 255).astype(np.uint8) imgs = [cv2.cvtColor(img, cv2.COLOR_RGB2BGR) for img in imgs] display(preds, imgs, list(map(str, range(80)))) for p, s in zip(preds, sample['scale']): if len(p['rois']) > 0: p['rois'] /= s return preds