Spaces:
Build error
Build error
File size: 8,718 Bytes
708dec4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 |
import math
import torch
import torch.nn.functional as F
from torch import nn
from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.layers import Scale, DFConv2d
from .loss import make_fcos_loss_evaluator
from .anchor_generator import make_center_anchor_generator
from .inference import make_fcos_postprocessor
@registry.RPN_HEADS.register("FCOSHead")
class FCOSHead(torch.nn.Module):
def __init__(self, cfg):
super(FCOSHead, self).__init__()
# TODO: Implement the sigmoid version first.
num_classes = cfg.MODEL.FCOS.NUM_CLASSES - 1
in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
use_gn = cfg.MODEL.FCOS.USE_GN
use_bn = cfg.MODEL.FCOS.USE_BN
use_dcn_in_tower = cfg.MODEL.FCOS.USE_DFCONV
self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
self.norm_reg_targets = cfg.MODEL.FCOS.NORM_REG_TARGETS
self.centerness_on_reg = cfg.MODEL.FCOS.CENTERNESS_ON_REG
cls_tower = []
bbox_tower = []
for i in range(cfg.MODEL.FCOS.NUM_CONVS):
if use_dcn_in_tower and \
i == cfg.MODEL.FCOS.NUM_CONVS - 1:
conv_func = DFConv2d
else:
conv_func = nn.Conv2d
cls_tower.append(
conv_func(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True
)
)
if use_gn:
cls_tower.append(nn.GroupNorm(32, in_channels))
if use_bn:
cls_tower.append(nn.BatchNorm2d(in_channels))
cls_tower.append(nn.ReLU())
bbox_tower.append(
conv_func(
in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1,
bias=True
)
)
if use_gn:
bbox_tower.append(nn.GroupNorm(32, in_channels))
if use_bn:
bbox_tower.append(nn.BatchNorm2d(in_channels))
bbox_tower.append(nn.ReLU())
self.add_module('cls_tower', nn.Sequential(*cls_tower))
self.add_module('bbox_tower', nn.Sequential(*bbox_tower))
self.cls_logits = nn.Conv2d(
in_channels, num_classes, kernel_size=3, stride=1,
padding=1
)
self.bbox_pred = nn.Conv2d(
in_channels, 4, kernel_size=3, stride=1,
padding=1
)
self.centerness = nn.Conv2d(
in_channels, 1, kernel_size=3, stride=1,
padding=1
)
# initialization
for modules in [self.cls_tower, self.bbox_tower,
self.cls_logits, self.bbox_pred,
self.centerness]:
for l in modules.modules():
if isinstance(l, nn.Conv2d):
torch.nn.init.normal_(l.weight, std=0.01)
torch.nn.init.constant_(l.bias, 0)
# initialize the bias for focal loss
prior_prob = cfg.MODEL.FCOS.PRIOR_PROB
bias_value = -math.log((1 - prior_prob) / prior_prob)
torch.nn.init.constant_(self.cls_logits.bias, bias_value)
self.scales = nn.ModuleList([Scale(init_value=1.0) for _ in range(5)])
def forward(self, x):
logits = []
bbox_reg = []
centerness = []
for l, feature in enumerate(x):
cls_tower = self.cls_tower(feature)
box_tower = self.bbox_tower(feature)
logits.append(self.cls_logits(cls_tower))
if self.centerness_on_reg:
centerness.append(self.centerness(box_tower))
else:
centerness.append(self.centerness(cls_tower))
bbox_pred = self.scales[l](self.bbox_pred(box_tower))
if self.norm_reg_targets:
bbox_pred = F.relu(bbox_pred)
if self.training:
bbox_reg.append(bbox_pred)
else:
bbox_reg.append(bbox_pred * self.fpn_strides[l])
else:
bbox_reg.append(torch.exp(bbox_pred))
return logits, bbox_reg, centerness
class FCOSModule(torch.nn.Module):
"""
Module for FCOS computation. Takes feature maps from the backbone and
FCOS outputs and losses. Only Test on FPN now.
"""
def __init__(self, cfg):
super(FCOSModule, self).__init__()
head = FCOSHead(cfg)
box_selector_train = make_fcos_postprocessor(cfg, is_train=True)
box_selector_test = make_fcos_postprocessor(cfg, is_train=False)
loss_evaluator = make_fcos_loss_evaluator(cfg)
self.cfg = cfg
self.head = head
self.box_selector_train = box_selector_train
self.box_selector_test = box_selector_test
self.loss_evaluator = loss_evaluator
self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES
if not cfg.MODEL.RPN_ONLY:
self.anchor_generator = make_center_anchor_generator(cfg)
def forward(self, images, features, targets=None):
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
features (list[Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list
correspond to different feature levels
targets (list[BoxList): ground-truth boxes present in the image (optional)
Returns:
boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per
image.
losses (dict[Tensor]): the losses for the model during training. During
testing, it is an empty dict.
"""
box_cls, box_regression, centerness = self.head(features)
locations = self.compute_locations(features)
if self.training and targets is not None:
return self._forward_train(
locations, box_cls, box_regression,
centerness, targets, images.image_sizes
)
else:
return self._forward_test(
locations, box_cls, box_regression,
centerness, images.image_sizes
)
def _forward_train(self, locations, box_cls, box_regression, centerness, targets, image_sizes=None):
loss_box_cls, loss_box_reg, loss_centerness = self.loss_evaluator(
locations, box_cls, box_regression, centerness, targets
)
losses = {
"loss_cls": loss_box_cls,
"loss_reg": loss_box_reg,
"loss_centerness": loss_centerness
}
if self.cfg.MODEL.RPN_ONLY:
return None, losses
else:
boxes = self.box_selector_train(
locations, box_cls, box_regression,
centerness, image_sizes
)
proposals = self.anchor_generator(boxes, image_sizes, centerness)
return proposals, losses
def _forward_test(self, locations, box_cls, box_regression, centerness, image_sizes):
boxes = self.box_selector_test(
locations, box_cls, box_regression,
centerness, image_sizes
)
if not self.cfg.MODEL.RPN_ONLY:
boxes = self.anchor_generator(boxes, image_sizes, centerness)
return boxes, {}
def compute_locations(self, features):
locations = []
for level, feature in enumerate(features):
h, w = feature.size()[-2:]
locations_per_level = self.compute_locations_per_level(
h, w, self.fpn_strides[level],
feature.device
)
locations.append(locations_per_level)
return locations
def compute_locations_per_level(self, h, w, stride, device):
shifts_x = torch.arange(
0, w * stride, step=stride,
dtype=torch.float32, device=device
)
shifts_y = torch.arange(
0, h * stride, step=stride,
dtype=torch.float32, device=device
)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2
return locations
|