Spaces:
Runtime error
Runtime error
// Copyright (C) 2018-2022 Intel Corporation | |
// SPDX-License-Identifier: Apache-2.0 | |
// | |
inline int out_to_in(float ox, float f) { return (int)((ox + 0.5f) * f); } | |
void interpolationCHW_nn(__local half *psrc, __local half *pdst, int OW, int IW, int C, float rw, float rh) | |
{ | |
float alpha = rh / 2.0f - 0.5f; | |
for (int w = 0; w < OW / 8; w++) { | |
float fw0 = rw * (w * 8 + 0) + alpha; | |
float fw1 = rw * (w * 8 + 1) + alpha; | |
float fw2 = rw * (w * 8 + 2) + alpha; | |
float fw3 = rw * (w * 8 + 3) + alpha; | |
float fw4 = rw * (w * 8 + 4) + alpha; | |
float fw5 = rw * (w * 8 + 5) + alpha; | |
float fw6 = rw * (w * 8 + 6) + alpha; | |
float fw7 = rw * (w * 8 + 7) + alpha; | |
int iw0 = min((int)ROUND(fw0), IW - 1); | |
int iw1 = min((int)ROUND(fw1), IW - 1); | |
int iw2 = min((int)ROUND(fw2), IW - 1); | |
int iw3 = min((int)ROUND(fw3), IW - 1); | |
int iw4 = min((int)ROUND(fw4), IW - 1); | |
int iw5 = min((int)ROUND(fw5), IW - 1); | |
int iw6 = min((int)ROUND(fw6), IW - 1); | |
int iw7 = min((int)ROUND(fw7), IW - 1); | |
for (int c = 0; c < C; c++) { | |
half8 val = { | |
*((__local half *)(psrc + c * IW + iw0)), | |
*((__local half *)(psrc + c * IW + iw1)), | |
*((__local half *)(psrc + c * IW + iw2)), | |
*((__local half *)(psrc + c * IW + iw3)), | |
*((__local half *)(psrc + c * IW + iw4)), | |
*((__local half *)(psrc + c * IW + iw5)), | |
*((__local half *)(psrc + c * IW + iw6)), | |
*((__local half *)(psrc + c * IW + iw7)), | |
}; | |
*((__local half8 *)(pdst + c * OW + w * 8)) = val; | |
} | |
} | |
for (int w = OW / 8 * 8; w < OW; w++) { | |
float fw = rw * w + alpha; | |
int iw0 = min((int)ROUND(fw), IW - 1); | |
for (int c = 0; c < C; c++) { | |
*((__local half *)(pdst + c * OW + w)) = *((__local half *)(psrc + c * IW + iw0)); | |
} | |
} | |
} | |
kernel void resample_nearest( | |
__global const half *restrict src, | |
__global half *restrict dst, | |
int iw, | |
int ih, | |
float factor, | |
int ow, | |
int oh, | |
int channels) | |
{ | |
__local half local_src[14 * 1024]; | |
__local half local_dst[14 * 1024]; | |
const int oy_first = get_group_id(1) * get_local_size(1); | |
const int oy_last = (get_group_id(1) + 1) * get_local_size(1) - 1; | |
const int iy_first = out_to_in(oy_first, 1.0 / factor); | |
const int iy_last = out_to_in(oy_last, 1.0 / factor); | |
const int iy_size = iy_last - iy_first + 1; | |
event_t e1 = async_work_group_copy_2D2D( | |
local_src, // dst | |
src + get_group_id(2) * channels * ih * iw + iy_first * iw, // src | |
iy_size * iw, // num_elements_per_line, | |
channels, // num_lines, | |
ih * iw - iy_size * iw, // src_line_stride, | |
0, // dst_line_stride, | |
0); | |
wait_group_events(1, &e1); | |
interpolationCHW_nn(local_src, local_dst, ow, iw, channels, 1.0 / factor, 1.0 / factor); | |
event_t e2 = async_work_group_copy_2D2D( | |
dst + get_group_id(2) * channels * get_global_size(1) * ow + get_group_id(1) * get_local_size(1) * ow, // dst | |
local_dst, // src | |
get_local_size(1) * ow, // size_t num_elements_per_line, | |
channels, // size_t num_lines, | |
0, // size_t src_line_stride, | |
get_global_size(1) * ow - get_local_size(1) * ow, // size_t dst_line_stride, | |
0); | |
wait_group_events(1, &e2); | |
} | |