Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
import torch | |
from torch import nn | |
from .inference import make_seg_postprocessor | |
from .loss import make_seg_loss_evaluator | |
import time | |
def conv3x3(in_planes, out_planes, stride=1, has_bias=False): | |
"3x3 convolution with padding" | |
return nn.Conv2d( | |
in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=has_bias | |
) | |
def conv3x3_bn_relu(in_planes, out_planes, stride=1, has_bias=False): | |
return nn.Sequential( | |
conv3x3(in_planes, out_planes, stride), | |
nn.BatchNorm2d(out_planes), | |
nn.ReLU(inplace=True), | |
) | |
class SEGHead(nn.Module): | |
""" | |
Adds a simple SEG Head with pixel-level prediction | |
""" | |
def __init__(self, in_channels, cfg): | |
""" | |
Arguments: | |
in_channels (int): number of channels of the input feature | |
""" | |
super(SEGHead, self).__init__() | |
self.cfg = cfg | |
ndim = 256 | |
self.fpn_out5 = nn.Sequential( | |
conv3x3(ndim, 64), nn.Upsample(scale_factor=8, mode="nearest") | |
) | |
self.fpn_out4 = nn.Sequential( | |
conv3x3(ndim, 64), nn.Upsample(scale_factor=4, mode="nearest") | |
) | |
self.fpn_out3 = nn.Sequential( | |
conv3x3(ndim, 64), nn.Upsample(scale_factor=2, mode="nearest") | |
) | |
self.fpn_out2 = conv3x3(ndim, 64) | |
self.seg_out = nn.Sequential( | |
conv3x3_bn_relu(in_channels, 64, 1), | |
nn.ConvTranspose2d(64, 64, 2, 2), | |
nn.BatchNorm2d(64), | |
nn.ReLU(True), | |
nn.ConvTranspose2d(64, 1, 2, 2), | |
nn.Sigmoid(), | |
) | |
if self.cfg.MODEL.SEG.USE_PPM: | |
# PPM Module | |
pool_scales=(2, 4, 8) | |
fc_dim = 256 | |
self.ppm_pooling = [] | |
self.ppm_conv = [] | |
for scale in pool_scales: | |
self.ppm_pooling.append(nn.AdaptiveAvgPool2d(scale)) | |
self.ppm_conv.append(nn.Sequential( | |
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False), | |
nn.BatchNorm2d(512), | |
nn.ReLU(inplace=True) | |
)) | |
self.ppm_pooling = nn.ModuleList(self.ppm_pooling) | |
self.ppm_conv = nn.ModuleList(self.ppm_conv) | |
self.ppm_last_conv = conv3x3_bn_relu(fc_dim + len(pool_scales)*512, ndim, 1) | |
self.ppm_conv.apply(self.weights_init) | |
self.ppm_last_conv.apply(self.weights_init) | |
self.fpn_out5.apply(self.weights_init) | |
self.fpn_out4.apply(self.weights_init) | |
self.fpn_out3.apply(self.weights_init) | |
self.fpn_out2.apply(self.weights_init) | |
self.seg_out.apply(self.weights_init) | |
def forward(self, x): | |
if self.cfg.MODEL.SEG.USE_PPM: | |
conv5 = x[-2] | |
input_size = conv5.size() | |
ppm_out = [conv5] | |
for pool_scale, pool_conv in zip(self.ppm_pooling, self.ppm_conv): | |
ppm_out.append(pool_conv(nn.functional.interpolate( | |
pool_scale(conv5), | |
(input_size[2], input_size[3]), | |
mode='bilinear', align_corners=False))) | |
ppm_out = torch.cat(ppm_out, 1) | |
f = self.ppm_last_conv(ppm_out) | |
else: | |
f = x[-2] | |
# p5 = self.fpn_out5(x[-2]) | |
p5 = self.fpn_out5(f) | |
p4 = self.fpn_out4(x[-3]) | |
p3 = self.fpn_out3(x[-4]) | |
p2 = self.fpn_out2(x[-5]) | |
fuse = torch.cat((p5, p4, p3, p2), 1) | |
out = self.seg_out(fuse) | |
return out, fuse | |
def weights_init(self, m): | |
classname = m.__class__.__name__ | |
if classname.find("Conv") != -1: | |
nn.init.kaiming_normal_(m.weight.data) | |
elif classname.find("BatchNorm") != -1: | |
m.weight.data.fill_(1.0) | |
m.bias.data.fill_(1e-4) | |
class SEGModule(torch.nn.Module): | |
""" | |
Module for RPN computation. Takes feature maps from the backbone and RPN | |
proposals and losses. Works for both FPN and non-FPN. | |
""" | |
def __init__(self, cfg): | |
super(SEGModule, self).__init__() | |
self.cfg = cfg.clone() | |
in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS | |
head = SEGHead(in_channels, cfg) | |
box_selector_train = make_seg_postprocessor(cfg, is_train=True) | |
box_selector_test = make_seg_postprocessor(cfg, is_train=False) | |
loss_evaluator = make_seg_loss_evaluator(cfg) | |
# self.anchor_generator = anchor_generator | |
self.head = head | |
self.box_selector_train = box_selector_train | |
self.box_selector_test = box_selector_test | |
self.loss_evaluator = loss_evaluator | |
def forward(self, images, features, targets=None): | |
""" | |
Arguments: | |
images (ImageList): images for which we want to compute the predictions | |
features (Tensor): fused feature from FPN | |
targets (Tensor): segmentaion gt map | |
Returns: | |
boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per | |
image. | |
losses (dict[Tensor]): the losses for the model during training. During | |
testing, it is an empty dict. | |
""" | |
preds, fuse_feature = self.head(features) | |
# anchors = self.anchor_generator(images, features) | |
image_shapes = images.get_sizes() | |
if self.training: | |
return self._forward_train(preds, targets, image_shapes), [fuse_feature] | |
else: | |
return self._forward_test(preds, image_shapes), [fuse_feature] | |
def _forward_train(self, preds, targets, image_shapes): | |
# Segmentation map must be transformed into boxes for detection. | |
# sampled into a training batch. | |
with torch.no_grad(): | |
boxes = self.box_selector_train(preds, image_shapes, targets) | |
loss_seg = self.loss_evaluator(preds, targets) | |
losses = {"loss_seg": loss_seg} | |
return boxes, losses | |
def _forward_test(self, preds, image_shapes): | |
# torch.cuda.synchronize() | |
# start_time = time.time() | |
boxes, rotated_boxes, polygons, scores = self.box_selector_test(preds, image_shapes) | |
# torch.cuda.synchronize() | |
# end_time = time.time() | |
# print('post time:', end_time - start_time) | |
seg_results = {'rotated_boxes': rotated_boxes, 'polygons': polygons, 'preds': preds, 'scores': scores} | |
return boxes, seg_results | |
def build_segmentation(cfg): | |
""" | |
This gives the gist of it. Not super important because it doesn't change as much | |
""" | |
return SEGModule(cfg) | |