|
|
|
|
|
#include <torch/extension.h> |
|
#include "ROIAlignRotated/ROIAlignRotated.h" |
|
#include "box_iou_rotated/box_iou_rotated.h" |
|
#include "cocoeval/cocoeval.h" |
|
#include "deformable/deform_conv.h" |
|
#include "nms_rotated/nms_rotated.h" |
|
|
|
namespace detectron2 { |
|
|
|
#if defined(WITH_CUDA) || defined(WITH_HIP) |
|
extern int get_cudart_version(); |
|
#endif |
|
|
|
std::string get_cuda_version() { |
|
#if defined(WITH_CUDA) || defined(WITH_HIP) |
|
std::ostringstream oss; |
|
|
|
#if defined(WITH_CUDA) |
|
oss << "CUDA "; |
|
#else |
|
oss << "HIP "; |
|
#endif |
|
|
|
|
|
|
|
auto printCudaStyleVersion = [&](int v) { |
|
oss << (v / 1000) << "." << (v / 10 % 100); |
|
if (v % 10 != 0) { |
|
oss << "." << (v % 10); |
|
} |
|
}; |
|
printCudaStyleVersion(get_cudart_version()); |
|
return oss.str(); |
|
#else |
|
return std::string("not available"); |
|
#endif |
|
} |
|
|
|
bool has_cuda() { |
|
#if defined(WITH_CUDA) |
|
return true; |
|
#else |
|
return false; |
|
#endif |
|
} |
|
|
|
|
|
|
|
std::string get_compiler_version() { |
|
std::ostringstream ss; |
|
#if defined(__GNUC__) |
|
#ifndef __clang__ |
|
|
|
#if ((__GNUC__ <= 4) && (__GNUC_MINOR__ <= 8)) |
|
#error "GCC >= 4.9 is required!" |
|
#endif |
|
|
|
{ ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; } |
|
#endif |
|
#endif |
|
|
|
#if defined(__clang_major__) |
|
{ |
|
ss << "clang " << __clang_major__ << "." << __clang_minor__ << "." |
|
<< __clang_patchlevel__; |
|
} |
|
#endif |
|
|
|
#if defined(_MSC_VER) |
|
{ ss << "MSVC " << _MSC_FULL_VER; } |
|
#endif |
|
return ss.str(); |
|
} |
|
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
|
m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); |
|
m.def("get_cuda_version", &get_cuda_version, "get_cuda_version"); |
|
m.def("has_cuda", &has_cuda, "has_cuda"); |
|
|
|
m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward"); |
|
m.def( |
|
"deform_conv_backward_input", |
|
&deform_conv_backward_input, |
|
"deform_conv_backward_input"); |
|
m.def( |
|
"deform_conv_backward_filter", |
|
&deform_conv_backward_filter, |
|
"deform_conv_backward_filter"); |
|
m.def( |
|
"modulated_deform_conv_forward", |
|
&modulated_deform_conv_forward, |
|
"modulated_deform_conv_forward"); |
|
m.def( |
|
"modulated_deform_conv_backward", |
|
&modulated_deform_conv_backward, |
|
"modulated_deform_conv_backward"); |
|
|
|
m.def("COCOevalAccumulate", &COCOeval::Accumulate, "COCOeval::Accumulate"); |
|
m.def( |
|
"COCOevalEvaluateImages", |
|
&COCOeval::EvaluateImages, |
|
"COCOeval::EvaluateImages"); |
|
pybind11::class_<COCOeval::InstanceAnnotation>(m, "InstanceAnnotation") |
|
.def(pybind11::init<uint64_t, double, double, bool, bool>()); |
|
pybind11::class_<COCOeval::ImageEvaluation>(m, "ImageEvaluation") |
|
.def(pybind11::init<>()); |
|
} |
|
|
|
TORCH_LIBRARY(detectron2, m) { |
|
m.def("nms_rotated", &nms_rotated); |
|
m.def("box_iou_rotated", &box_iou_rotated); |
|
m.def("roi_align_rotated_forward", &ROIAlignRotated_forward); |
|
m.def("roi_align_rotated_backward", &ROIAlignRotated_backward); |
|
} |
|
} |
|
|