WSCL / models /models.py
yhzhai's picture
release code
482ab8a
raw
history blame
23.2 kB
from typing import List
import torch
import torch.nn as nn
from . import hrnet, mobilenet, resnet, resnext
from .lib.nn import SynchronizedBatchNorm2d
BatchNorm2d = SynchronizedBatchNorm2d
class SegmentationModuleBase(nn.Module):
def __init__(self):
super(SegmentationModuleBase, self).__init__()
def pixel_acc(self, pred, label):
_, preds = torch.max(pred, dim=1)
valid = (label >= 0).long()
acc_sum = torch.sum(valid * (preds == label).long())
pixel_sum = torch.sum(valid)
acc = acc_sum.float() / (pixel_sum.float() + 1e-10)
return acc
class SegmentationModule(SegmentationModuleBase):
def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None):
super(SegmentationModule, self).__init__()
self.encoder = net_enc
self.decoder = net_dec
self.crit = crit
self.deep_sup_scale = deep_sup_scale
def forward(self, feed_dict, *, segSize=None):
# training
if segSize is None:
if self.deep_sup_scale is not None: # use deep supervision technique
(pred, pred_deepsup) = self.decoder(
self.encoder(feed_dict["img_data"], return_feature_maps=True)
)
else:
pred = self.decoder(
self.encoder(feed_dict["img_data"], return_feature_maps=True)
)
loss = self.crit(pred, feed_dict["seg_label"])
if self.deep_sup_scale is not None:
loss_deepsup = self.crit(pred_deepsup, feed_dict["seg_label"])
loss = loss + loss_deepsup * self.deep_sup_scale
acc = self.pixel_acc(pred, feed_dict["seg_label"])
return loss, acc
# inference
else:
pred = self.decoder(
self.encoder(feed_dict["img_data"], return_feature_maps=True),
segSize=segSize,
)
return pred
class ModelBuilder:
# custom weights initialization
@staticmethod
def weights_init(m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.kaiming_normal_(m.weight.data)
elif classname.find("BatchNorm") != -1:
m.weight.data.fill_(1.0)
m.bias.data.fill_(1e-4)
# elif classname.find('Linear') != -1:
# m.weight.data.normal_(0.0, 0.0001)
@staticmethod
def build_encoder(arch="resnet50dilated", fc_dim=512, weights=""):
pretrained = True if len(weights) == 0 else False
arch = arch.lower()
if arch == "mobilenetv2dilated":
orig_mobilenet = mobilenet.__dict__["mobilenetv2"](pretrained=pretrained)
net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8)
elif arch == "resnet18":
orig_resnet = resnet.__dict__["resnet18"](pretrained=pretrained)
net_encoder = Resnet(orig_resnet)
elif arch == "resnet18dilated":
orig_resnet = resnet.__dict__["resnet18"](pretrained=pretrained)
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
elif arch == "resnet34":
raise NotImplementedError
orig_resnet = resnet.__dict__["resnet34"](pretrained=pretrained)
net_encoder = Resnet(orig_resnet)
elif arch == "resnet34dilated":
raise NotImplementedError
orig_resnet = resnet.__dict__["resnet34"](pretrained=pretrained)
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
elif arch == "resnet50":
orig_resnet = resnet.__dict__["resnet50"](pretrained=pretrained)
net_encoder = Resnet(orig_resnet)
elif arch == "resnet50dilated":
orig_resnet = resnet.__dict__["resnet50"](pretrained=pretrained)
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
elif arch == "resnet101":
orig_resnet = resnet.__dict__["resnet101"](pretrained=pretrained)
net_encoder = Resnet(orig_resnet)
elif arch == "resnet101dilated":
orig_resnet = resnet.__dict__["resnet101"](pretrained=pretrained)
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
elif arch == "resnext101":
orig_resnext = resnext.__dict__["resnext101"](pretrained=pretrained)
net_encoder = Resnet(orig_resnext) # we can still use class Resnet
elif arch == "hrnetv2":
net_encoder = hrnet.__dict__["hrnetv2"](pretrained=pretrained)
else:
raise Exception("Architecture undefined!")
# encoders are usually pretrained
# net_encoder.apply(ModelBuilder.weights_init)
if len(weights) > 0:
print("Loading weights for net_encoder")
net_encoder.load_state_dict(
torch.load(weights, map_location=lambda storage, loc: storage),
strict=False,
)
return net_encoder
@staticmethod
def build_decoder(
arch="ppm_deepsup",
fc_dim=512,
num_class=150,
weights="",
use_softmax=False,
dropout=0.0,
fcn_up: int = 32,
):
arch = arch.lower()
if arch == "c1_deepsup":
net_decoder = C1DeepSup(
num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax
)
elif arch == "c1": # currently only support C1
net_decoder = C1(
num_class=num_class,
fc_dim=fc_dim,
use_softmax=use_softmax,
dropout=dropout,
fcn_up=fcn_up,
)
elif arch == "ppm":
net_decoder = PPM(
num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax
)
elif arch == "ppm_deepsup":
net_decoder = PPMDeepsup(
num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax
)
elif arch == "upernet_lite":
net_decoder = UPerNet(
num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax, fpn_dim=256
)
elif arch == "upernet":
net_decoder = UPerNet(
num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax, fpn_dim=512
)
else:
raise Exception("Architecture undefined!")
net_decoder.apply(ModelBuilder.weights_init)
if len(weights) > 0:
print("Loading weights for net_decoder")
net_decoder.load_state_dict(
torch.load(weights, map_location=lambda storage, loc: storage),
strict=False,
)
return net_decoder
def conv3x3_bn_relu(in_planes, out_planes, stride=1):
"3x3 convolution + BN + relu"
return nn.Sequential(
nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
),
BatchNorm2d(out_planes),
nn.ReLU(inplace=True),
)
class Resnet(nn.Module):
def __init__(self, orig_resnet):
super(Resnet, self).__init__()
# take pretrained resnet, except AvgPool and FC
self.conv1 = orig_resnet.conv1
self.bn1 = orig_resnet.bn1
self.relu1 = orig_resnet.relu1
self.conv2 = orig_resnet.conv2
self.bn2 = orig_resnet.bn2
self.relu2 = orig_resnet.relu2
self.conv3 = orig_resnet.conv3
self.bn3 = orig_resnet.bn3
self.relu3 = orig_resnet.relu3
self.maxpool = orig_resnet.maxpool
self.layer1 = orig_resnet.layer1
self.layer2 = orig_resnet.layer2
self.layer3 = orig_resnet.layer3
self.layer4 = orig_resnet.layer4
def forward(self, x, return_feature_maps=False):
conv_out = []
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.maxpool(x) # b, 128, h / 2, w / 2
x = self.layer1(x)
conv_out.append(x)
# b, 128, h / 4, w / 4
x = self.layer2(x)
conv_out.append(x)
# b, 128, h / 8, w / 8
x = self.layer3(x)
conv_out.append(x)
# b, 128, h / 16, w / 16
x = self.layer4(x)
conv_out.append(x)
# b, 128, h / 32, w / 32
if return_feature_maps:
return conv_out
return [x]
class ResnetDilated(nn.Module):
def __init__(self, orig_resnet, dilate_scale=8):
super(ResnetDilated, self).__init__()
from functools import partial
if dilate_scale == 8:
orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2))
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4))
elif dilate_scale == 16:
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2))
# take pretrained resnet, except AvgPool and FC
self.conv1 = orig_resnet.conv1
self.bn1 = orig_resnet.bn1
self.relu1 = orig_resnet.relu1
self.conv2 = orig_resnet.conv2
self.bn2 = orig_resnet.bn2
self.relu2 = orig_resnet.relu2
self.conv3 = orig_resnet.conv3
self.bn3 = orig_resnet.bn3
self.relu3 = orig_resnet.relu3
self.maxpool = orig_resnet.maxpool
self.layer1 = orig_resnet.layer1
self.layer2 = orig_resnet.layer2
self.layer3 = orig_resnet.layer3
self.layer4 = orig_resnet.layer4
def _nostride_dilate(self, m, dilate):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
# the convolution with stride
if m.stride == (2, 2):
m.stride = (1, 1)
if m.kernel_size == (3, 3):
m.dilation = (dilate // 2, dilate // 2)
m.padding = (dilate // 2, dilate // 2)
# other convoluions
else:
if m.kernel_size == (3, 3):
m.dilation = (dilate, dilate)
m.padding = (dilate, dilate)
def forward(self, x, return_feature_maps=False):
conv_out = []
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.maxpool(x)
x = self.layer1(x)
conv_out.append(x)
x = self.layer2(x)
conv_out.append(x)
x = self.layer3(x)
conv_out.append(x)
x = self.layer4(x)
conv_out.append(x)
if return_feature_maps:
return conv_out
return [x]
class MobileNetV2Dilated(nn.Module):
def __init__(self, orig_net, dilate_scale=8):
super(MobileNetV2Dilated, self).__init__()
from functools import partial
# take pretrained mobilenet features
self.features = orig_net.features[:-1]
self.total_idx = len(self.features)
self.down_idx = [2, 4, 7, 14]
if dilate_scale == 8:
for i in range(self.down_idx[-2], self.down_idx[-1]):
self.features[i].apply(partial(self._nostride_dilate, dilate=2))
for i in range(self.down_idx[-1], self.total_idx):
self.features[i].apply(partial(self._nostride_dilate, dilate=4))
elif dilate_scale == 16:
for i in range(self.down_idx[-1], self.total_idx):
self.features[i].apply(partial(self._nostride_dilate, dilate=2))
def _nostride_dilate(self, m, dilate):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
# the convolution with stride
if m.stride == (2, 2):
m.stride = (1, 1)
if m.kernel_size == (3, 3):
m.dilation = (dilate // 2, dilate // 2)
m.padding = (dilate // 2, dilate // 2)
# other convoluions
else:
if m.kernel_size == (3, 3):
m.dilation = (dilate, dilate)
m.padding = (dilate, dilate)
def forward(self, x, return_feature_maps=False):
if return_feature_maps:
conv_out = []
for i in range(self.total_idx):
x = self.features[i](x)
if i in self.down_idx:
conv_out.append(x)
conv_out.append(x)
return conv_out
else:
return [self.features(x)]
# last conv, deep supervision
class C1DeepSup(nn.Module):
def __init__(self, num_class=150, fc_dim=2048, use_softmax=False):
super(C1DeepSup, self).__init__()
self.use_softmax = use_softmax
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
# last conv
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
def forward(self, conv_out, segSize=None):
conv5 = conv_out[-1]
x = self.cbr(conv5)
x = self.conv_last(x)
if self.use_softmax: # is True during inference
x = nn.functional.interpolate(
x, size=segSize, mode="bilinear", align_corners=False
)
x = nn.functional.softmax(x, dim=1)
return x
# deep sup
conv4 = conv_out[-2]
_ = self.cbr_deepsup(conv4)
_ = self.conv_last_deepsup(_)
x = nn.functional.log_softmax(x, dim=1)
_ = nn.functional.log_softmax(_, dim=1)
return (x, _)
# last conv
class C1(nn.Module):
def __init__(
self,
num_class=150,
fc_dim: int = 2048,
use_softmax=False,
dropout=0.0,
fcn_up: int = 32,
):
super(C1, self).__init__()
self.use_softmax = use_softmax
self.fcn_up = fcn_up
if fcn_up == 32:
in_dim = fc_dim
elif fcn_up == 16:
in_dim = int(fc_dim / 2 * 3)
else: # 8
in_dim = int(fc_dim / 2 * 3 + fc_dim / 4)
self.cbr = conv3x3_bn_relu(in_dim, fc_dim // 4, 1)
# last conv
self.dropout = nn.Dropout2d(dropout)
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
def forward(self, conv_out: List, segSize=None):
if self.fcn_up == 32:
conv5 = conv_out[-1]
elif self.fcn_up == 16:
conv4 = conv_out[-2]
tgt_shape = conv4.shape[-2:]
conv5 = conv_out[-1]
conv5 = nn.functional.interpolate(
conv5, size=tgt_shape, mode="bilinear", align_corners=False
)
conv5 = torch.cat([conv4, conv5], dim=1)
else: # 8
conv3 = conv_out[-3]
tgt_shape = conv3.shape[-2:]
conv4 = conv_out[-2]
conv5 = conv_out[-1]
conv4 = nn.functional.interpolate(
conv4, size=tgt_shape, mode="bilinear", align_corners=False
)
conv5 = nn.functional.interpolate(
conv5, size=tgt_shape, mode="bilinear", align_corners=False
)
conv5 = torch.cat([conv3, conv4, conv5], dim=1)
x = self.cbr(conv5)
x = self.dropout(x)
x = self.conv_last(x)
return x
# pyramid pooling
class PPM(nn.Module):
def __init__(
self, num_class=150, fc_dim=4096, use_softmax=False, pool_scales=(1, 2, 3, 6)
):
super(PPM, self).__init__()
self.use_softmax = use_softmax
self.ppm = []
for scale in pool_scales:
self.ppm.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(scale),
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
BatchNorm2d(512),
nn.ReLU(inplace=True),
)
)
self.ppm = nn.ModuleList(self.ppm)
self.conv_last = nn.Sequential(
nn.Conv2d(
fc_dim + len(pool_scales) * 512,
512,
kernel_size=3,
padding=1,
bias=False,
),
BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Dropout2d(0.1),
nn.Conv2d(512, num_class, kernel_size=1),
)
def forward(self, conv_out, segSize=None):
conv5 = conv_out[-1]
input_size = conv5.size()
ppm_out = [conv5]
for pool_scale in self.ppm:
ppm_out.append(
nn.functional.interpolate(
pool_scale(conv5),
(input_size[2], input_size[3]),
mode="bilinear",
align_corners=False,
)
)
ppm_out = torch.cat(ppm_out, 1)
x = self.conv_last(ppm_out)
if segSize is not None: # for inference
x = nn.functional.interpolate(
x, size=segSize, mode="bilinear", align_corners=False
)
return x
# pyramid pooling, deep supervision
class PPMDeepsup(nn.Module):
def __init__(
self, num_class=150, fc_dim=4096, use_softmax=False, pool_scales=(1, 2, 3, 6)
):
super(PPMDeepsup, self).__init__()
self.use_softmax = use_softmax
self.ppm = []
for scale in pool_scales:
self.ppm.append(
nn.Sequential(
nn.AdaptiveAvgPool2d(scale),
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
BatchNorm2d(512),
nn.ReLU(inplace=True),
)
)
self.ppm = nn.ModuleList(self.ppm)
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
self.conv_last = nn.Sequential(
nn.Conv2d(
fc_dim + len(pool_scales) * 512,
512,
kernel_size=3,
padding=1,
bias=False,
),
BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Dropout2d(0.1),
nn.Conv2d(512, num_class, kernel_size=1),
)
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
self.dropout_deepsup = nn.Dropout2d(0.1)
def forward(self, conv_out, segSize=None):
conv5 = conv_out[-1]
input_size = conv5.size()
ppm_out = [conv5]
for pool_scale in self.ppm:
ppm_out.append(
nn.functional.interpolate(
pool_scale(conv5),
(input_size[2], input_size[3]),
mode="bilinear",
align_corners=False,
)
)
ppm_out = torch.cat(ppm_out, 1)
x = self.conv_last(ppm_out)
if self.use_softmax: # is True during inference
x = nn.functional.interpolate(
x, size=segSize, mode="bilinear", align_corners=False
)
x = nn.functional.softmax(x, dim=1)
return x
# deep sup
conv4 = conv_out[-2]
_ = self.cbr_deepsup(conv4)
_ = self.dropout_deepsup(_)
_ = self.conv_last_deepsup(_)
x = nn.functional.log_softmax(x, dim=1)
_ = nn.functional.log_softmax(_, dim=1)
return (x, _)
# upernet
class UPerNet(nn.Module):
def __init__(
self,
num_class=150,
fc_dim=4096,
use_softmax=False,
pool_scales=(1, 2, 3, 6),
fpn_inplanes=(256, 512, 1024, 2048),
fpn_dim=256,
):
super(UPerNet, self).__init__()
self.use_softmax = use_softmax
# PPM Module
self.ppm_pooling = []
self.ppm_conv = []
for scale in pool_scales:
self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale))
self.ppm_conv.append(
nn.Sequential(
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
BatchNorm2d(512),
nn.ReLU(inplace=True),
)
)
self.ppm_pooling = nn.ModuleList(self.ppm_pooling)
self.ppm_conv = nn.ModuleList(self.ppm_conv)
self.ppm_last_conv = conv3x3_bn_relu(
fc_dim + len(pool_scales) * 512, fpn_dim, 1
)
# FPN Module
self.fpn_in = []
for fpn_inplane in fpn_inplanes[:-1]: # skip the top layer
self.fpn_in.append(
nn.Sequential(
nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False),
BatchNorm2d(fpn_dim),
nn.ReLU(inplace=True),
)
)
self.fpn_in = nn.ModuleList(self.fpn_in)
self.fpn_out = []
for i in range(len(fpn_inplanes) - 1): # skip the top layer
self.fpn_out.append(
nn.Sequential(
conv3x3_bn_relu(fpn_dim, fpn_dim, 1),
)
)
self.fpn_out = nn.ModuleList(self.fpn_out)
self.conv_last = nn.Sequential(
conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1),
nn.Conv2d(fpn_dim, num_class, kernel_size=1),
)
def forward(self, conv_out, segSize=None):
conv5 = conv_out[-1]
input_size = conv5.size()
ppm_out = [conv5]
for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):
ppm_out.append(
pool_conv(
nn.functional.interpolate(
pool_scale(conv5),
(input_size[2], input_size[3]),
mode="bilinear",
align_corners=False,
)
)
)
ppm_out = torch.cat(ppm_out, 1)
f = self.ppm_last_conv(ppm_out)
fpn_feature_list = [f]
for i in reversed(range(len(conv_out) - 1)):
conv_x = conv_out[i]
conv_x = self.fpn_in[i](conv_x) # lateral branch
f = nn.functional.interpolate(
f, size=conv_x.size()[2:], mode="bilinear", align_corners=False
) # top-down branch
f = conv_x + f
fpn_feature_list.append(self.fpn_out[i](f))
fpn_feature_list.reverse() # [P2 - P5]
output_size = fpn_feature_list[0].size()[2:]
fusion_list = [fpn_feature_list[0]]
for i in range(1, len(fpn_feature_list)):
fusion_list.append(
nn.functional.interpolate(
fpn_feature_list[i],
output_size,
mode="bilinear",
align_corners=False,
)
)
fusion_out = torch.cat(fusion_list, 1)
x = self.conv_last(fusion_out)
if self.use_softmax: # is True during inference
x = nn.functional.interpolate(
x, size=segSize, mode="bilinear", align_corners=False
)
x = nn.functional.softmax(x, dim=1)
return x
x = nn.functional.log_softmax(x, dim=1)
return x