Spaces:
Sleeping
Sleeping
from itertools import count | |
import torch | |
import torch.nn as nn | |
from torch.nn.parallel import DistributedDataParallel | |
from detectron2.modeling import GeneralizedRCNNWithTTA, DatasetMapperTTA | |
from detectron2.modeling.roi_heads.fast_rcnn import fast_rcnn_inference_single_image | |
from detectron2.structures import Instances, Boxes | |
class RegionSpotWithTTA(GeneralizedRCNNWithTTA): | |
def __init__(self, cfg, model, tta_mapper=None, batch_size=3): | |
""" | |
Args: | |
cfg (CfgNode): | |
model ( RegionSpot): a RegionSpot to apply TTA on. | |
tta_mapper (callable): takes a dataset dict and returns a list of | |
augmented versions of the dataset dict. Defaults to | |
`DatasetMapperTTA(cfg)`. | |
batch_size (int): batch the augmented images into this batch size for inference. | |
""" | |
# fix the issue: cannot assign module before Module.__init__() call | |
nn.Module.__init__(self) | |
if isinstance(model, DistributedDataParallel): | |
model = model.module | |
self.cfg = cfg.clone() | |
self.model = model | |
if tta_mapper is None: | |
tta_mapper = DatasetMapperTTA(cfg) | |
self.tta_mapper = tta_mapper | |
self.batch_size = batch_size | |
# cvpods tta. | |
self.enable_cvpods_tta = cfg.TEST.AUG.CVPODS_TTA | |
self.enable_scale_filter = cfg.TEST.AUG.SCALE_FILTER | |
self.scale_ranges = cfg.TEST.AUG.SCALE_RANGES | |
self.max_detection = cfg.MODEL.RegionSpot.NUM_PROPOSALS | |
def _batch_inference(self, batched_inputs, detected_instances=None): | |
""" | |
Execute inference on a list of inputs, | |
using batch size = self.batch_size, instead of the length of the list. | |
""" | |
if detected_instances is None: | |
detected_instances = [None] * len(batched_inputs) | |
factors = 2 if self.tta_mapper.flip else 1 | |
if self.enable_scale_filter: | |
assert len(batched_inputs) == len(self.scale_ranges) * factors | |
outputs = [] | |
inputs, instances = [], [] | |
for idx, input, instance in zip(count(), batched_inputs, detected_instances): | |
inputs.append(input) | |
instances.append(instance) | |
if self.enable_cvpods_tta: | |
output = self.model.forward(inputs, do_postprocess=False)[0] | |
if self.enable_scale_filter: | |
pred_boxes = output.get("pred_boxes") | |
keep = self.filter_boxes(pred_boxes.tensor, *self.scale_ranges[idx // factors]) | |
output = Instances( | |
image_size=output.image_size, | |
pred_boxes=Boxes(pred_boxes.tensor[keep]), | |
pred_classes=output.pred_classes[keep], | |
scores=output.scores[keep]) | |
outputs.extend([output]) | |
else: | |
if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1: | |
outputs.extend( | |
self.model.forward( | |
inputs, | |
do_postprocess=False, | |
) | |
) | |
inputs, instances = [], [] | |
return outputs | |
def filter_boxes(boxes, min_scale, max_scale): | |
""" | |
boxes: (N, 4) shape | |
""" | |
# assert boxes.mode == "xyxy" | |
w = boxes[:, 2] - boxes[:, 0] | |
h = boxes[:, 3] - boxes[:, 1] | |
keep = (w * h > min_scale * min_scale) & (w * h < max_scale * max_scale) | |
return keep | |
def _inference_one_image(self, input): | |
""" | |
Args: | |
input (dict): one dataset dict with "image" field being a CHW tensor | |
Returns: | |
dict: one output dict | |
""" | |
orig_shape = (input["height"], input["width"]) | |
augmented_inputs, tfms = self._get_augmented_inputs(input) | |
# Detect boxes from all augmented versions | |
all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms) | |
# merge all detected boxes to obtain final predictions for boxes | |
if self.enable_cvpods_tta: | |
merged_instances = self._merge_detections_cvpods_tta(all_boxes, all_scores, all_classes, orig_shape) | |
else: | |
merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape) | |
return {"instances": merged_instances} | |
def _merge_detections(self, all_boxes, all_scores, all_classes, shape_hw): | |
# select from the union of all results | |
num_boxes = len(all_boxes) | |
num_classes = self.cfg.MODEL. RegionSpot.NUM_CLASSES | |
# +1 because fast_rcnn_inference expects background scores as well | |
all_scores_2d = torch.zeros(num_boxes, num_classes + 1, device=all_boxes.device) | |
for idx, cls, score in zip(count(), all_classes, all_scores): | |
all_scores_2d[idx, cls] = score | |
merged_instances, _ = fast_rcnn_inference_single_image( | |
all_boxes, | |
all_scores_2d, | |
shape_hw, | |
1e-8, | |
self.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST, | |
self.cfg.TEST.DETECTIONS_PER_IMAGE, | |
) | |
return merged_instances | |
def _merge_detections_cvpods_tta(self, all_boxes, all_scores, all_classes, shape_hw): | |
all_scores = torch.tensor(all_scores).to(all_boxes.device) | |
all_classes = torch.tensor(all_classes).to(all_boxes.device) | |
all_boxes, all_scores, all_classes = self.merge_result_from_multi_scales( | |
all_boxes, all_scores, all_classes, | |
nms_type="soft_vote", vote_thresh=0.65, | |
max_detection=self.max_detection | |
) | |
all_boxes = Boxes(all_boxes) | |
all_boxes.clip(shape_hw) | |
result = Instances(shape_hw) | |
result.pred_boxes = all_boxes | |
result.scores = all_scores | |
result.pred_classes = all_classes.long() | |
return result | |
def merge_result_from_multi_scales( | |
self, boxes, scores, labels, nms_type="soft-vote", vote_thresh=0.65, max_detection=100 | |
): | |
boxes, scores, labels = self.batched_vote_nms( | |
boxes, scores, labels, nms_type, vote_thresh | |
) | |
number_of_detections = boxes.shape[0] | |
# Limit to max_per_image detections **over all classes** | |
if number_of_detections > max_detection > 0: | |
boxes = boxes[:max_detection] | |
scores = scores[:max_detection] | |
labels = labels[:max_detection] | |
return boxes, scores, labels | |
def batched_vote_nms(self, boxes, scores, labels, vote_type, vote_thresh=0.65): | |
# apply per class level nms, add max_coordinates on boxes first, then remove it. | |
labels = labels.float() | |
max_coordinates = boxes.max() + 1 | |
offsets = labels.reshape(-1, 1) * max_coordinates | |
boxes = boxes + offsets | |
boxes, scores, labels = self.bbox_vote(boxes, scores, labels, vote_thresh, vote_type) | |
boxes -= labels.reshape(-1, 1) * max_coordinates | |
return boxes, scores, labels | |
def bbox_vote(self, boxes, scores, labels, vote_thresh, vote_type="softvote"): | |
assert boxes.shape[0] == scores.shape[0] == labels.shape[0] | |
det = torch.cat((boxes, scores.reshape(-1, 1), labels.reshape(-1, 1)), dim=1) | |
vote_results = torch.zeros(0, 6, device=det.device) | |
if det.numel() == 0: | |
return vote_results[:, :4], vote_results[:, 4], vote_results[:, 5] | |
order = scores.argsort(descending=True) | |
det = det[order] | |
while det.shape[0] > 0: | |
# IOU | |
area = (det[:, 2] - det[:, 0]) * (det[:, 3] - det[:, 1]) | |
xx1 = torch.max(det[0, 0], det[:, 0]) | |
yy1 = torch.max(det[0, 1], det[:, 1]) | |
xx2 = torch.min(det[0, 2], det[:, 2]) | |
yy2 = torch.min(det[0, 3], det[:, 3]) | |
w = torch.clamp(xx2 - xx1, min=0.) | |
h = torch.clamp(yy2 - yy1, min=0.) | |
inter = w * h | |
iou = inter / (area[0] + area[:] - inter) | |
# get needed merge det and delete these det | |
merge_index = torch.where(iou >= vote_thresh)[0] | |
vote_det = det[merge_index, :] | |
det = det[iou < vote_thresh] | |
if merge_index.shape[0] <= 1: | |
vote_results = torch.cat((vote_results, vote_det), dim=0) | |
else: | |
if vote_type == "soft_vote": | |
vote_det_iou = iou[merge_index] | |
det_accu_sum = self.get_soft_dets_sum(vote_det, vote_det_iou) | |
elif vote_type == "vote": | |
det_accu_sum = self.get_dets_sum(vote_det) | |
vote_results = torch.cat((vote_results, det_accu_sum), dim=0) | |
order = vote_results[:, 4].argsort(descending=True) | |
vote_results = vote_results[order, :] | |
return vote_results[:, :4], vote_results[:, 4], vote_results[:, 5] | |
def get_dets_sum(vote_det): | |
vote_det[:, :4] *= vote_det[:, 4:5].repeat(1, 4) | |
max_score = vote_det[:, 4].max() | |
det_accu_sum = torch.zeros((1, 6), device=vote_det.device) | |
det_accu_sum[:, :4] = torch.sum(vote_det[:, :4], dim=0) / torch.sum(vote_det[:, 4]) | |
det_accu_sum[:, 4] = max_score | |
det_accu_sum[:, 5] = vote_det[0, 5] | |
return det_accu_sum | |
def get_soft_dets_sum(vote_det, vote_det_iou): | |
soft_vote_det = vote_det.detach().clone() | |
soft_vote_det[:, 4] *= (1 - vote_det_iou) | |
INFERENCE_TH = 0.05 | |
soft_index = torch.where(soft_vote_det[:, 4] >= INFERENCE_TH)[0] | |
soft_vote_det = soft_vote_det[soft_index, :] | |
vote_det[:, :4] *= vote_det[:, 4:5].repeat(1, 4) | |
max_score = vote_det[:, 4].max() | |
det_accu_sum = torch.zeros((1, 6), device=vote_det.device) | |
det_accu_sum[:, :4] = torch.sum(vote_det[:, :4], dim=0) / torch.sum(vote_det[:, 4]) | |
det_accu_sum[:, 4] = max_score | |
det_accu_sum[:, 5] = vote_det[0, 5] | |
if soft_vote_det.shape[0] > 0: | |
det_accu_sum = torch.cat((det_accu_sum, soft_vote_det), dim=0) | |
return det_accu_sum | |