Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. | |
""" | |
This file contains specific functions for computing losses on the RPN | |
file | |
""" | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from ..balanced_positive_negative_sampler import BalancedPositiveNegativeSampler | |
from ..utils import cat, concat_box_prediction_layers | |
from maskrcnn_benchmark.layers import smooth_l1_loss | |
from maskrcnn_benchmark.modeling.matcher import Matcher | |
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou | |
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist | |
from maskrcnn_benchmark.layers import SigmoidFocalLoss, IOULoss, TokenSigmoidFocalLoss | |
from maskrcnn_benchmark.utils.comm import get_world_size, reduce_sum | |
from maskrcnn_benchmark.utils.amp import custom_fwd, custom_bwd | |
from maskrcnn_benchmark.utils.shallow_contrastive_loss_helper import * | |
from transformers import AutoTokenizer | |
INF = 1e8 | |
class RPNLossComputation(object): | |
""" | |
This class computes the RPN loss. | |
""" | |
def __init__(self, proposal_matcher, fg_bg_sampler, box_coder): | |
""" | |
Arguments: | |
proposal_matcher (Matcher) | |
fg_bg_sampler (BalancedPositiveNegativeSampler) | |
box_coder (BoxCoder) | |
""" | |
# self.target_preparator = target_preparator | |
self.proposal_matcher = proposal_matcher | |
self.fg_bg_sampler = fg_bg_sampler | |
self.box_coder = box_coder | |
def match_targets_to_anchors(self, anchor, target): | |
match_quality_matrix = boxlist_iou(target, anchor) | |
matched_idxs = self.proposal_matcher(match_quality_matrix) | |
# RPN doesn't need any fields from target | |
# for creating the labels, so clear them all | |
target = target.copy_with_fields([]) | |
# get the targets corresponding GT for each anchor | |
# NB: need to clamp the indices because we can have a single | |
# GT in the image, and matched_idxs can be -2, which goes | |
# out of bounds | |
if len(target): | |
matched_targets = target[matched_idxs.clamp(min=0)] | |
else: | |
matched_targets = target | |
matched_targets.add_field("matched_idxs", matched_idxs) | |
return matched_targets | |
def prepare_targets(self, anchors, targets): | |
labels = [] | |
regression_targets = [] | |
for anchors_per_image, targets_per_image in zip(anchors, targets): | |
matched_targets = self.match_targets_to_anchors( | |
anchors_per_image, targets_per_image | |
) | |
matched_idxs = matched_targets.get_field("matched_idxs") | |
labels_per_image = matched_idxs >= 0 | |
labels_per_image = labels_per_image.to(dtype=torch.float32) | |
# discard anchors that go out of the boundaries of the image | |
labels_per_image[~anchors_per_image.get_field("visibility")] = -1 | |
# discard indices that are between thresholds | |
inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS | |
labels_per_image[inds_to_discard] = -1 | |
# compute regression targets | |
if not matched_targets.bbox.shape[0]: | |
zeros = torch.zeros_like(labels_per_image) | |
regression_targets_per_image = torch.stack((zeros, zeros, zeros, zeros), dim=1) | |
else: | |
regression_targets_per_image = self.box_coder.encode(matched_targets.bbox, anchors_per_image.bbox) | |
labels.append(labels_per_image) | |
regression_targets.append(regression_targets_per_image) | |
return labels, regression_targets | |
def __call__(self, anchors, objectness, box_regression, targets): | |
""" | |
Arguments: | |
anchors (list[BoxList]) | |
objectness (list[Tensor]) | |
box_regression (list[Tensor]) | |
targets (list[BoxList]) | |
Returns: | |
objectness_loss (Tensor) | |
box_loss (Tensor | |
""" | |
anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors] | |
labels, regression_targets = self.prepare_targets(anchors, targets) | |
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) | |
sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1) | |
sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1) | |
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0) | |
objectness_flattened = [] | |
box_regression_flattened = [] | |
# for each feature level, permute the outputs to make them be in the | |
# same format as the labels. Note that the labels are computed for | |
# all feature levels concatenated, so we keep the same representation | |
# for the objectness and the box_regression | |
for objectness_per_level, box_regression_per_level in zip( | |
objectness, box_regression | |
): | |
N, A, H, W = objectness_per_level.shape | |
objectness_per_level = objectness_per_level.permute(0, 2, 3, 1).reshape( | |
N, -1 | |
) | |
box_regression_per_level = box_regression_per_level.view(N, -1, 4, H, W) | |
box_regression_per_level = box_regression_per_level.permute(0, 3, 4, 1, 2) | |
box_regression_per_level = box_regression_per_level.reshape(N, -1, 4) | |
objectness_flattened.append(objectness_per_level) | |
box_regression_flattened.append(box_regression_per_level) | |
# concatenate on the first dimension (representing the feature levels), to | |
# take into account the way the labels were generated (with all feature maps | |
# being concatenated as well) | |
objectness = cat(objectness_flattened, dim=1).reshape(-1) | |
box_regression = cat(box_regression_flattened, dim=1).reshape(-1, 4) | |
labels = torch.cat(labels, dim=0) | |
regression_targets = torch.cat(regression_targets, dim=0) | |
box_loss = smooth_l1_loss( | |
box_regression[sampled_pos_inds], | |
regression_targets[sampled_pos_inds], | |
beta=1.0 / 9, | |
size_average=False, | |
) / (sampled_inds.numel()) | |
objectness_loss = F.binary_cross_entropy_with_logits( | |
objectness[sampled_inds], labels[sampled_inds] | |
) | |
return objectness_loss, box_loss | |
class FocalLossComputation(object): | |
""" | |
This class computes the RetinaNet loss. | |
""" | |
def __init__(self, proposal_matcher, box_coder, | |
generate_labels_func, | |
sigmoid_focal_loss, | |
bbox_reg_beta=0.11, | |
regress_norm=1.0): | |
""" | |
Arguments: | |
proposal_matcher (Matcher) | |
box_coder (BoxCoder) | |
""" | |
self.proposal_matcher = proposal_matcher | |
self.box_coder = box_coder | |
self.box_cls_loss_func = sigmoid_focal_loss | |
self.bbox_reg_beta = bbox_reg_beta | |
self.copied_fields = ['labels'] | |
self.generate_labels_func = generate_labels_func | |
self.discard_cases = ['between_thresholds'] | |
self.regress_norm = regress_norm | |
def match_targets_to_anchors(self, anchor, target, copied_fields=[]): | |
match_quality_matrix = boxlist_iou(target, anchor) | |
matched_idxs = self.proposal_matcher(match_quality_matrix) | |
# RPN doesn't need any fields from target | |
# for creating the labels, so clear them all | |
target = target.copy_with_fields(copied_fields) | |
# get the targets corresponding GT for each anchor | |
# NB: need to clamp the indices because we can have a single | |
# GT in the image, and matched_idxs can be -2, which goes | |
# out of bounds | |
matched_targets = target[matched_idxs.clamp(min=0)] | |
matched_targets.add_field("matched_idxs", matched_idxs) | |
return matched_targets | |
def prepare_targets(self, anchors, targets): | |
labels = [] | |
regression_targets = [] | |
for anchors_per_image, targets_per_image in zip(anchors, targets): | |
matched_targets = self.match_targets_to_anchors( | |
anchors_per_image, targets_per_image, self.copied_fields | |
) | |
matched_idxs = matched_targets.get_field("matched_idxs") | |
labels_per_image = self.generate_labels_func(matched_targets) | |
labels_per_image = labels_per_image.to(dtype=torch.float32) | |
# Background (negative examples) | |
bg_indices = matched_idxs == Matcher.BELOW_LOW_THRESHOLD | |
labels_per_image[bg_indices] = 0 | |
# discard anchors that go out of the boundaries of the image | |
if "not_visibility" in self.discard_cases: | |
labels_per_image[~anchors_per_image.get_field("visibility")] = -1 | |
# discard indices that are between thresholds | |
if "between_thresholds" in self.discard_cases: | |
inds_to_discard = matched_idxs == Matcher.BETWEEN_THRESHOLDS | |
labels_per_image[inds_to_discard] = -1 | |
# compute regression targets | |
regression_targets_per_image = self.box_coder.encode( | |
matched_targets.bbox, anchors_per_image.bbox | |
) | |
labels.append(labels_per_image) | |
regression_targets.append(regression_targets_per_image) | |
return labels, regression_targets | |
def __call__(self, anchors, box_cls, box_regression, targets): | |
""" | |
Arguments: | |
anchors (list[BoxList]) | |
box_cls (list[Tensor]) | |
box_regression (list[Tensor]) | |
targets (list[BoxList]) | |
Returns: | |
retinanet_cls_loss (Tensor) | |
retinanet_regression_loss (Tensor | |
""" | |
anchors = [cat_boxlist(anchors_per_image) for anchors_per_image in anchors] | |
labels, regression_targets = self.prepare_targets(anchors, targets) | |
N = len(labels) | |
box_cls, box_regression = \ | |
concat_box_prediction_layers(box_cls, box_regression) | |
labels = torch.cat(labels, dim=0) | |
regression_targets = torch.cat(regression_targets, dim=0) | |
pos_inds = torch.nonzero(labels > 0).squeeze(1) | |
retinanet_regression_loss = smooth_l1_loss( | |
box_regression[pos_inds], | |
regression_targets[pos_inds], | |
beta=self.bbox_reg_beta, | |
size_average=False, | |
) / (max(1, pos_inds.numel() * self.regress_norm)) | |
labels = labels.int() | |
retinanet_cls_loss = self.box_cls_loss_func( | |
box_cls, | |
labels | |
) / (pos_inds.numel() + N) | |
return retinanet_cls_loss, retinanet_regression_loss | |
class FCOSLossComputation(object): | |
""" | |
This class computes the FCOS losses. | |
""" | |
def __init__(self, cfg): | |
self.cls_loss_func = SigmoidFocalLoss( | |
cfg.MODEL.FOCAL.LOSS_GAMMA, | |
cfg.MODEL.FOCAL.LOSS_ALPHA | |
) | |
self.fpn_strides = cfg.MODEL.FCOS.FPN_STRIDES | |
self.center_sampling_radius = cfg.MODEL.FCOS.CENTER_SAMPLING_RADIUS | |
self.iou_loss_type = cfg.MODEL.FCOS.IOU_LOSS_TYPE | |
self.norm_reg_targets = cfg.MODEL.FCOS.NORM_REG_TARGETS | |
self.use_gt_center = cfg.MODEL.FCOS.USE_GT_CENTER | |
# we make use of IOU Loss for bounding boxes regression, | |
# but we found that L1 in log scale can yield a similar performance | |
self.box_reg_loss_func = IOULoss(self.iou_loss_type) | |
self.centerness_loss_func = torch.nn.BCEWithLogitsLoss(reduction="sum") | |
def get_sample_region(self, gt, strides, num_points_per, gt_xs, gt_ys, radius=1.0): | |
''' | |
This code is from | |
https://github.com/yqyao/FCOS_PLUS/blob/0d20ba34ccc316650d8c30febb2eb40cb6eaae37/ | |
maskrcnn_benchmark/modeling/rpn/fcos/loss.py#L42 | |
''' | |
num_gts = gt.shape[0] | |
K = len(gt_xs) | |
gt = gt[None].expand(K, num_gts, 4) | |
center_x = (gt[..., 0] + gt[..., 2]) / 2 | |
center_y = (gt[..., 1] + gt[..., 3]) / 2 | |
center_gt = gt.new_zeros(gt.shape) | |
# no gt | |
if center_x[..., 0].sum() == 0: | |
return gt_xs.new_zeros(gt_xs.shape, dtype=torch.uint8) | |
beg = 0 | |
for level, n_p in enumerate(num_points_per): | |
end = beg + n_p | |
stride = strides[level] * radius | |
xmin = center_x[beg:end] - stride | |
ymin = center_y[beg:end] - stride | |
xmax = center_x[beg:end] + stride | |
ymax = center_y[beg:end] + stride | |
# limit sample region in gt | |
center_gt[beg:end, :, 0] = torch.where( | |
xmin > gt[beg:end, :, 0], xmin, gt[beg:end, :, 0] | |
) | |
center_gt[beg:end, :, 1] = torch.where( | |
ymin > gt[beg:end, :, 1], ymin, gt[beg:end, :, 1] | |
) | |
center_gt[beg:end, :, 2] = torch.where( | |
xmax > gt[beg:end, :, 2], | |
gt[beg:end, :, 2], xmax | |
) | |
center_gt[beg:end, :, 3] = torch.where( | |
ymax > gt[beg:end, :, 3], | |
gt[beg:end, :, 3], ymax | |
) | |
beg = end | |
left = gt_xs[:, None] - center_gt[..., 0] | |
right = center_gt[..., 2] - gt_xs[:, None] | |
top = gt_ys[:, None] - center_gt[..., 1] | |
bottom = center_gt[..., 3] - gt_ys[:, None] | |
center_bbox = torch.stack((left, top, right, bottom), -1) | |
inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0 | |
return inside_gt_bbox_mask | |
def prepare_targets(self, points, targets): | |
object_sizes_of_interest = [ | |
[-1, 64], | |
[64, 128], | |
[128, 256], | |
[256, 512], | |
[512, INF], | |
] | |
expanded_object_sizes_of_interest = [] | |
for l, points_per_level in enumerate(points): | |
object_sizes_of_interest_per_level = \ | |
points_per_level.new_tensor(object_sizes_of_interest[l]) | |
expanded_object_sizes_of_interest.append( | |
object_sizes_of_interest_per_level[None].expand(len(points_per_level), -1) | |
) | |
expanded_object_sizes_of_interest = torch.cat(expanded_object_sizes_of_interest, dim=0) | |
num_points_per_level = [len(points_per_level) for points_per_level in points] | |
self.num_points_per_level = num_points_per_level | |
points_all_level = torch.cat(points, dim=0) | |
labels, reg_targets = self.compute_targets_for_locations( | |
points_all_level, targets, expanded_object_sizes_of_interest | |
) | |
for i in range(len(labels)): | |
labels[i] = torch.split(labels[i], num_points_per_level, dim=0) | |
reg_targets[i] = torch.split(reg_targets[i], num_points_per_level, dim=0) | |
labels_level_first = [] | |
reg_targets_level_first = [] | |
for level in range(len(points)): | |
labels_level_first.append( | |
torch.cat([labels_per_im[level] for labels_per_im in labels], dim=0) | |
) | |
reg_targets_per_level = torch.cat([ | |
reg_targets_per_im[level] | |
for reg_targets_per_im in reg_targets | |
], dim=0) | |
if self.norm_reg_targets: | |
reg_targets_per_level = reg_targets_per_level / self.fpn_strides[level] | |
reg_targets_level_first.append(reg_targets_per_level) | |
return labels_level_first, reg_targets_level_first | |
def compute_targets_for_locations(self, locations, targets, object_sizes_of_interest): | |
labels = [] | |
reg_targets = [] | |
xs, ys = locations[:, 0], locations[:, 1] | |
for im_i in range(len(targets)): | |
targets_per_im = targets[im_i] | |
assert targets_per_im.mode == "xyxy" | |
if self.use_gt_center: | |
center = targets_per_im.get_field("cbox") | |
bboxes = center.bbox | |
area = center.area() | |
else: | |
bboxes = targets_per_im.bbox | |
area = targets_per_im.area() | |
labels_per_im = targets_per_im.get_field("labels") | |
l = xs[:, None] - bboxes[:, 0][None] | |
t = ys[:, None] - bboxes[:, 1][None] | |
r = bboxes[:, 2][None] - xs[:, None] | |
b = bboxes[:, 3][None] - ys[:, None] | |
reg_targets_per_im = torch.stack([l, t, r, b], dim=2) | |
if self.center_sampling_radius > 0: | |
is_in_boxes = self.get_sample_region( | |
bboxes, | |
self.fpn_strides, | |
self.num_points_per_level, | |
xs, ys, | |
radius=self.center_sampling_radius | |
) | |
else: | |
# no center sampling, it will use all the locations within a ground-truth box | |
is_in_boxes = reg_targets_per_im.min(dim=2)[0] > 0 | |
max_reg_targets_per_im = reg_targets_per_im.max(dim=2)[0] | |
# limit the regression range for each location | |
is_cared_in_the_level = \ | |
(max_reg_targets_per_im >= object_sizes_of_interest[:, [0]]) & \ | |
(max_reg_targets_per_im <= object_sizes_of_interest[:, [1]]) | |
locations_to_gt_area = area[None].repeat(len(locations), 1) | |
locations_to_gt_area[is_in_boxes == 0] = INF | |
locations_to_gt_area[is_cared_in_the_level == 0] = INF | |
# if there are still more than one objects for a location, | |
# we choose the one with minimal area | |
locations_to_min_area, locations_to_gt_inds = locations_to_gt_area.min(dim=1) | |
reg_targets_per_im = reg_targets_per_im[range(len(locations)), locations_to_gt_inds] | |
labels_per_im = labels_per_im[locations_to_gt_inds] | |
labels_per_im[locations_to_min_area == INF] = 0 | |
labels.append(labels_per_im) | |
reg_targets.append(reg_targets_per_im) | |
return labels, reg_targets | |
def compute_centerness_targets(self, reg_targets): | |
left_right = reg_targets[:, [0, 2]] | |
top_bottom = reg_targets[:, [1, 3]] | |
centerness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \ | |
(top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]) | |
return torch.sqrt(centerness) | |
def __call__(self, locations, box_cls, box_regression, centerness, targets): | |
""" | |
Arguments: | |
locations (list[BoxList]) | |
box_cls (list[Tensor]) | |
box_regression (list[Tensor]) | |
centerness (list[Tensor]) | |
targets (list[BoxList]) | |
Returns: | |
cls_loss (Tensor) | |
reg_loss (Tensor) | |
centerness_loss (Tensor) | |
""" | |
N = box_cls[0].size(0) | |
num_classes = box_cls[0].size(1) | |
labels, reg_targets = self.prepare_targets(locations, targets) | |
box_cls_flatten = [] | |
box_regression_flatten = [] | |
centerness_flatten = [] | |
labels_flatten = [] | |
reg_targets_flatten = [] | |
for l in range(len(labels)): | |
box_cls_flatten.append(box_cls[l].permute(0, 2, 3, 1).reshape(-1, num_classes)) | |
box_regression_flatten.append(box_regression[l].permute(0, 2, 3, 1).reshape(-1, 4)) | |
labels_flatten.append(labels[l].reshape(-1)) | |
reg_targets_flatten.append(reg_targets[l].reshape(-1, 4)) | |
centerness_flatten.append(centerness[l].reshape(-1)) | |
box_cls_flatten = torch.cat(box_cls_flatten, dim=0) | |
box_regression_flatten = torch.cat(box_regression_flatten, dim=0) | |
centerness_flatten = torch.cat(centerness_flatten, dim=0) | |
labels_flatten = torch.cat(labels_flatten, dim=0) | |
reg_targets_flatten = torch.cat(reg_targets_flatten, dim=0) | |
pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1) | |
box_regression_flatten = box_regression_flatten[pos_inds] | |
reg_targets_flatten = reg_targets_flatten[pos_inds] | |
centerness_flatten = centerness_flatten[pos_inds] | |
cls_loss = self.cls_loss_func( | |
box_cls_flatten, | |
labels_flatten.int() | |
) / max(pos_inds.numel(), 1.0) | |
if pos_inds.numel() > 0: | |
centerness_targets = self.compute_centerness_targets(reg_targets_flatten) | |
reg_loss = self.box_reg_loss_func( | |
box_regression_flatten, | |
reg_targets_flatten, | |
centerness_targets | |
) / centerness_targets.sum() | |
centerness_loss = self.centerness_loss_func( | |
centerness_flatten, | |
centerness_targets | |
) / max(pos_inds.numel(), 1.0) | |
else: | |
reg_loss = box_regression_flatten.sum() | |
centerness_loss = centerness_flatten.sum() | |
return cls_loss, reg_loss, centerness_loss | |
# class ATSSLossComputation(object): | |
class ATSSLossComputation(torch.nn.Module): | |
def __init__(self, cfg, box_coder): | |
super(ATSSLossComputation, self).__init__() | |
self.cfg = cfg | |
self.cls_loss_func = SigmoidFocalLoss(cfg.MODEL.FOCAL.LOSS_GAMMA, cfg.MODEL.FOCAL.LOSS_ALPHA) | |
self.centerness_loss_func = torch.nn.BCEWithLogitsLoss(reduction="sum") | |
self.matcher = Matcher(cfg.MODEL.FOCAL.FG_IOU_THRESHOLD, cfg.MODEL.FOCAL.BG_IOU_THRESHOLD, True) | |
self.box_coder = box_coder | |
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: | |
self.token_loss_func = TokenSigmoidFocalLoss(cfg.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_ALPHA, | |
cfg.MODEL.DYHEAD.FUSE_CONFIG.TOKEN_GAMMA) | |
self.lang = cfg.MODEL.LANGUAGE_BACKBONE.MODEL_TYPE | |
# self.tokenizer = AutoTokenizer.from_pretrained(self.lang) | |
if self.cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": | |
from transformers import CLIPTokenizerFast | |
# self.tokenizer = build_tokenizer(self.cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE) | |
if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS: | |
print("Reuse token 'ðŁĴij</w>' (token_id = 49404) for mask token!") | |
self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", | |
from_slow=True, mask_token='ðŁĴij</w>') | |
else: | |
self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32", | |
from_slow=True) | |
else: | |
self.tokenizer = AutoTokenizer.from_pretrained(self.lang) | |
# if use shallow contrastive loss | |
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS \ | |
or self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: | |
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS: | |
assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS == False | |
channels = cfg.MODEL.DYHEAD.CHANNELS | |
num_anchors = len(cfg.MODEL.RPN.ASPECT_RATIOS) * cfg.MODEL.RPN.SCALES_PER_OCTAVE | |
shallow_input_dim = channels * num_anchors | |
elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: | |
assert self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS == False | |
shallow_input_dim = cfg.MODEL.SWINT.OUT_CHANNELS[-2] | |
shallow_log_scale = self.cfg.MODEL.DYHEAD.SHALLOW_LOG_SCALE | |
shallow_contrastive_hdim = cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_CONTRASTIVE_HIDDEN_DIM | |
# self.shallow_contrastive_projection_image = nn.Conv2d(channels, num_anchors * shallow_contrastive_hdim, | |
# kernel_size=1) | |
self.shallow_contrastive_projection_image = nn.Linear(shallow_input_dim, shallow_contrastive_hdim, | |
bias=True) | |
self.shallow_contrastive_projection_text = nn.Linear(self.cfg.MODEL.LANGUAGE_BACKBONE.LANG_DIM, | |
shallow_contrastive_hdim, bias=True) | |
self.shallow_log_scale = nn.Parameter(torch.Tensor([shallow_log_scale]), requires_grad=True) | |
# (initialization) if use shallow contrastive loss | |
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS: | |
for modules in [self.shallow_contrastive_projection_image, self.shallow_contrastive_projection_text]: | |
for l in modules.modules(): | |
if isinstance(l, nn.Conv2d): | |
torch.nn.init.normal_(l.weight, std=0.01) | |
torch.nn.init.constant_(l.bias, 0) | |
if isinstance(l, nn.Linear): | |
torch.nn.init.xavier_uniform_(l.weight) | |
l.bias.data.fill_(0) | |
def NllSoftMaxLoss(self, logits, target): | |
loss_ce = -target * logits.log_softmax( | |
-1) # basically, only the those positives with positive target_sim will have losses | |
return loss_ce | |
def ContrastiveAlignLoss(self, logits, positive_map): | |
positive_logits = -logits.masked_fill(~positive_map, 0) | |
negative_logits = logits # .masked_fill(positive_map, -1000000) | |
boxes_with_pos = positive_map.any(2) | |
pos_term = positive_logits.sum(2) | |
neg_term = negative_logits.logsumexp(2) | |
nb_pos = positive_map.sum(2) + 1e-6 | |
box_to_token_loss = ((pos_term / nb_pos + neg_term)).masked_fill(~boxes_with_pos, 0).sum() | |
tokens_with_pos = positive_map.any(1) | |
pos_term = positive_logits.sum(1) | |
neg_term = negative_logits.logsumexp(1) | |
nb_pos = positive_map.sum(1) + 1e-6 | |
tokens_to_boxes_loss = ((pos_term / nb_pos + neg_term)).masked_fill(~tokens_with_pos, 0).sum() | |
tot_loss = (box_to_token_loss + tokens_to_boxes_loss) / 2 | |
return tot_loss | |
def GIoULoss(self, pred, target, anchor, weight=None): | |
pred_boxes = self.box_coder.decode(pred.view(-1, 4), anchor.view(-1, 4)) | |
pred_x1 = pred_boxes[:, 0] | |
pred_y1 = pred_boxes[:, 1] | |
pred_x2 = pred_boxes[:, 2] | |
pred_y2 = pred_boxes[:, 3] | |
pred_x2 = torch.max(pred_x1, pred_x2) | |
pred_y2 = torch.max(pred_y1, pred_y2) | |
pred_area = (pred_x2 - pred_x1) * (pred_y2 - pred_y1) | |
gt_boxes = self.box_coder.decode(target.view(-1, 4), anchor.view(-1, 4)) | |
target_x1 = gt_boxes[:, 0] | |
target_y1 = gt_boxes[:, 1] | |
target_x2 = gt_boxes[:, 2] | |
target_y2 = gt_boxes[:, 3] | |
target_area = (target_x2 - target_x1) * (target_y2 - target_y1) | |
x1_intersect = torch.max(pred_x1, target_x1) | |
y1_intersect = torch.max(pred_y1, target_y1) | |
x2_intersect = torch.min(pred_x2, target_x2) | |
y2_intersect = torch.min(pred_y2, target_y2) | |
area_intersect = torch.zeros(pred_x1.size()).to(pred) | |
mask = (y2_intersect > y1_intersect) * (x2_intersect > x1_intersect) | |
area_intersect[mask] = (x2_intersect[mask] - x1_intersect[mask]) * (y2_intersect[mask] - y1_intersect[mask]) | |
x1_enclosing = torch.min(pred_x1, target_x1) | |
y1_enclosing = torch.min(pred_y1, target_y1) | |
x2_enclosing = torch.max(pred_x2, target_x2) | |
y2_enclosing = torch.max(pred_y2, target_y2) | |
area_enclosing = (x2_enclosing - x1_enclosing) * (y2_enclosing - y1_enclosing) + 1e-7 | |
area_union = pred_area + target_area - area_intersect + 1e-7 | |
ious = area_intersect / area_union | |
gious = ious - (area_enclosing - area_union) / area_enclosing | |
losses = 1 - gious | |
if weight is not None and weight.sum() > 0: | |
return (losses * weight).sum() | |
else: | |
assert losses.numel() != 0 | |
return losses.sum() | |
def prepare_targets(self, targets, anchors, tokenized=None, positive_map=None, proj_tokens=None): | |
cls_labels = [] | |
reg_targets = [] | |
token_labels = [] | |
map_labels = [] | |
gold_box_od_labels = [] | |
od_label_of_tokens_labels = [] | |
positive_indices = [] | |
offset = 0 | |
for im_i in range(len(targets)): | |
targets_per_im = targets[im_i] | |
assert targets_per_im.mode == "xyxy" | |
# bboxes_per_im = targets_per_im.get_field("boxes") | |
bboxes_per_im = targets_per_im.bbox | |
labels_per_im = targets_per_im.get_field("labels") | |
num_gt = len(bboxes_per_im) | |
if positive_map is not None: | |
token_per_im = positive_map[offset:offset + num_gt, :] | |
offset += num_gt | |
# Recheck if the label matches with the positive map | |
# print(labels_per_im) | |
# print(token_per_im.nonzero()) | |
# shallow contrastive | |
if "original_od_label" in targets_per_im.fields(): | |
gold_box_od_label = targets_per_im.get_field("original_od_label") | |
if "positive_map_for_od_labels" in targets_per_im.fields(): | |
od_label_of_token_per_im = targets_per_im.get_field("positive_map_for_od_labels") | |
# print(gold_box_od_label) | |
# print(od_label_of_token_per_im) | |
if positive_map is not None and proj_tokens is not None: | |
if "tokens_positive" in targets_per_im.fields(): | |
cur_tokens = targets_per_im.get_field("tokens_positive") | |
else: | |
cur_tokens = targets_per_im.get_field("tokens") | |
map = torch.zeros((len(cur_tokens), proj_tokens.shape[1]), dtype=torch.bool) | |
for j, tok_list in enumerate(cur_tokens): | |
for (beg, end) in tok_list: | |
beg_pos = tokenized.char_to_token(im_i, beg) | |
end_pos = tokenized.char_to_token(im_i, end - 1) | |
if beg_pos is None: | |
try: | |
beg_pos = tokenized.char_to_token(im_i, beg + 1) | |
if beg_pos is None: | |
beg_pos = tokenized.char_to_token(im_i, beg + 2) | |
except: | |
beg_pos = None | |
if end_pos is None: | |
try: | |
end_pos = tokenized.char_to_token(im_i, end - 2) | |
if end_pos is None: | |
end_pos = tokenized.char_to_token(im_i, end - 3) | |
except: | |
end_pos = None | |
if beg_pos is None or end_pos is None: | |
continue | |
assert beg_pos is not None and end_pos is not None | |
map[j, beg_pos: end_pos + 1].fill_(True) | |
anchors_per_im = cat_boxlist(anchors[im_i]) | |
num_anchors_per_loc = len(self.cfg.MODEL.RPN.ASPECT_RATIOS) * self.cfg.MODEL.RPN.SCALES_PER_OCTAVE | |
num_anchors_per_level = [len(anchors_per_level.bbox) for anchors_per_level in anchors[im_i]] | |
ious = boxlist_iou(anchors_per_im, targets_per_im) | |
gt_cx = (bboxes_per_im[:, 2] + bboxes_per_im[:, 0]) / 2.0 | |
gt_cy = (bboxes_per_im[:, 3] + bboxes_per_im[:, 1]) / 2.0 | |
gt_points = torch.stack((gt_cx, gt_cy), dim=1) | |
anchors_cx_per_im = (anchors_per_im.bbox[:, 2] + anchors_per_im.bbox[:, 0]) / 2.0 | |
anchors_cy_per_im = (anchors_per_im.bbox[:, 3] + anchors_per_im.bbox[:, 1]) / 2.0 | |
anchor_points = torch.stack((anchors_cx_per_im, anchors_cy_per_im), dim=1) | |
distances = (anchor_points[:, None, :] - gt_points[None, :, :]).pow(2).sum(-1).sqrt() | |
# Selecting candidates based on the center distance between anchor box and object | |
candidate_idxs = [] | |
star_idx = 0 | |
for level, anchors_per_level in enumerate(anchors[im_i]): | |
end_idx = star_idx + num_anchors_per_level[level] | |
distances_per_level = distances[star_idx:end_idx, :] | |
topk = min(self.cfg.MODEL.ATSS.TOPK * num_anchors_per_loc, num_anchors_per_level[level]) | |
_, topk_idxs_per_level = distances_per_level.topk(topk, dim=0, largest=False) | |
candidate_idxs.append(topk_idxs_per_level + star_idx) | |
star_idx = end_idx | |
candidate_idxs = torch.cat(candidate_idxs, dim=0) | |
# Using the sum of mean and standard deviation as the IoU threshold to select final positive samples | |
candidate_ious = ious[candidate_idxs, torch.arange(num_gt)] | |
iou_mean_per_gt = candidate_ious.mean(0) | |
iou_std_per_gt = candidate_ious.std(0) | |
iou_thresh_per_gt = iou_mean_per_gt + iou_std_per_gt | |
is_pos = candidate_ious >= iou_thresh_per_gt[None, :] | |
# Limiting the final positive samples’ center to object | |
anchor_num = anchors_cx_per_im.shape[0] | |
for ng in range(num_gt): | |
candidate_idxs[:, ng] += ng * anchor_num | |
e_anchors_cx = anchors_cx_per_im.view(1, -1).expand(num_gt, anchor_num).contiguous().view(-1) | |
e_anchors_cy = anchors_cy_per_im.view(1, -1).expand(num_gt, anchor_num).contiguous().view(-1) | |
candidate_idxs = candidate_idxs.view(-1) | |
l = e_anchors_cx[candidate_idxs].view(-1, num_gt) - bboxes_per_im[:, 0] | |
t = e_anchors_cy[candidate_idxs].view(-1, num_gt) - bboxes_per_im[:, 1] | |
r = bboxes_per_im[:, 2] - e_anchors_cx[candidate_idxs].view(-1, num_gt) | |
b = bboxes_per_im[:, 3] - e_anchors_cy[candidate_idxs].view(-1, num_gt) | |
is_in_gts = torch.stack([l, t, r, b], dim=1).min(dim=1)[0] > 0.01 | |
is_pos = is_pos & is_in_gts | |
# if an anchor box is assigned to multiple gts, the one with the highest IoU will be selected. | |
ious_inf = torch.full_like(ious, -INF).t().contiguous().view(-1) | |
index = candidate_idxs.view(-1)[is_pos.view(-1)] | |
ious_inf[index] = ious.t().contiguous().view(-1)[index] | |
ious_inf = ious_inf.view(num_gt, -1).t() | |
anchors_to_gt_values, anchors_to_gt_indexs = ious_inf.max(dim=1) | |
# get positive anchors index from ATSS | |
positive_index = [i[0].item() for i in torch.nonzero(anchors_to_gt_indexs)] | |
cls_labels_per_im = labels_per_im[anchors_to_gt_indexs] | |
cls_labels_per_im[anchors_to_gt_values == -INF] = 0 | |
if positive_map is not None: | |
token_labels_per_im = token_per_im[anchors_to_gt_indexs] | |
unmatched_labels = torch.zeros(token_labels_per_im.shape[1], device=token_labels_per_im.device) | |
# TODO: temporarially disable the [NoObj] token logic, and only restrict to binary loss | |
unmatched_labels[-1] = 1 # token: none object - > 256 | |
token_labels_per_im[anchors_to_gt_values == -INF] = unmatched_labels | |
# move from cpu to gpu | |
token_labels_per_im = token_labels_per_im.to(cls_labels_per_im.device) | |
# print(token_labels_per_im[anchors_to_gt_values == -INF].shape) | |
# print(cls_labels_per_im[anchors_to_gt_values != -INF][0]) | |
# print(token_labels_per_im[anchors_to_gt_values != -INF][0].nonzero()) | |
if positive_map is not None and proj_tokens is not None: | |
map_labels_per_im = map[anchors_to_gt_indexs] | |
unmatched_labels = torch.zeros(map_labels_per_im.shape[1], dtype=torch.bool, | |
device=map_labels_per_im.device) # map: none False | |
map_labels_per_im[anchors_to_gt_values == -INF] = unmatched_labels | |
# move from cpu to gpu | |
map_labels_per_im = map_labels_per_im.to(cls_labels_per_im.device) | |
# print(map_labels_per_im[anchors_to_gt_values == -INF].shape) | |
# print(map_labels_per_im[anchors_to_gt_values != -INF][0]) | |
if positive_map is not None and proj_tokens is not None: | |
gold_box_od_label_per_im = gold_box_od_label[anchors_to_gt_indexs] | |
gold_box_od_label_per_im[anchors_to_gt_values == -INF] = -100 | |
# move from cpu to gpu | |
gold_box_od_label_per_im = gold_box_od_label_per_im.to(cls_labels_per_im.device) | |
# print(gold_box_od_label_per_im[anchors_to_gt_values != -INF]) | |
matched_gts = bboxes_per_im[anchors_to_gt_indexs] | |
reg_targets_per_im = self.box_coder.encode(matched_gts, anchors_per_im.bbox) | |
cls_labels.append(cls_labels_per_im) | |
reg_targets.append(reg_targets_per_im) | |
if positive_map is not None: | |
token_labels.append(token_labels_per_im) | |
if positive_map is not None and proj_tokens is not None: | |
map_labels.append(map_labels_per_im) | |
gold_box_od_labels.append(gold_box_od_label_per_im) | |
od_label_of_tokens_labels.append(od_label_of_token_per_im) | |
positive_indices.append(positive_index) | |
# print([len(x) for x in positive_indices]) | |
return cls_labels, reg_targets, token_labels, map_labels, gold_box_od_labels, od_label_of_tokens_labels, positive_indices | |
def compute_centerness_targets(self, reg_targets, anchors): | |
gts = self.box_coder.decode(reg_targets, anchors) | |
anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2 | |
anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2 | |
l = anchors_cx - gts[:, 0] | |
t = anchors_cy - gts[:, 1] | |
r = gts[:, 2] - anchors_cx | |
b = gts[:, 3] - anchors_cy | |
left_right = torch.stack([l, r], dim=1) | |
top_bottom = torch.stack([t, b], dim=1) | |
centerness = torch.sqrt((left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * \ | |
(top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])) | |
assert not torch.isnan(centerness).any() | |
return centerness | |
def __call__(self, box_cls, box_regression, centerness, targets, anchors, | |
captions=None, | |
positive_map=None, | |
token_logits=None, | |
proj_tokens=None, | |
contrastive_logits=None, | |
dot_product_logits=None, | |
text_masks=None, | |
shallow_img_emb_feats=None | |
): | |
tokenized = None | |
if captions is not None: | |
# tokenized = self.tokenizer.batch_encode_plus(captions, padding="longest", return_tensors="pt") | |
if self.cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip": | |
tokenized = self.tokenizer.batch_encode_plus(captions, | |
max_length=self.cfg.MODEL.LANGUAGE_BACKBONE.MAX_QUERY_LEN, | |
padding='max_length' if self.cfg.MODEL.LANGUAGE_BACKBONE.PAD_MAX else "longest", | |
return_tensors='pt', | |
truncation=True) | |
else: | |
tokenized = self.tokenizer.batch_encode_plus(captions, padding="longest", return_tensors="pt") | |
labels, reg_targets, token_labels, map_labels, gold_box_od_labels, od_label_of_tokens_labels, positive_indices = self.prepare_targets(targets, anchors, | |
tokenized, | |
positive_map, | |
proj_tokens | |
) | |
N = len(labels) | |
box_regression_flatten, box_cls_flatten, token_logits_stacked = concat_box_prediction_layers( | |
box_regression, | |
box_cls, | |
token_logits, | |
) | |
# contrastive logits | |
if positive_map is not None and contrastive_logits is not None: | |
contrastive_logits = torch.cat(contrastive_logits, dim=1) | |
# dot product soft token logits | |
if dot_product_logits is not None: | |
dot_product_logits = torch.cat(dot_product_logits, dim=1) | |
centerness_flatten = [ct.permute(0, 2, 3, 1).reshape(N, -1, 1) for ct in centerness] | |
centerness_flatten = torch.cat(centerness_flatten, dim=1).reshape(-1) | |
labels_flatten = torch.cat(labels, dim=0) | |
reg_targets_flatten = torch.cat(reg_targets, dim=0) | |
anchors_flatten = torch.cat([cat_boxlist(anchors_per_image).bbox for anchors_per_image in anchors], dim=0) | |
if positive_map is not None: | |
token_labels_stacked = torch.stack(token_labels, dim=0) | |
if positive_map is not None and proj_tokens is not None: | |
positive_map_box_to_self_text = None | |
shallow_positive_map = None | |
bs = proj_tokens.shape[0] | |
device = proj_tokens.device | |
# NOTE: 0. setup env | |
if dist.is_dist_avail_and_initialized(): | |
world_size = dist.get_world_size() | |
rank = torch.distributed.get_rank() | |
else: | |
world_size = 1 | |
rank = 0 | |
if contrastive_logits is not None: | |
positive_map_box_to_self_text = torch.stack(map_labels, dim=0) | |
if shallow_img_emb_feats is not None: | |
''' | |
Ultimate: | |
N*B*(max_anchor_num) x N*B*T | |
Final Goal: | |
F = B x (max_anchor_num) x N*B*T | |
X: B x (max_anchor_num) od_labels : [0, 20, 30, ..] | |
Y: N*B*T: which denotes the od_label of every token | |
F[i,j] = A[i] == B[j] | |
''' | |
with torch.no_grad(): | |
# NOTE: 1. get X (predicted_box_od_label), which the detection label of every predicted boxes | |
# predicted_box_od_label: B x A | |
# check memory limitation: prevent # of positive >= # of max_positive | |
new_positive_indices = [] | |
# print([len(positive_index) for positive_index in positive_indices]) | |
for positive_index in positive_indices: | |
if len(positive_index) >= self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_MAX_POSITIVE_ANCHORS: | |
import random | |
positive_index = sorted(random.sample(positive_index, | |
self.cfg.MODEL.DYHEAD.FUSE_CONFIG.SHALLOW_MAX_POSITIVE_ANCHORS)) | |
new_positive_indices.append(positive_index) | |
# print([len(positive_index) for positive_index in positive_indices]) | |
max_len = max([len(positive_index) for positive_index in new_positive_indices]) | |
max_anchor_num = max_len | |
if world_size > 1: | |
num_anchors = torch.tensor(max_len, device=positive_map.device) | |
num_anchors_full = [torch.zeros_like(num_anchors) for _ in range(world_size)] | |
torch.distributed.all_gather(num_anchors_full, num_anchors) | |
max_anchor_num = max([anchor.item() for anchor in num_anchors_full]) | |
new_negative_pad_indices = [] | |
# if not PAD_ZEROS, select random negative paddings | |
if not self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_ZERO_PADS: | |
for (positive_index, old_positive_index) in zip(new_positive_indices, positive_indices): | |
negative_index = [i for i in range(len(cat_boxlist(anchors[0]))) if i not in old_positive_index] | |
import random | |
negative_pad_index = sorted(random.sample(negative_index, | |
max_anchor_num - len(positive_index))) | |
new_negative_pad_indices.append(negative_pad_index) | |
predicted_box_od_label = [] | |
for i in range(bs): | |
predicted_box_od_label.append( | |
pad_tensor_given_dim_length(gold_box_od_labels[i][new_positive_indices[i]], | |
dim=0, | |
length=max_anchor_num, | |
padding_value=-100, | |
batch_first=False | |
)) | |
predicted_box_od_label = torch.stack(predicted_box_od_label, dim=0) | |
# if padding, need to create image masks to filter out the paddings | |
image_masks = None | |
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_ZERO_PADS: | |
image_masks = torch.zeros((bs, max_anchor_num), dtype=torch.long).to(text_masks.device) | |
for i in range(bs): | |
image_masks[i, :len(new_positive_indices[i])] = 1 | |
# NOTE: 2. Get Y (od_label_of_tokens) | |
# od_label_of_tokens: N x B x T | |
od_label_of_tokens = torch.stack(od_label_of_tokens_labels, dim=0).long() | |
od_label_of_tokens = gather_tensors(od_label_of_tokens) | |
# NOTE: 3. get F | |
# F: B*A x N*B*T | |
mapping_predicted_box_to_all_text = predicted_box_od_label.view(-1).unsqueeze( | |
1) == od_label_of_tokens.view(-1).unsqueeze(0) | |
# NOTE: 4. we still need to calculate the mapping between predicted box to its corresponding text's mapping | |
# positive_map_box_to_self_text: B x A x T, leave this for vanilla contrastive alignment loss | |
positive_map_box_to_self_text = [] | |
for i in range(bs): | |
positive_map_box_to_self_text.append( | |
pad_tensor_given_dim_length(map_labels[i][new_positive_indices[i]], | |
dim=0, | |
length=max_anchor_num, | |
padding_value=False, | |
batch_first=False | |
)) | |
positive_map_box_to_self_text = torch.stack(positive_map_box_to_self_text, dim=0) | |
# change the corresponding place in our batch | |
for i in range(bs): | |
mapping_predicted_box_to_all_text[i * max_anchor_num: (i + 1) * max_anchor_num, | |
(rank * bs + i) * 256: (rank * bs + i + 1) * 256] = positive_map_box_to_self_text[i] | |
# NOTE: 5. communicate and get positive map | |
# mapping_predicted_box_to_all_text: N*B*A x N*B*T | |
mapping_predicted_box_to_all_text = gather_tensors(mapping_predicted_box_to_all_text).view(-1, | |
mapping_predicted_box_to_all_text.size( | |
-1)) | |
shallow_positive_map = mapping_predicted_box_to_all_text # This is the true positive map | |
shallow_positive_map = shallow_positive_map.unsqueeze(0) | |
# Get text attention masks | |
text_attention_mask = torch.zeros((bs, 256), dtype=torch.long) # B x 256 | |
for i in range(bs): | |
text_attention_mask[i, :len(text_masks[i])] = text_masks[i] | |
text_attention_mask = gather_tensors( | |
text_attention_mask.bool().to(device)) # N x B x 256 | |
# if PAD_ZEROS, get image masks | |
if image_masks is not None: | |
image_attention_mask = torch.zeros((bs, max_anchor_num), dtype=torch.long) # B x max_anchor | |
for i in range(bs): | |
image_attention_mask[i, :len(image_masks[i])] = image_masks[i] | |
image_attention_mask = gather_tensors( | |
image_attention_mask.bool().to(device)) # N x B x max_anchor | |
# NOTE: 6. calculate shallow contrastive logits | |
shallow_proj_tokens = F.normalize(self.shallow_contrastive_projection_text(proj_tokens), p=2, dim=-1) | |
shallow_normalized_img_embs = [] | |
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: | |
# choice 1:use features from SWINT backbone layer (c4) before vl fusion | |
from maskrcnn_benchmark.layers.roi_align import ROIAlignV2 | |
pooler = ROIAlignV2((1, 1), 1./16, 0) | |
# get positive features | |
for i in range(bs): | |
rois = convert_to_roi_format(cat_boxlist(anchors[i])[new_positive_indices[i]]) | |
roi_feature = pooler(shallow_img_emb_feats[i].unsqueeze(0), rois) | |
roi_feature = roi_feature.squeeze(-1).squeeze(-1) | |
shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(roi_feature) | |
shallow_normalized_img_emb = F.normalize(shallow_contrastive_proj_queries, p=2, dim=-1) | |
if image_masks is not None: | |
# pad zeros | |
shallow_normalized_img_embs.append( | |
pad_tensor_given_dim_length(shallow_normalized_img_emb, | |
dim=0, | |
length=max_anchor_num, | |
padding_value=0.0, | |
batch_first=False | |
)) | |
else: | |
# pad negatives | |
negative_rois = convert_to_roi_format(cat_boxlist(anchors[i])[new_negative_pad_indices[i]]) | |
negative_roi_feature = pooler(shallow_img_emb_feats[i].unsqueeze(0), negative_rois) | |
negative_roi_feature = negative_roi_feature.squeeze(-1).squeeze(-1) | |
negative_shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(negative_roi_feature) | |
negative_shallow_normalized_img_emb = F.normalize(negative_shallow_contrastive_proj_queries, | |
p=2, dim=-1) | |
shallow_normalized_img_embs.append( | |
pad_random_negative_tensor_given_length(shallow_normalized_img_emb, | |
negative_shallow_normalized_img_emb, | |
length=max_anchor_num | |
) | |
) | |
elif self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS: | |
# choice 2:use features after FPN | |
shallow_img_embs = torch.cat(shallow_img_emb_feats, dim=1) | |
# get positive features | |
for i in range(bs): | |
shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(shallow_img_embs[i, new_positive_indices[i], :]) | |
shallow_normalized_img_emb = F.normalize(shallow_contrastive_proj_queries, p=2, dim=-1) | |
if image_masks is not None: | |
# pad zeros | |
shallow_normalized_img_embs.append( | |
pad_tensor_given_dim_length(shallow_normalized_img_emb, | |
dim=0, | |
length=max_anchor_num, | |
padding_value=0.0, | |
batch_first=False | |
)) | |
else: | |
# pad negatives | |
negative_shallow_contrastive_proj_queries = self.shallow_contrastive_projection_image(shallow_img_embs[i, new_negative_pad_indices[i], :]) | |
negative_shallow_normalized_img_emb = F.normalize(negative_shallow_contrastive_proj_queries, | |
p=2, dim=-1) | |
shallow_normalized_img_embs.append( | |
pad_random_negative_tensor_given_length(shallow_normalized_img_emb, | |
negative_shallow_normalized_img_emb, | |
length=max_anchor_num | |
) | |
) | |
shallow_normalized_img_embs = torch.stack(shallow_normalized_img_embs, dim=0) | |
shallow_normalized_text_emb = shallow_proj_tokens | |
shallow_normalized_text_emb = pad_tensor_given_dim_length(shallow_normalized_text_emb, | |
dim=1, | |
length=256, | |
padding_value=0.0) | |
gathered_shallow_normalized_img_emb = gather_tensors(shallow_normalized_img_embs) | |
gathered_shallow_normalized_text_emb = gather_tensors(shallow_normalized_text_emb) | |
gathered_shallow_normalized_img_emb = gathered_shallow_normalized_img_emb.view(-1, | |
gathered_shallow_normalized_img_emb.size( | |
-1)) | |
gathered_shallow_normalized_text_emb = gathered_shallow_normalized_text_emb.view(-1, | |
gathered_shallow_normalized_text_emb.size( | |
-1)) | |
shallow_contrastive_logits = ( | |
torch.matmul(gathered_shallow_normalized_img_emb, | |
gathered_shallow_normalized_text_emb.transpose(-1, | |
-2)) / self.shallow_log_scale.exp()) | |
shallow_contrastive_logits = shallow_contrastive_logits.unsqueeze(0) | |
# apply text mask | |
text_attention_mask = text_attention_mask.view(-1).unsqueeze(0).unsqueeze(0) | |
text_attention_mask = text_attention_mask.repeat(1, shallow_contrastive_logits.size(1), | |
1) # copy along the image feature dimension | |
shallow_contrastive_logits = shallow_contrastive_logits.masked_fill(~text_attention_mask, -1000000) | |
# if PAD ZEROS, apply image mask | |
if image_masks is not None: | |
image_attention_mask = image_attention_mask.view(-1).unsqueeze(0).unsqueeze(-1) | |
image_attention_mask = image_attention_mask.repeat(1, 1, shallow_contrastive_logits.size( | |
2)) # copy along the text feature dimension | |
shallow_contrastive_logits = shallow_contrastive_logits.masked_fill(~image_attention_mask, -1000000) | |
# Note: 7. calculate image and text logits and maps | |
shallow_image_logits = shallow_contrastive_logits[:, | |
(rank * bs) * max_anchor_num: (rank * bs + bs) * max_anchor_num, :] | |
shallow_image_positive_map = normalized_positive_map( | |
shallow_positive_map[:, (rank * bs) * max_anchor_num: (rank * bs + bs) * max_anchor_num, :]) | |
shallow_text_logits = shallow_contrastive_logits[:, :, | |
(rank * bs) * 256: (rank * bs + bs) * 256].transpose(1, | |
2) | |
shallow_text_positive_map = normalized_positive_map( | |
shallow_positive_map[:, :, (rank * bs) * 256: (rank * bs + bs) * 256].transpose(1, 2)) | |
pos_inds = torch.nonzero(labels_flatten > 0).squeeze(1) | |
num_gpus = get_world_size() | |
total_num_pos = reduce_sum(pos_inds.new_tensor([pos_inds.numel()])).item() | |
num_pos_avg_per_gpu = max(total_num_pos / float(num_gpus), 1.0) | |
cls_loss = self.cls_loss_func(box_cls_flatten, labels_flatten.int()) / num_pos_avg_per_gpu | |
token_logits_loss = None | |
contrastive_align_loss = None | |
dot_product_token_loss = None | |
shallow_contrastive_loss = None | |
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_TOKEN_LOSS: | |
token_logits_loss = self.token_loss_func(token_logits_stacked, | |
token_labels_stacked, text_masks=text_masks, | |
version="binary") / num_pos_avg_per_gpu | |
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_CONTRASTIVE_ALIGN_LOSS: | |
contrastive_align_loss = self.ContrastiveAlignLoss(contrastive_logits, positive_map_box_to_self_text) / num_pos_avg_per_gpu | |
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_DOT_PRODUCT_TOKEN_LOSS: | |
dot_product_token_loss = self.token_loss_func(dot_product_logits, | |
token_labels_stacked, text_masks=text_masks, | |
version="binary") / num_pos_avg_per_gpu | |
if self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_SHALLOW_CONTRASTIVE_LOSS or \ | |
self.cfg.MODEL.DYHEAD.FUSE_CONFIG.USE_BACKBONE_SHALLOW_CONTRASTIVE_LOSS: | |
box_to_token_loss = self.NllSoftMaxLoss(shallow_image_logits, shallow_image_positive_map).sum() | |
token_to_box_loss = self.NllSoftMaxLoss(shallow_text_logits, shallow_text_positive_map).sum() | |
tot_loss = (box_to_token_loss + token_to_box_loss) / 2 | |
shallow_contrastive_loss = tot_loss / num_pos_avg_per_gpu | |
box_regression_flatten = box_regression_flatten[pos_inds] | |
reg_targets_flatten = reg_targets_flatten[pos_inds] | |
anchors_flatten = anchors_flatten[pos_inds] | |
centerness_flatten = centerness_flatten[pos_inds] | |
if pos_inds.numel() > 0: | |
centerness_targets = self.compute_centerness_targets(reg_targets_flatten, anchors_flatten) | |
sum_centerness_targets_avg_per_gpu = reduce_sum(centerness_targets.sum()).item() / float(num_gpus) | |
reg_loss = self.GIoULoss(box_regression_flatten, reg_targets_flatten, anchors_flatten, | |
weight=centerness_targets) / sum_centerness_targets_avg_per_gpu | |
centerness_loss = self.centerness_loss_func(centerness_flatten, centerness_targets) / num_pos_avg_per_gpu | |
else: | |
reg_loss = box_regression_flatten.sum() | |
reduce_sum(centerness_flatten.new_tensor([0.0])) | |
centerness_loss = centerness_flatten.sum() | |
return cls_loss, reg_loss * self.cfg.MODEL.ATSS.REG_LOSS_WEIGHT, centerness_loss, \ | |
token_logits_loss, \ | |
contrastive_align_loss, \ | |
dot_product_token_loss, \ | |
shallow_contrastive_loss | |
def generate_anchor_labels(matched_targets): | |
labels_per_image = matched_targets.get_field("labels") | |
return labels_per_image | |
def make_focal_loss_evaluator(cfg, box_coder): | |
matcher = Matcher( | |
cfg.MODEL.FOCAL.FG_IOU_THRESHOLD, | |
cfg.MODEL.FOCAL.BG_IOU_THRESHOLD, | |
allow_low_quality_matches=True, | |
) | |
sigmoid_focal_loss = SigmoidFocalLoss( | |
cfg.MODEL.FOCAL.LOSS_GAMMA, | |
cfg.MODEL.FOCAL.LOSS_ALPHA | |
) | |
loss_evaluator = FocalLossComputation( | |
matcher, | |
box_coder, | |
generate_anchor_labels, | |
sigmoid_focal_loss, | |
bbox_reg_beta=cfg.MODEL.FOCAL.BBOX_REG_BETA, | |
regress_norm=cfg.MODEL.FOCAL.BBOX_REG_WEIGHT, | |
) | |
return loss_evaluator | |
def make_rpn_loss_evaluator(cfg, box_coder): | |
matcher = Matcher( | |
cfg.MODEL.RPN.FG_IOU_THRESHOLD, | |
cfg.MODEL.RPN.BG_IOU_THRESHOLD, | |
allow_low_quality_matches=True, | |
) | |
fg_bg_sampler = BalancedPositiveNegativeSampler( | |
cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE, cfg.MODEL.RPN.POSITIVE_FRACTION | |
) | |
loss_evaluator = RPNLossComputation(matcher, fg_bg_sampler, box_coder) | |
return loss_evaluator | |
def make_fcos_loss_evaluator(cfg): | |
loss_evaluator = FCOSLossComputation(cfg) | |
return loss_evaluator | |
def make_atss_loss_evaluator(cfg, box_coder): | |
loss_evaluator = ATSSLossComputation(cfg, box_coder) | |
return loss_evaluator | |