File size: 5,986 Bytes
81efcf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_extended_async_copies : enable

__kernel void Convolution3x3(
    const __global half *in_param,
    const __global half *out,
    const __global half *w,
    int IW,
    int IH,
    int IC,
    int OW,
    int OH,
    int OC,
    int KX,
    int KY,
    int stride_x,
    int stride_y,
    int pad_x,
    int pad_y,
    int dilation_x,
    int dilation_y)
{
    __local half in_local[8 * 1024];
    __local half out_local[8 * 1024];
    __local half w_local[8 * 1024];

    const int sizePlane = IW * IH;
    event_t e1          = async_work_group_copy_2D2D(
        in_local, // dst
        in_param + get_group_id(0) * stride_y * IW, // src
        3 * IW, // num_elements_per_line,
        IC, // num_lines,
        IW * IH - 3 * IW, // src_line_stride,
        0, // dst_line_stride,
        0);
    wait_group_events(1, &e1);

    const int sizeWeight = IC * 3 * 3;
    e1 = async_work_group_copy(w_local, w + get_group_id(1) * sizeWeight, sizeWeight, 0);
    wait_group_events(1, &e1);

    int oh = get_global_id(0);
    int oc = get_global_id(1);

    __local half *in = (__local half *)in_local + 1;

    int stride;
    int write_output = 0;
    __local half *src;

    if ((stride_x == 1) && (stride_y == 1)) {
        stride       = OW / 8;
        write_output = 1;
    }
    if ((stride_x == 2) && (stride_y == 2)) {
        stride       = OW / 4;
        write_output = 2;
    }

    for (int ow = 0; ow < stride; ow++) {
        float8 val = {0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f};
        for (int ic = 0; ic < IC; ++ic) {
            src             = (__local half *)((__local half8 *)(in + ic * IW * 3) + ow);
            __local half *k = (__local half *)(w_local + ic * 3 * 3);

            half8 aux_in00 = *((__local half8 *)src - 1);
            half8 aux_in01 = *((__local half8 *)src + 0);
            half8 aux_in02 = *((__local half8 *)src + 1);
            half8 aux_in10 = *((__local half8 *)(src + IW) - 1);
            half8 aux_in11 = *((__local half8 *)(src + IW) + 0);
            half8 aux_in12 = *((__local half8 *)(src + IW) + 1);
            half8 aux_in20 = *((__local half8 *)(src + IW * 2) - 1);
            half8 aux_in21 = *((__local half8 *)(src + IW * 2) + 0);
            half8 aux_in22 = *((__local half8 *)(src + IW * 2) + 1);

            short8 in00 = *((short8 *)&aux_in00);
            short8 in01 = *((short8 *)&aux_in01);
            short8 in02 = *((short8 *)&aux_in02);
            short8 in10 = *((short8 *)&aux_in10);
            short8 in11 = *((short8 *)&aux_in11);
            short8 in12 = *((short8 *)&aux_in12);
            short8 in20 = *((short8 *)&aux_in20);
            short8 in21 = *((short8 *)&aux_in21);
            short8 in22 = *((short8 *)&aux_in22);

            short8 aux_aux00 = __builtin_shave_cmu_alignvec_rri_short8(in00, in01, 14);
            short8 aux_aux01 = in01;
            short8 aux_aux02 = __builtin_shave_cmu_alignvec_rri_short8(in01, in02, 2);
            short8 aux_aux10 = __builtin_shave_cmu_alignvec_rri_short8(in10, in11, 14);
            short8 aux_aux11 = in11;
            short8 aux_aux12 = __builtin_shave_cmu_alignvec_rri_short8(in11, in12, 2);
            short8 aux_aux20 = __builtin_shave_cmu_alignvec_rri_short8(in20, in21, 14);
            short8 aux_aux21 = in21;
            short8 aux_aux22 = __builtin_shave_cmu_alignvec_rri_short8(in21, in22, 2);

            half8 aux00 = *((half8 *)&aux_aux00);
            half8 aux01 = *((half8 *)&aux_aux01);
            half8 aux02 = *((half8 *)&aux_aux02);
            half8 aux10 = *((half8 *)&aux_aux10);
            half8 aux11 = *((half8 *)&aux_aux11);
            half8 aux12 = *((half8 *)&aux_aux12);
            half8 aux20 = *((half8 *)&aux_aux20);
            half8 aux21 = *((half8 *)&aux_aux21);
            half8 aux22 = *((half8 *)&aux_aux22);

            half8 w00 = (half8)(*(k + 0));
            half8 w01 = (half8)(*(k + 1));
            half8 w02 = (half8)(*(k + 2));
            half8 w10 = (half8)(*(k + 3));
            half8 w11 = (half8)(*(k + 4));
            half8 w12 = (half8)(*(k + 5));
            half8 w20 = (half8)(*(k + 6));
            half8 w21 = (half8)(*(k + 7));
            half8 w22 = (half8)(*(k + 8));

            val += convert_float8(aux00) * convert_float8(w00);
            val += convert_float8(aux01) * convert_float8(w01);
            val += convert_float8(aux02) * convert_float8(w02);
            val += convert_float8(aux10) * convert_float8(w10);
            val += convert_float8(aux11) * convert_float8(w11);
            val += convert_float8(aux12) * convert_float8(w12);
            val += convert_float8(aux20) * convert_float8(w20);
            val += convert_float8(aux21) * convert_float8(w21);
            val += convert_float8(aux22) * convert_float8(w22);
        }
        if (write_output == 2) *((__local half4 *)(out_local) + ow) = convert_half4(val.s0246);
        if (write_output == 1) *((__local half8 *)(out_local) + ow) = convert_half8(val);
    }

    for (int ow = OW & ~(0x7); ow < OW; ow++) {
        float val = 0.0f;
        for (int ic = 0; ic < IC; ++ic) {
            for (int ky = 0; ky < 3; ++ky) {
                for (int kx = 0; kx < 3; ++kx) {
                    int iw = ow * stride_x - pad_x + kx * dilation_x;
                    int ih = oh * stride_y - pad_y + ky * dilation_y;

                    val += convert_float(in[ic * IW * 3 + (ky * dilation_y) * IW + iw])
                           * convert_float(w_local[ic * 3 * 3 + ky * 3 + kx]);
                }
            }
        }
        out_local[ow] = convert_half(val);
    }

    barrier(CLK_LOCAL_MEM_FENCE);

    event_t e2 = async_work_group_copy(
        out + get_group_id(1) * OW * OH + get_group_id(0) * OW,
        out_local,
        OW,
        0);
    wait_group_events(1, &e2);
}