Spaces:
Runtime error
Runtime error
File size: 1,704 Bytes
128757a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 |
// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
#pragma once
#include "cpu/vision.h"
#ifdef WITH_CUDA
#include "cuda/vision.h"
#endif
// Interface for Python
at::Tensor ROIAlign_forward(const at::Tensor& input,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int sampling_ratio) {
if (input.device().is_cuda()) {
#ifdef WITH_CUDA
return ROIAlign_forward_cuda(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
return ROIAlign_forward_cpu(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio);
}
at::Tensor ROIAlign_backward(const at::Tensor& grad,
const at::Tensor& rois,
const float spatial_scale,
const int pooled_height,
const int pooled_width,
const int batch_size,
const int channels,
const int height,
const int width,
const int sampling_ratio) {
if (grad.device().is_cuda()) {
#ifdef WITH_CUDA
return ROIAlign_backward_cuda(grad, rois, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width, sampling_ratio);
#else
AT_ERROR("Not compiled with GPU support");
#endif
}
AT_ERROR("Not implemented on the CPU");
}
|