|
|
|
#include "../box_iou_rotated/box_iou_rotated_utils.h" |
|
#include "nms_rotated.h" |
|
|
|
namespace detectron2 { |
|
|
|
template <typename scalar_t> |
|
at::Tensor nms_rotated_cpu_kernel( |
|
const at::Tensor& dets, |
|
const at::Tensor& scores, |
|
const double iou_threshold) { |
|
|
|
|
|
|
|
|
|
AT_ASSERTM(dets.device().is_cpu(), "dets must be a CPU tensor"); |
|
AT_ASSERTM(scores.device().is_cpu(), "scores must be a CPU tensor"); |
|
AT_ASSERTM( |
|
dets.scalar_type() == scores.scalar_type(), |
|
"dets should have the same type as scores"); |
|
|
|
if (dets.numel() == 0) { |
|
return at::empty({0}, dets.options().dtype(at::kLong)); |
|
} |
|
|
|
auto order_t = std::get<1>(scores.sort(0, true)); |
|
|
|
auto ndets = dets.size(0); |
|
at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); |
|
at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); |
|
|
|
auto suppressed = suppressed_t.data_ptr<uint8_t>(); |
|
auto keep = keep_t.data_ptr<int64_t>(); |
|
auto order = order_t.data_ptr<int64_t>(); |
|
|
|
int64_t num_to_keep = 0; |
|
|
|
for (int64_t _i = 0; _i < ndets; _i++) { |
|
auto i = order[_i]; |
|
if (suppressed[i] == 1) { |
|
continue; |
|
} |
|
|
|
keep[num_to_keep++] = i; |
|
|
|
for (int64_t _j = _i + 1; _j < ndets; _j++) { |
|
auto j = order[_j]; |
|
if (suppressed[j] == 1) { |
|
continue; |
|
} |
|
|
|
auto ovr = single_box_iou_rotated<scalar_t>( |
|
dets[i].data_ptr<scalar_t>(), dets[j].data_ptr<scalar_t>()); |
|
if (ovr >= iou_threshold) { |
|
suppressed[j] = 1; |
|
} |
|
} |
|
} |
|
return keep_t.narrow(0, 0, num_to_keep); |
|
} |
|
|
|
at::Tensor nms_rotated_cpu( |
|
|
|
const at::Tensor& dets, |
|
const at::Tensor& scores, |
|
const double iou_threshold) { |
|
auto result = at::empty({0}, dets.options()); |
|
|
|
AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated", [&] { |
|
result = nms_rotated_cpu_kernel<scalar_t>(dets, scores, iou_threshold); |
|
}); |
|
return result; |
|
} |
|
|
|
} |
|
|