Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from torch import Tensor | |
_XYWH2XYXY = torch.tensor([[1.0, 0.0, 1.0, 0.0], [0.0, 1.0, 0.0, 1.0], | |
[-0.5, 0.0, 0.5, 0.0], [0.0, -0.5, 0.0, 0.5]], | |
dtype=torch.float32) | |
def select_nms_index(scores: Tensor, | |
boxes: Tensor, | |
nms_index: Tensor, | |
batch_size: int, | |
keep_top_k: int = -1): | |
batch_inds, cls_inds = nms_index[:, 0], nms_index[:, 1] | |
box_inds = nms_index[:, 2] | |
scores = scores[batch_inds, cls_inds, box_inds].unsqueeze(1) | |
boxes = boxes[batch_inds, box_inds, ...] | |
dets = torch.cat([boxes, scores], dim=1) | |
batched_dets = dets.unsqueeze(0).repeat(batch_size, 1, 1) | |
batch_template = torch.arange( | |
0, batch_size, dtype=batch_inds.dtype, device=batch_inds.device) | |
batched_dets = batched_dets.where( | |
(batch_inds == batch_template.unsqueeze(1)).unsqueeze(-1), | |
batched_dets.new_zeros(1)) | |
batched_labels = cls_inds.unsqueeze(0).repeat(batch_size, 1) | |
batched_labels = batched_labels.where( | |
(batch_inds == batch_template.unsqueeze(1)), | |
batched_labels.new_ones(1) * -1) | |
N = batched_dets.shape[0] | |
batched_dets = torch.cat((batched_dets, batched_dets.new_zeros((N, 1, 5))), | |
1) | |
batched_labels = torch.cat((batched_labels, -batched_labels.new_ones( | |
(N, 1))), 1) | |
_, topk_inds = batched_dets[:, :, -1].sort(dim=1, descending=True) | |
topk_batch_inds = torch.arange( | |
batch_size, dtype=topk_inds.dtype, | |
device=topk_inds.device).view(-1, 1) | |
batched_dets = batched_dets[topk_batch_inds, topk_inds, ...] | |
batched_labels = batched_labels[topk_batch_inds, topk_inds, ...] | |
batched_dets, batched_scores = batched_dets.split([4, 1], 2) | |
batched_scores = batched_scores.squeeze(-1) | |
num_dets = (batched_scores > 0).sum(1, keepdim=True) | |
return num_dets, batched_dets, batched_scores, batched_labels | |
class ONNXNMSop(torch.autograd.Function): | |
def forward( | |
ctx, | |
boxes: Tensor, | |
scores: Tensor, | |
max_output_boxes_per_class: Tensor = torch.tensor([100]), | |
iou_threshold: Tensor = torch.tensor([0.5]), | |
score_threshold: Tensor = torch.tensor([0.05]) | |
) -> Tensor: | |
device = boxes.device | |
batch = scores.shape[0] | |
num_det = 20 | |
batches = torch.randint(0, batch, (num_det, )).sort()[0].to(device) | |
idxs = torch.arange(100, 100 + num_det).to(device) | |
zeros = torch.zeros((num_det, ), dtype=torch.int64).to(device) | |
selected_indices = torch.cat([batches[None], zeros[None], idxs[None]], | |
0).T.contiguous() | |
selected_indices = selected_indices.to(torch.int64) | |
return selected_indices | |
def symbolic( | |
g, | |
boxes: Tensor, | |
scores: Tensor, | |
max_output_boxes_per_class: Tensor = torch.tensor([100]), | |
iou_threshold: Tensor = torch.tensor([0.5]), | |
score_threshold: Tensor = torch.tensor([0.05]), | |
): | |
return g.op( | |
'NonMaxSuppression', | |
boxes, | |
scores, | |
max_output_boxes_per_class, | |
iou_threshold, | |
score_threshold, | |
outputs=1) | |
def onnx_nms( | |
boxes: torch.Tensor, | |
scores: torch.Tensor, | |
max_output_boxes_per_class: int = 100, | |
iou_threshold: float = 0.5, | |
score_threshold: float = 0.05, | |
pre_top_k: int = -1, | |
keep_top_k: int = 100, | |
box_coding: int = 0, | |
): | |
max_output_boxes_per_class = torch.tensor([max_output_boxes_per_class]) | |
iou_threshold = torch.tensor([iou_threshold]) | |
score_threshold = torch.tensor([score_threshold]) | |
batch_size, _, _ = scores.shape | |
if box_coding == 1: | |
boxes = boxes @ (_XYWH2XYXY.to(boxes.device)) | |
scores = scores.transpose(1, 2).contiguous() | |
selected_indices = ONNXNMSop.apply(boxes, scores, | |
max_output_boxes_per_class, | |
iou_threshold, score_threshold) | |
num_dets, batched_dets, batched_scores, batched_labels = select_nms_index( | |
scores, boxes, selected_indices, batch_size, keep_top_k=keep_top_k) | |
return num_dets, batched_dets, batched_scores, batched_labels.to( | |
torch.int32) | |