Spaces:
Build error
Build error
import torch | |
from torch.nn import functional as F | |
from maskrcnn_benchmark.modeling.matcher import Matcher | |
from maskrcnn_benchmark.modeling.balanced_positive_negative_sampler import ( | |
BalancedPositiveNegativeSampler, | |
) | |
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou | |
from maskrcnn_benchmark.modeling.utils import cat | |
from maskrcnn_benchmark.layers import smooth_l1_loss | |
from maskrcnn_benchmark.structures.boxlist_ops import cat_boxlist | |
from maskrcnn_benchmark.structures.keypoint import keypoints_to_heat_map | |
def project_keypoints_to_heatmap(keypoints, proposals, discretization_size): | |
proposals = proposals.convert("xyxy") | |
return keypoints_to_heat_map( | |
keypoints.keypoints, proposals.bbox, discretization_size | |
) | |
def cat_boxlist_with_keypoints(boxlists): | |
assert all(boxlist.has_field("keypoints") for boxlist in boxlists) | |
kp = [boxlist.get_field("keypoints").keypoints for boxlist in boxlists] | |
kp = cat(kp, 0) | |
fields = boxlists[0].get_fields() | |
fields = [field for field in fields if field != "keypoints"] | |
boxlists = [boxlist.copy_with_fields(fields) for boxlist in boxlists] | |
boxlists = cat_boxlist(boxlists) | |
boxlists.add_field("keypoints", kp) | |
return boxlists | |
def _within_box(points, boxes): | |
"""Validate which keypoints are contained inside a given box. | |
points: NxKx2 | |
boxes: Nx4 | |
output: NxK | |
""" | |
x_within = (points[..., 0] >= boxes[:, 0, None]) & ( | |
points[..., 0] <= boxes[:, 2, None] | |
) | |
y_within = (points[..., 1] >= boxes[:, 1, None]) & ( | |
points[..., 1] <= boxes[:, 3, None] | |
) | |
return x_within & y_within | |
class KeypointRCNNLossComputation(object): | |
def __init__(self, proposal_matcher, fg_bg_sampler, discretization_size): | |
""" | |
Arguments: | |
proposal_matcher (Matcher) | |
fg_bg_sampler (BalancedPositiveNegativeSampler) | |
discretization_size (int) | |
""" | |
self.proposal_matcher = proposal_matcher | |
self.fg_bg_sampler = fg_bg_sampler | |
self.discretization_size = discretization_size | |
def match_targets_to_proposals(self, proposal, target): | |
match_quality_matrix = boxlist_iou(target, proposal) | |
matched_idxs = self.proposal_matcher(match_quality_matrix) | |
# Keypoint RCNN needs "labels" and "keypoints "fields for creating the targets | |
target = target.copy_with_fields(["labels", "keypoints"]) | |
# get the targets corresponding GT for each proposal | |
# NB: need to clamp the indices because we can have a single | |
# GT in the image, and matched_idxs can be -2, which goes | |
# out of bounds | |
matched_targets = target[matched_idxs.clamp(min=0)] | |
matched_targets.add_field("matched_idxs", matched_idxs) | |
return matched_targets | |
def prepare_targets(self, proposals, targets): | |
labels = [] | |
keypoints = [] | |
for proposals_per_image, targets_per_image in zip(proposals, targets): | |
matched_targets = self.match_targets_to_proposals( | |
proposals_per_image, targets_per_image | |
) | |
matched_idxs = matched_targets.get_field("matched_idxs") | |
labels_per_image = matched_targets.get_field("labels") | |
labels_per_image = labels_per_image.to(dtype=torch.int64) | |
# this can probably be removed, but is left here for clarity | |
# and completeness | |
# TODO check if this is the right one, as BELOW_THRESHOLD | |
neg_inds = matched_idxs == Matcher.BELOW_LOW_THRESHOLD | |
labels_per_image[neg_inds] = 0 | |
keypoints_per_image = matched_targets.get_field("keypoints") | |
within_box = _within_box( | |
keypoints_per_image.keypoints, matched_targets.bbox | |
) | |
vis_kp = keypoints_per_image.keypoints[..., 2] > 0 | |
is_visible = (within_box & vis_kp).sum(1) > 0 | |
labels_per_image[~is_visible] = -1 | |
labels.append(labels_per_image) | |
keypoints.append(keypoints_per_image) | |
return labels, keypoints | |
def subsample(self, proposals, targets): | |
""" | |
This method performs the positive/negative sampling, and return | |
the sampled proposals. | |
Note: this function keeps a state. | |
Arguments: | |
proposals (list[BoxList]) | |
targets (list[BoxList]) | |
""" | |
labels, keypoints = self.prepare_targets(proposals, targets) | |
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels) | |
proposals = list(proposals) | |
# add corresponding label and regression_targets information to the bounding boxes | |
for labels_per_image, keypoints_per_image, proposals_per_image in zip( | |
labels, keypoints, proposals | |
): | |
proposals_per_image.add_field("labels", labels_per_image) | |
proposals_per_image.add_field("keypoints", keypoints_per_image) | |
# distributed sampled proposals, that were obtained on all feature maps | |
# concatenated via the fg_bg_sampler, into individual feature map levels | |
for img_idx, (pos_inds_img, neg_inds_img) in enumerate( | |
zip(sampled_pos_inds, sampled_neg_inds) | |
): | |
img_sampled_inds = torch.nonzero(pos_inds_img).squeeze(1) | |
proposals_per_image = proposals[img_idx][img_sampled_inds] | |
proposals[img_idx] = proposals_per_image | |
self._proposals = proposals | |
return proposals | |
def __call__(self, proposals, keypoint_logits): | |
heatmaps = [] | |
valid = [] | |
for proposals_per_image in proposals: | |
kp = proposals_per_image.get_field("keypoints") | |
heatmaps_per_image, valid_per_image = project_keypoints_to_heatmap( | |
kp, proposals_per_image, self.discretization_size | |
) | |
heatmaps.append(heatmaps_per_image.view(-1)) | |
valid.append(valid_per_image.view(-1)) | |
keypoint_targets = cat(heatmaps, dim=0) | |
valid = cat(valid, dim=0).to(dtype=torch.bool) | |
valid = torch.nonzero(valid).squeeze(1) | |
# torch.mean (in binary_cross_entropy_with_logits) does'nt | |
# accept empty tensors, so handle it sepaartely | |
if keypoint_targets.numel() == 0 or len(valid) == 0: | |
return keypoint_logits.sum() * 0 | |
N, K, H, W = keypoint_logits.shape | |
keypoint_logits = keypoint_logits.view(N * K, H * W) | |
keypoint_loss = F.cross_entropy(keypoint_logits[valid], keypoint_targets[valid]) | |
return keypoint_loss | |
def make_roi_keypoint_loss_evaluator(cfg): | |
matcher = Matcher( | |
cfg.MODEL.ROI_HEADS.FG_IOU_THRESHOLD, | |
cfg.MODEL.ROI_HEADS.BG_IOU_THRESHOLD, | |
allow_low_quality_matches=False, | |
) | |
fg_bg_sampler = BalancedPositiveNegativeSampler( | |
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION | |
) | |
resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.RESOLUTION | |
loss_evaluator = KeypointRCNNLossComputation(matcher, fg_bg_sampler, resolution) | |
return loss_evaluator |