|
from typing import List |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from . import hrnet, mobilenet, resnet, resnext |
|
from .lib.nn import SynchronizedBatchNorm2d |
|
|
|
BatchNorm2d = SynchronizedBatchNorm2d |
|
|
|
|
|
class SegmentationModuleBase(nn.Module): |
|
def __init__(self): |
|
super(SegmentationModuleBase, self).__init__() |
|
|
|
def pixel_acc(self, pred, label): |
|
_, preds = torch.max(pred, dim=1) |
|
valid = (label >= 0).long() |
|
acc_sum = torch.sum(valid * (preds == label).long()) |
|
pixel_sum = torch.sum(valid) |
|
acc = acc_sum.float() / (pixel_sum.float() + 1e-10) |
|
return acc |
|
|
|
|
|
class SegmentationModule(SegmentationModuleBase): |
|
def __init__(self, net_enc, net_dec, crit, deep_sup_scale=None): |
|
super(SegmentationModule, self).__init__() |
|
self.encoder = net_enc |
|
self.decoder = net_dec |
|
self.crit = crit |
|
self.deep_sup_scale = deep_sup_scale |
|
|
|
def forward(self, feed_dict, *, segSize=None): |
|
|
|
if segSize is None: |
|
if self.deep_sup_scale is not None: |
|
(pred, pred_deepsup) = self.decoder( |
|
self.encoder(feed_dict["img_data"], return_feature_maps=True) |
|
) |
|
else: |
|
pred = self.decoder( |
|
self.encoder(feed_dict["img_data"], return_feature_maps=True) |
|
) |
|
|
|
loss = self.crit(pred, feed_dict["seg_label"]) |
|
if self.deep_sup_scale is not None: |
|
loss_deepsup = self.crit(pred_deepsup, feed_dict["seg_label"]) |
|
loss = loss + loss_deepsup * self.deep_sup_scale |
|
|
|
acc = self.pixel_acc(pred, feed_dict["seg_label"]) |
|
return loss, acc |
|
|
|
else: |
|
pred = self.decoder( |
|
self.encoder(feed_dict["img_data"], return_feature_maps=True), |
|
segSize=segSize, |
|
) |
|
return pred |
|
|
|
|
|
class ModelBuilder: |
|
|
|
@staticmethod |
|
def weights_init(m): |
|
classname = m.__class__.__name__ |
|
if classname.find("Conv") != -1: |
|
nn.init.kaiming_normal_(m.weight.data) |
|
elif classname.find("BatchNorm") != -1: |
|
m.weight.data.fill_(1.0) |
|
m.bias.data.fill_(1e-4) |
|
|
|
|
|
|
|
@staticmethod |
|
def build_encoder(arch="resnet50dilated", fc_dim=512, weights=""): |
|
pretrained = True if len(weights) == 0 else False |
|
arch = arch.lower() |
|
if arch == "mobilenetv2dilated": |
|
orig_mobilenet = mobilenet.__dict__["mobilenetv2"](pretrained=pretrained) |
|
net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8) |
|
elif arch == "resnet18": |
|
orig_resnet = resnet.__dict__["resnet18"](pretrained=pretrained) |
|
net_encoder = Resnet(orig_resnet) |
|
elif arch == "resnet18dilated": |
|
orig_resnet = resnet.__dict__["resnet18"](pretrained=pretrained) |
|
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) |
|
elif arch == "resnet34": |
|
raise NotImplementedError |
|
orig_resnet = resnet.__dict__["resnet34"](pretrained=pretrained) |
|
net_encoder = Resnet(orig_resnet) |
|
elif arch == "resnet34dilated": |
|
raise NotImplementedError |
|
orig_resnet = resnet.__dict__["resnet34"](pretrained=pretrained) |
|
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) |
|
elif arch == "resnet50": |
|
orig_resnet = resnet.__dict__["resnet50"](pretrained=pretrained) |
|
net_encoder = Resnet(orig_resnet) |
|
elif arch == "resnet50dilated": |
|
orig_resnet = resnet.__dict__["resnet50"](pretrained=pretrained) |
|
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) |
|
elif arch == "resnet101": |
|
orig_resnet = resnet.__dict__["resnet101"](pretrained=pretrained) |
|
net_encoder = Resnet(orig_resnet) |
|
elif arch == "resnet101dilated": |
|
orig_resnet = resnet.__dict__["resnet101"](pretrained=pretrained) |
|
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8) |
|
elif arch == "resnext101": |
|
orig_resnext = resnext.__dict__["resnext101"](pretrained=pretrained) |
|
net_encoder = Resnet(orig_resnext) |
|
elif arch == "hrnetv2": |
|
net_encoder = hrnet.__dict__["hrnetv2"](pretrained=pretrained) |
|
else: |
|
raise Exception("Architecture undefined!") |
|
|
|
|
|
|
|
if len(weights) > 0: |
|
print("Loading weights for net_encoder") |
|
net_encoder.load_state_dict( |
|
torch.load(weights, map_location=lambda storage, loc: storage), |
|
strict=False, |
|
) |
|
return net_encoder |
|
|
|
@staticmethod |
|
def build_decoder( |
|
arch="ppm_deepsup", |
|
fc_dim=512, |
|
num_class=150, |
|
weights="", |
|
use_softmax=False, |
|
dropout=0.0, |
|
fcn_up: int = 32, |
|
): |
|
arch = arch.lower() |
|
if arch == "c1_deepsup": |
|
net_decoder = C1DeepSup( |
|
num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax |
|
) |
|
elif arch == "c1": |
|
net_decoder = C1( |
|
num_class=num_class, |
|
fc_dim=fc_dim, |
|
use_softmax=use_softmax, |
|
dropout=dropout, |
|
fcn_up=fcn_up, |
|
) |
|
elif arch == "ppm": |
|
net_decoder = PPM( |
|
num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax |
|
) |
|
elif arch == "ppm_deepsup": |
|
net_decoder = PPMDeepsup( |
|
num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax |
|
) |
|
elif arch == "upernet_lite": |
|
net_decoder = UPerNet( |
|
num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax, fpn_dim=256 |
|
) |
|
elif arch == "upernet": |
|
net_decoder = UPerNet( |
|
num_class=num_class, fc_dim=fc_dim, use_softmax=use_softmax, fpn_dim=512 |
|
) |
|
else: |
|
raise Exception("Architecture undefined!") |
|
|
|
net_decoder.apply(ModelBuilder.weights_init) |
|
if len(weights) > 0: |
|
print("Loading weights for net_decoder") |
|
net_decoder.load_state_dict( |
|
torch.load(weights, map_location=lambda storage, loc: storage), |
|
strict=False, |
|
) |
|
return net_decoder |
|
|
|
|
|
def conv3x3_bn_relu(in_planes, out_planes, stride=1): |
|
"3x3 convolution + BN + relu" |
|
return nn.Sequential( |
|
nn.Conv2d( |
|
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False |
|
), |
|
BatchNorm2d(out_planes), |
|
nn.ReLU(inplace=True), |
|
) |
|
|
|
|
|
class Resnet(nn.Module): |
|
def __init__(self, orig_resnet): |
|
super(Resnet, self).__init__() |
|
|
|
|
|
self.conv1 = orig_resnet.conv1 |
|
self.bn1 = orig_resnet.bn1 |
|
self.relu1 = orig_resnet.relu1 |
|
self.conv2 = orig_resnet.conv2 |
|
self.bn2 = orig_resnet.bn2 |
|
self.relu2 = orig_resnet.relu2 |
|
self.conv3 = orig_resnet.conv3 |
|
self.bn3 = orig_resnet.bn3 |
|
self.relu3 = orig_resnet.relu3 |
|
self.maxpool = orig_resnet.maxpool |
|
self.layer1 = orig_resnet.layer1 |
|
self.layer2 = orig_resnet.layer2 |
|
self.layer3 = orig_resnet.layer3 |
|
self.layer4 = orig_resnet.layer4 |
|
|
|
def forward(self, x, return_feature_maps=False): |
|
conv_out = [] |
|
|
|
x = self.relu1(self.bn1(self.conv1(x))) |
|
x = self.relu2(self.bn2(self.conv2(x))) |
|
x = self.relu3(self.bn3(self.conv3(x))) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x) |
|
conv_out.append(x) |
|
|
|
x = self.layer2(x) |
|
conv_out.append(x) |
|
|
|
x = self.layer3(x) |
|
conv_out.append(x) |
|
|
|
x = self.layer4(x) |
|
conv_out.append(x) |
|
|
|
|
|
if return_feature_maps: |
|
return conv_out |
|
return [x] |
|
|
|
|
|
class ResnetDilated(nn.Module): |
|
def __init__(self, orig_resnet, dilate_scale=8): |
|
super(ResnetDilated, self).__init__() |
|
from functools import partial |
|
|
|
if dilate_scale == 8: |
|
orig_resnet.layer3.apply(partial(self._nostride_dilate, dilate=2)) |
|
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=4)) |
|
elif dilate_scale == 16: |
|
orig_resnet.layer4.apply(partial(self._nostride_dilate, dilate=2)) |
|
|
|
|
|
self.conv1 = orig_resnet.conv1 |
|
self.bn1 = orig_resnet.bn1 |
|
self.relu1 = orig_resnet.relu1 |
|
self.conv2 = orig_resnet.conv2 |
|
self.bn2 = orig_resnet.bn2 |
|
self.relu2 = orig_resnet.relu2 |
|
self.conv3 = orig_resnet.conv3 |
|
self.bn3 = orig_resnet.bn3 |
|
self.relu3 = orig_resnet.relu3 |
|
self.maxpool = orig_resnet.maxpool |
|
self.layer1 = orig_resnet.layer1 |
|
self.layer2 = orig_resnet.layer2 |
|
self.layer3 = orig_resnet.layer3 |
|
self.layer4 = orig_resnet.layer4 |
|
|
|
def _nostride_dilate(self, m, dilate): |
|
classname = m.__class__.__name__ |
|
if classname.find("Conv") != -1: |
|
|
|
if m.stride == (2, 2): |
|
m.stride = (1, 1) |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate // 2, dilate // 2) |
|
m.padding = (dilate // 2, dilate // 2) |
|
|
|
else: |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate, dilate) |
|
m.padding = (dilate, dilate) |
|
|
|
def forward(self, x, return_feature_maps=False): |
|
conv_out = [] |
|
|
|
x = self.relu1(self.bn1(self.conv1(x))) |
|
x = self.relu2(self.bn2(self.conv2(x))) |
|
x = self.relu3(self.bn3(self.conv3(x))) |
|
x = self.maxpool(x) |
|
|
|
x = self.layer1(x) |
|
conv_out.append(x) |
|
x = self.layer2(x) |
|
conv_out.append(x) |
|
x = self.layer3(x) |
|
conv_out.append(x) |
|
x = self.layer4(x) |
|
conv_out.append(x) |
|
|
|
if return_feature_maps: |
|
return conv_out |
|
return [x] |
|
|
|
|
|
class MobileNetV2Dilated(nn.Module): |
|
def __init__(self, orig_net, dilate_scale=8): |
|
super(MobileNetV2Dilated, self).__init__() |
|
from functools import partial |
|
|
|
|
|
self.features = orig_net.features[:-1] |
|
|
|
self.total_idx = len(self.features) |
|
self.down_idx = [2, 4, 7, 14] |
|
|
|
if dilate_scale == 8: |
|
for i in range(self.down_idx[-2], self.down_idx[-1]): |
|
self.features[i].apply(partial(self._nostride_dilate, dilate=2)) |
|
for i in range(self.down_idx[-1], self.total_idx): |
|
self.features[i].apply(partial(self._nostride_dilate, dilate=4)) |
|
elif dilate_scale == 16: |
|
for i in range(self.down_idx[-1], self.total_idx): |
|
self.features[i].apply(partial(self._nostride_dilate, dilate=2)) |
|
|
|
def _nostride_dilate(self, m, dilate): |
|
classname = m.__class__.__name__ |
|
if classname.find("Conv") != -1: |
|
|
|
if m.stride == (2, 2): |
|
m.stride = (1, 1) |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate // 2, dilate // 2) |
|
m.padding = (dilate // 2, dilate // 2) |
|
|
|
else: |
|
if m.kernel_size == (3, 3): |
|
m.dilation = (dilate, dilate) |
|
m.padding = (dilate, dilate) |
|
|
|
def forward(self, x, return_feature_maps=False): |
|
if return_feature_maps: |
|
conv_out = [] |
|
for i in range(self.total_idx): |
|
x = self.features[i](x) |
|
if i in self.down_idx: |
|
conv_out.append(x) |
|
conv_out.append(x) |
|
return conv_out |
|
|
|
else: |
|
return [self.features(x)] |
|
|
|
|
|
|
|
class C1DeepSup(nn.Module): |
|
def __init__(self, num_class=150, fc_dim=2048, use_softmax=False): |
|
super(C1DeepSup, self).__init__() |
|
self.use_softmax = use_softmax |
|
|
|
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1) |
|
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) |
|
|
|
|
|
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) |
|
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) |
|
|
|
def forward(self, conv_out, segSize=None): |
|
conv5 = conv_out[-1] |
|
|
|
x = self.cbr(conv5) |
|
x = self.conv_last(x) |
|
|
|
if self.use_softmax: |
|
x = nn.functional.interpolate( |
|
x, size=segSize, mode="bilinear", align_corners=False |
|
) |
|
x = nn.functional.softmax(x, dim=1) |
|
return x |
|
|
|
|
|
conv4 = conv_out[-2] |
|
_ = self.cbr_deepsup(conv4) |
|
_ = self.conv_last_deepsup(_) |
|
|
|
x = nn.functional.log_softmax(x, dim=1) |
|
_ = nn.functional.log_softmax(_, dim=1) |
|
|
|
return (x, _) |
|
|
|
|
|
|
|
class C1(nn.Module): |
|
def __init__( |
|
self, |
|
num_class=150, |
|
fc_dim: int = 2048, |
|
use_softmax=False, |
|
dropout=0.0, |
|
fcn_up: int = 32, |
|
): |
|
super(C1, self).__init__() |
|
self.use_softmax = use_softmax |
|
self.fcn_up = fcn_up |
|
|
|
if fcn_up == 32: |
|
in_dim = fc_dim |
|
elif fcn_up == 16: |
|
in_dim = int(fc_dim / 2 * 3) |
|
else: |
|
in_dim = int(fc_dim / 2 * 3 + fc_dim / 4) |
|
self.cbr = conv3x3_bn_relu(in_dim, fc_dim // 4, 1) |
|
|
|
|
|
self.dropout = nn.Dropout2d(dropout) |
|
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) |
|
|
|
def forward(self, conv_out: List, segSize=None): |
|
if self.fcn_up == 32: |
|
conv5 = conv_out[-1] |
|
elif self.fcn_up == 16: |
|
conv4 = conv_out[-2] |
|
tgt_shape = conv4.shape[-2:] |
|
conv5 = conv_out[-1] |
|
conv5 = nn.functional.interpolate( |
|
conv5, size=tgt_shape, mode="bilinear", align_corners=False |
|
) |
|
conv5 = torch.cat([conv4, conv5], dim=1) |
|
else: |
|
conv3 = conv_out[-3] |
|
tgt_shape = conv3.shape[-2:] |
|
conv4 = conv_out[-2] |
|
conv5 = conv_out[-1] |
|
conv4 = nn.functional.interpolate( |
|
conv4, size=tgt_shape, mode="bilinear", align_corners=False |
|
) |
|
conv5 = nn.functional.interpolate( |
|
conv5, size=tgt_shape, mode="bilinear", align_corners=False |
|
) |
|
conv5 = torch.cat([conv3, conv4, conv5], dim=1) |
|
x = self.cbr(conv5) |
|
x = self.dropout(x) |
|
x = self.conv_last(x) |
|
|
|
return x |
|
|
|
|
|
|
|
class PPM(nn.Module): |
|
def __init__( |
|
self, num_class=150, fc_dim=4096, use_softmax=False, pool_scales=(1, 2, 3, 6) |
|
): |
|
super(PPM, self).__init__() |
|
self.use_softmax = use_softmax |
|
|
|
self.ppm = [] |
|
for scale in pool_scales: |
|
self.ppm.append( |
|
nn.Sequential( |
|
nn.AdaptiveAvgPool2d(scale), |
|
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), |
|
BatchNorm2d(512), |
|
nn.ReLU(inplace=True), |
|
) |
|
) |
|
self.ppm = nn.ModuleList(self.ppm) |
|
|
|
self.conv_last = nn.Sequential( |
|
nn.Conv2d( |
|
fc_dim + len(pool_scales) * 512, |
|
512, |
|
kernel_size=3, |
|
padding=1, |
|
bias=False, |
|
), |
|
BatchNorm2d(512), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout2d(0.1), |
|
nn.Conv2d(512, num_class, kernel_size=1), |
|
) |
|
|
|
def forward(self, conv_out, segSize=None): |
|
conv5 = conv_out[-1] |
|
|
|
input_size = conv5.size() |
|
ppm_out = [conv5] |
|
for pool_scale in self.ppm: |
|
ppm_out.append( |
|
nn.functional.interpolate( |
|
pool_scale(conv5), |
|
(input_size[2], input_size[3]), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
) |
|
ppm_out = torch.cat(ppm_out, 1) |
|
|
|
x = self.conv_last(ppm_out) |
|
|
|
if segSize is not None: |
|
x = nn.functional.interpolate( |
|
x, size=segSize, mode="bilinear", align_corners=False |
|
) |
|
return x |
|
|
|
|
|
|
|
class PPMDeepsup(nn.Module): |
|
def __init__( |
|
self, num_class=150, fc_dim=4096, use_softmax=False, pool_scales=(1, 2, 3, 6) |
|
): |
|
super(PPMDeepsup, self).__init__() |
|
self.use_softmax = use_softmax |
|
|
|
self.ppm = [] |
|
for scale in pool_scales: |
|
self.ppm.append( |
|
nn.Sequential( |
|
nn.AdaptiveAvgPool2d(scale), |
|
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), |
|
BatchNorm2d(512), |
|
nn.ReLU(inplace=True), |
|
) |
|
) |
|
self.ppm = nn.ModuleList(self.ppm) |
|
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1) |
|
|
|
self.conv_last = nn.Sequential( |
|
nn.Conv2d( |
|
fc_dim + len(pool_scales) * 512, |
|
512, |
|
kernel_size=3, |
|
padding=1, |
|
bias=False, |
|
), |
|
BatchNorm2d(512), |
|
nn.ReLU(inplace=True), |
|
nn.Dropout2d(0.1), |
|
nn.Conv2d(512, num_class, kernel_size=1), |
|
) |
|
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0) |
|
self.dropout_deepsup = nn.Dropout2d(0.1) |
|
|
|
def forward(self, conv_out, segSize=None): |
|
conv5 = conv_out[-1] |
|
|
|
input_size = conv5.size() |
|
ppm_out = [conv5] |
|
for pool_scale in self.ppm: |
|
ppm_out.append( |
|
nn.functional.interpolate( |
|
pool_scale(conv5), |
|
(input_size[2], input_size[3]), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
) |
|
ppm_out = torch.cat(ppm_out, 1) |
|
|
|
x = self.conv_last(ppm_out) |
|
|
|
if self.use_softmax: |
|
x = nn.functional.interpolate( |
|
x, size=segSize, mode="bilinear", align_corners=False |
|
) |
|
x = nn.functional.softmax(x, dim=1) |
|
return x |
|
|
|
|
|
conv4 = conv_out[-2] |
|
_ = self.cbr_deepsup(conv4) |
|
_ = self.dropout_deepsup(_) |
|
_ = self.conv_last_deepsup(_) |
|
|
|
x = nn.functional.log_softmax(x, dim=1) |
|
_ = nn.functional.log_softmax(_, dim=1) |
|
|
|
return (x, _) |
|
|
|
|
|
|
|
class UPerNet(nn.Module): |
|
def __init__( |
|
self, |
|
num_class=150, |
|
fc_dim=4096, |
|
use_softmax=False, |
|
pool_scales=(1, 2, 3, 6), |
|
fpn_inplanes=(256, 512, 1024, 2048), |
|
fpn_dim=256, |
|
): |
|
super(UPerNet, self).__init__() |
|
self.use_softmax = use_softmax |
|
|
|
|
|
self.ppm_pooling = [] |
|
self.ppm_conv = [] |
|
|
|
for scale in pool_scales: |
|
self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) |
|
self.ppm_conv.append( |
|
nn.Sequential( |
|
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), |
|
BatchNorm2d(512), |
|
nn.ReLU(inplace=True), |
|
) |
|
) |
|
self.ppm_pooling = nn.ModuleList(self.ppm_pooling) |
|
self.ppm_conv = nn.ModuleList(self.ppm_conv) |
|
self.ppm_last_conv = conv3x3_bn_relu( |
|
fc_dim + len(pool_scales) * 512, fpn_dim, 1 |
|
) |
|
|
|
|
|
self.fpn_in = [] |
|
for fpn_inplane in fpn_inplanes[:-1]: |
|
self.fpn_in.append( |
|
nn.Sequential( |
|
nn.Conv2d(fpn_inplane, fpn_dim, kernel_size=1, bias=False), |
|
BatchNorm2d(fpn_dim), |
|
nn.ReLU(inplace=True), |
|
) |
|
) |
|
self.fpn_in = nn.ModuleList(self.fpn_in) |
|
|
|
self.fpn_out = [] |
|
for i in range(len(fpn_inplanes) - 1): |
|
self.fpn_out.append( |
|
nn.Sequential( |
|
conv3x3_bn_relu(fpn_dim, fpn_dim, 1), |
|
) |
|
) |
|
self.fpn_out = nn.ModuleList(self.fpn_out) |
|
|
|
self.conv_last = nn.Sequential( |
|
conv3x3_bn_relu(len(fpn_inplanes) * fpn_dim, fpn_dim, 1), |
|
nn.Conv2d(fpn_dim, num_class, kernel_size=1), |
|
) |
|
|
|
def forward(self, conv_out, segSize=None): |
|
conv5 = conv_out[-1] |
|
|
|
input_size = conv5.size() |
|
ppm_out = [conv5] |
|
for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): |
|
ppm_out.append( |
|
pool_conv( |
|
nn.functional.interpolate( |
|
pool_scale(conv5), |
|
(input_size[2], input_size[3]), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
) |
|
) |
|
ppm_out = torch.cat(ppm_out, 1) |
|
f = self.ppm_last_conv(ppm_out) |
|
|
|
fpn_feature_list = [f] |
|
for i in reversed(range(len(conv_out) - 1)): |
|
conv_x = conv_out[i] |
|
conv_x = self.fpn_in[i](conv_x) |
|
|
|
f = nn.functional.interpolate( |
|
f, size=conv_x.size()[2:], mode="bilinear", align_corners=False |
|
) |
|
f = conv_x + f |
|
|
|
fpn_feature_list.append(self.fpn_out[i](f)) |
|
|
|
fpn_feature_list.reverse() |
|
output_size = fpn_feature_list[0].size()[2:] |
|
fusion_list = [fpn_feature_list[0]] |
|
for i in range(1, len(fpn_feature_list)): |
|
fusion_list.append( |
|
nn.functional.interpolate( |
|
fpn_feature_list[i], |
|
output_size, |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
) |
|
fusion_out = torch.cat(fusion_list, 1) |
|
x = self.conv_last(fusion_out) |
|
|
|
if self.use_softmax: |
|
x = nn.functional.interpolate( |
|
x, size=segSize, mode="bilinear", align_corners=False |
|
) |
|
x = nn.functional.softmax(x, dim=1) |
|
return x |
|
|
|
x = nn.functional.log_softmax(x, dim=1) |
|
|
|
return x |
|
|