File size: 3,290 Bytes
81efcf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_extended_async_copies : enable

ushort extract_weights(uchar val, int bit) { return ((val >> bit) & 1); }

__kernel void binary_convolution(
    const __global half *restrict src_data,
    const __global uchar *restrict weights_data,
    __global half *restrict dst_data,
    float pad_value,

    int IW,
    int IH,
    int IC,

    int DW,
    int DH,

    int GC,

    int KW,
    int KH,

    int PW,
    int PH,

    int SW,
    int SH,

    int OW)
{
    __local half src_local[32 * 1024];
    __local half dst_local[2 * 1024];

    const int oh = get_group_id(0);
    const int oc = get_group_id(1);
    const int OH = get_global_size(0);
    const int OC = get_global_size(1);

    const int gc = oc / (OC / GC);

    if (oh * SH >= 0 && oh * SH <= IH - 1) {
        const __global half *src = src_data + (gc * IC / GC) * IW * IH + (SH * oh) * IW;

        event_t e1 = async_work_group_copy_2D2D(
            src_local, // dst
            src, // src
            IW, // num_elements_per_line,
            IC / GC, // num_lines,
            IH * IW - IW, // src_line_stride,
            0, // dst_line_stride,
            0);
        wait_group_events(1, &e1);
    }

    half pad_value_half = convert_half(pad_value);

    //padding row
    if (oh * SH > IH - 1) {
        __local half *dst = src_local;
        for (int c = 0; c < IC / GC; c++) {
            #pragma unroll 8
            for (int j = 0; j < IW; j++) {
                dst[j] = pad_value_half;
            }
            dst += IW;
        }
    }

    int OWS = SW * OW;
    ushort8 in;

    for (int ows8 = 0; ows8 < (OWS + 7) / 8; ows8++) {
        ushort8 val = {0, 0, 0, 0, 0, 0, 0, 0};
        for (int ic = 0; ic < IC / GC; ++ic) {
            __local half *src = (__local half *)((__local half8 *)(src_local + ic * IW) + ows8);
            int weight_pos    = oc * IC / GC + ic;
            ushort w =
                extract_weights(weights_data[((weight_pos + 0)) / 8], ((weight_pos + 0) % 8));

            if ((ows8 * 8) <= IW - 1) {
                in = *((__local ushort8 *)(src));
            }

            //padding column
            if (ows8 * 8 + 7 > IW - 1) {
                int boundary = (IW - 1) - ows8 * 8 + 1;
                boundary     = boundary < 0 ? 0 : boundary;
                for (int offset = boundary; offset < 8; offset++) {
                    *((half *)(&in) + offset) = pad_value_half;
                }
            }

            ushort8 w8 = (ushort8)(w);

            ushort8 cond =
                (((in) < (ushort8)0x8000) && (in > (ushort8)0x0000)) ? (ushort8)(1) : (ushort8)(0);

            val += (cond ^ w8);
        }

        ushort8 val_shift = val << 1;
        int boundary      = (ows8 * 8 + 7) / SW < OW - 1 ? (ows8 * 8 + 7) / SW : OW - 1;
        for (int ow = (ows8 * 8 + SW - 1) / SW; ow <= boundary; ow++) {
            *(dst_local + ow) = (half)(IC / GC - *((ushort *)(&val_shift) + ow * SW - ows8 * 8));
        }
    }

    barrier(CLK_LOCAL_MEM_FENCE);

    event_t e2 = async_work_group_copy(dst_data + oc * OW * OH + oh * OW, dst_local, OW, 0);
    wait_group_events(1, &e2);
}