Spaces:
Build error
Build error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
import torch | |
import torch.nn.functional as F | |
from torch import nn | |
from maskrcnn_benchmark.modeling import registry | |
from maskrcnn_benchmark.modeling.box_coder import BoxCoder | |
from .loss import make_rpn_loss_evaluator | |
from .anchor_generator import make_anchor_generator | |
from .inference import make_rpn_postprocessor | |
class mRPNHead(nn.Module): | |
""" | |
Adds a simple RPN Head with classification and regression heads | |
""" | |
def __init__(self, cfg, in_channels, num_anchors): | |
""" | |
Arguments: | |
cfg : config | |
in_channels (int): number of channels of the input feature | |
num_anchors (int): number of anchors to be predicted | |
""" | |
super(mRPNHead, self).__init__() | |
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) | |
self.bbox_pred = nn.Conv2d( | |
in_channels, num_anchors * 4, kernel_size=1, stride=1 | |
) | |
for l in [self.cls_logits, self.bbox_pred]: | |
torch.nn.init.normal_(l.weight, std=0.01) | |
torch.nn.init.constant_(l.bias, 0) | |
def forward(self, x): | |
logits = [] | |
bbox_reg = [] | |
for feature in x: | |
t = F.relu(feature) | |
logits.append(self.cls_logits(t)) | |
bbox_reg.append(self.bbox_pred(t)) | |
return logits, bbox_reg | |
class RPNHead(nn.Module): | |
""" | |
Adds a simple RPN Head with classification and regression heads | |
""" | |
def __init__(self, cfg, in_channels, num_anchors): | |
""" | |
Arguments: | |
cfg : config | |
in_channels (int): number of channels of the input feature | |
num_anchors (int): number of anchors to be predicted | |
""" | |
super(RPNHead, self).__init__() | |
self.conv = nn.Conv2d( | |
in_channels, in_channels, kernel_size=3, stride=1, padding=1 | |
) | |
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1) | |
self.bbox_pred = nn.Conv2d( | |
in_channels, num_anchors * 4, kernel_size=1, stride=1 | |
) | |
for l in [self.conv, self.cls_logits, self.bbox_pred]: | |
torch.nn.init.normal_(l.weight, std=0.01) | |
torch.nn.init.constant_(l.bias, 0) | |
def forward(self, x): | |
logits = [] | |
bbox_reg = [] | |
for feature in x: | |
t = F.relu(self.conv(feature)) | |
logits.append(self.cls_logits(t)) | |
bbox_reg.append(self.bbox_pred(t)) | |
return logits, bbox_reg | |
class RPNModule(torch.nn.Module): | |
""" | |
Module for RPN computation. Takes feature maps from the backbone and RPN | |
proposals and losses. Works for both FPN and non-FPN. | |
""" | |
def __init__(self, cfg): | |
super(RPNModule, self).__init__() | |
self.cfg = cfg.clone() | |
anchor_generator = make_anchor_generator(cfg) | |
in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS | |
rpn_head = registry.RPN_HEADS[cfg.MODEL.RPN.RPN_HEAD] | |
head = rpn_head( | |
cfg, in_channels, anchor_generator.num_anchors_per_location()[0] | |
) | |
rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0)) | |
box_selector_train = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=True) | |
box_selector_test = make_rpn_postprocessor(cfg, rpn_box_coder, is_train=False) | |
loss_evaluator = make_rpn_loss_evaluator(cfg, rpn_box_coder) | |
self.anchor_generator = anchor_generator | |
self.head = head | |
self.box_selector_train = box_selector_train | |
self.box_selector_test = box_selector_test | |
self.loss_evaluator = loss_evaluator | |
def forward(self, images, features, targets=None): | |
""" | |
Arguments: | |
images (ImageList): images for which we want to compute the predictions | |
features (list[Tensor]): features computed from the images that are | |
used for computing the predictions. Each tensor in the list | |
correspond to different feature levels | |
targets (list[BoxList): ground-truth boxes present in the image (optional) | |
Returns: | |
boxes (list[BoxList]): the predicted boxes from the RPN, one BoxList per | |
image. | |
losses (dict[Tensor]): the losses for the model during training. During | |
testing, it is an empty dict. | |
""" | |
objectness, rpn_box_regression = self.head(features) | |
anchors = self.anchor_generator(images, features) | |
if self.training: | |
return self._forward_train(anchors, objectness, rpn_box_regression, targets) | |
else: | |
return self._forward_test(anchors, objectness, rpn_box_regression) | |
def _forward_train(self, anchors, objectness, rpn_box_regression, targets): | |
if self.cfg.MODEL.RPN_ONLY: | |
# When training an RPN-only model, the loss is determined by the | |
# predicted objectness and rpn_box_regression values and there is | |
# no need to transform the anchors into predicted boxes; this is an | |
# optimization that avoids the unnecessary transformation. | |
boxes = anchors | |
else: | |
# For end-to-end models, anchors must be transformed into boxes and | |
# sampled into a training batch. | |
with torch.no_grad(): | |
boxes = self.box_selector_train( | |
anchors, objectness, rpn_box_regression, targets | |
) | |
loss_objectness, loss_rpn_box_reg = self.loss_evaluator( | |
anchors, objectness, rpn_box_regression, targets | |
) | |
losses = { | |
"loss_objectness": loss_objectness, | |
"loss_rpn_box_reg": loss_rpn_box_reg, | |
} | |
return boxes, losses | |
def _forward_test(self, anchors, objectness, rpn_box_regression): | |
boxes = self.box_selector_test(anchors, objectness, rpn_box_regression) | |
if self.cfg.MODEL.RPN_ONLY: | |
# For end-to-end models, the RPN proposals are an intermediate state | |
# and don't bother to sort them in decreasing score order. For RPN-only | |
# models, the proposals are the final output and we return them in | |
# high-to-low confidence order. | |
inds = [ | |
box.get_field("objectness").sort(descending=True)[1] for box in boxes | |
] | |
boxes = [box[ind] for box, ind in zip(boxes, inds)] | |
return boxes, {} |