3v324v23's picture
add
c310e19
raw
history blame
No virus
6.57 kB
#!/usr/bin/env python3
import torch
from torch import nn
from .inference import make_seg_postprocessor
from .loss import make_seg_loss_evaluator
import time
def conv3x3(in_planes, out_planes, stride=1, has_bias=False):
"3x3 convolution with padding"
return nn.Conv2d(
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=has_bias
)
def conv3x3_bn_relu(in_planes, out_planes, stride=1, has_bias=False):
return nn.Sequential(
conv3x3(in_planes, out_planes, stride),
nn.BatchNorm2d(out_planes),
nn.ReLU(inplace=True),
)
class SEGHead(nn.Module):
"""
Adds a simple SEG Head with pixel-level prediction
"""
def __init__(self, in_channels, cfg):
"""
Arguments:
in_channels (int): number of channels of the input feature
"""
super(SEGHead, self).__init__()
self.cfg = cfg
ndim = 256
self.fpn_out5 = nn.Sequential(
conv3x3(ndim, 64), nn.Upsample(scale_factor=8, mode="nearest")
)
self.fpn_out4 = nn.Sequential(
conv3x3(ndim, 64), nn.Upsample(scale_factor=4, mode="nearest")
)
self.fpn_out3 = nn.Sequential(
conv3x3(ndim, 64), nn.Upsample(scale_factor=2, mode="nearest")
)
self.fpn_out2 = conv3x3(ndim, 64)
self.seg_out = nn.Sequential(
conv3x3_bn_relu(in_channels, 64, 1),
nn.ConvTranspose2d(64, 64, 2, 2),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 1, 2, 2),
nn.Sigmoid(),
)
if self.cfg.MODEL.SEG.USE_PPM:
# PPM Module
pool_scales=(2, 4, 8)
fc_dim = 256
self.ppm_pooling = []
self.ppm_conv = []
for scale in pool_scales:
self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale))
self.ppm_conv.append(nn.Sequential(
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)
))
self.ppm_pooling = nn.ModuleList(self.ppm_pooling)
self.ppm_conv = nn.ModuleList(self.ppm_conv)
self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, ndim, 1)
self.ppm_conv.apply(self.weights_init)
self.ppm_last_conv.apply(self.weights_init)
self.fpn_out5.apply(self.weights_init)
self.fpn_out4.apply(self.weights_init)
self.fpn_out3.apply(self.weights_init)
self.fpn_out2.apply(self.weights_init)
self.seg_out.apply(self.weights_init)
def forward(self, x):
if self.cfg.MODEL.SEG.USE_PPM:
conv5 = x[-2]
input_size = conv5.size()
ppm_out = [conv5]
for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv):
ppm_out.append(pool_conv(nn.functional.interpolate(
pool_scale(conv5),
(input_size[2], input_size[3]),
mode='bilinear', align_corners=False)))
ppm_out = torch.cat(ppm_out, 1)
f = self.ppm_last_conv(ppm_out)
else:
f = x[-2]
# p5 = self.fpn_out5(x[-2])
p5 = self.fpn_out5(f)
p4 = self.fpn_out4(x[-3])
p3 = self.fpn_out3(x[-4])
p2 = self.fpn_out2(x[-5])
fuse = torch.cat((p5, p4, p3, p2), 1)
out = self.seg_out(fuse)
return out, fuse
def weights_init(self, m):
classname = m.__class__.__name__
if classname.find("Conv") != -1:
nn.init.kaiming_normal_(m.weight.data)
elif classname.find("BatchNorm") != -1:
m.weight.data.fill_(1.0)
m.bias.data.fill_(1e-4)
class SEGModule(torch.nn.Module):
"""
Module for RPN computation. Takes feature maps from the backbone and RPN
proposals and losses. Works for both FPN and non-FPN.
"""
def __init__(self, cfg):
super(SEGModule, self).__init__()
self.cfg = cfg.clone()
in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
head = SEGHead(in_channels, cfg)
box_selector_train = make_seg_postprocessor(cfg, is_train=True)
box_selector_test = make_seg_postprocessor(cfg, is_train=False)
loss_evaluator = make_seg_loss_evaluator(cfg)
# self.anchor_generator = anchor_generator
self.head = head
self.box_selector_train = box_selector_train
self.box_selector_test = box_selector_test
self.loss_evaluator = loss_evaluator
def forward(self, images, features, targets=None):
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
features (Tensor): fused feature from FPN
targets (Tensor): segmentaion gt map
Returns:
boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
image.
losses (dict[Tensor]): the losses for the model during training. During
testing, it is an empty dict.
"""
preds, fuse_feature = self.head(features)
# anchors = self.anchor_generator(images, features)
image_shapes = images.get_sizes()
if self.training:
return self._forward_train(preds, targets, image_shapes), [fuse_feature]
else:
return self._forward_test(preds, image_shapes), [fuse_feature]
def _forward_train(self, preds, targets, image_shapes):
# Segmentation map must be transformed into boxes for detection.
# sampled into a training batch.
with torch.no_grad():
boxes = self.box_selector_train(preds, image_shapes, targets)
loss_seg = self.loss_evaluator(preds, targets)
losses = {"loss_seg": loss_seg}
return boxes, losses
def _forward_test(self, preds, image_shapes):
# torch.cuda.synchronize()
# start_time = time.time()
boxes, rotated_boxes, polygons, scores = self.box_selector_test(preds, image_shapes)
# torch.cuda.synchronize()
# end_time = time.time()
# print('post time:', end_time - start_time)
seg_results = {'rotated_boxes': rotated_boxes, 'polygons': polygons, 'preds': preds, 'scores': scores}
return boxes, seg_results
def build_segmentation(cfg):
"""
This gives the gist of it. Not super important because it doesn't change as much
"""
return SEGModule(cfg)