File size: 4,415 Bytes
d1a84ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/*
 * Copyright 2021 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_

#include "sparse_matmul/compute/ar_inputs.h"
#include "sparse_matmul/numerics/fast_transcendentals.h"

namespace csrblocksparse {

constexpr int kGenericSIMDWidth = 4;

// TODO(b/188702959): Rename arguments to match gru_gates.h.
template <typename GRUStateType, typename GRUMatMulOutType, typename QR_W_Type,
          typename SampleType, ARInputsMode kInputsMode,
          bool SplitGates = false>
void GoThroughGates(int start, int end, const QR_W_Type* qr_ptr,
                    const GRUMatMulOutType* gru_gates_ptr,
                    const GRUMatMulOutType* gru_gates_other_ptr,
                    const GRUMatMulOutType* conditioning_ptr,
                    GRUStateType* gru_h_ptr, const QR_W_Type* w_hat,
                    int proj_size, const SampleType* coarse_at_sminus1,
                    const SampleType* fine_at_sminus1,
                    const SampleType* coarse_at_s = nullptr) {
  float qr_cell = 0.0f, reset, update, cell;
  for (int i = start; i < end; ++i) {
    if (kInputsMode == ARInputsMode::k0ARInputs) {
      reset = static_cast<float>(gru_gates_ptr[i]);
      update = static_cast<float>(gru_gates_ptr[proj_size + i]);
    } else {
      float qr_c_reset = static_cast<float>(qr_ptr[2 * i + 0]);
      float qr_f_reset = static_cast<float>(qr_ptr[2 * i + 1]);
      float qr_c_update = static_cast<float>(qr_ptr[2 * proj_size + 2 * i + 0]);
      float qr_f_update = static_cast<float>(qr_ptr[2 * proj_size + 2 * i + 1]);
      float qr_c_cell = static_cast<float>(qr_ptr[4 * proj_size + 2 * i + 0]);
      float qr_f_cell = static_cast<float>(qr_ptr[4 * proj_size + 2 * i + 1]);
      float w_hat_i_reset = 0.0f;
      float w_hat_i_update = 0.0f;
      float w_hat_i_cell = 0.0f;
      if (kInputsMode == ARInputsMode::k3ARInputs) {
        w_hat_i_reset = static_cast<float>(w_hat[i]);
        w_hat_i_update = static_cast<float>(w_hat[proj_size + i]);
        w_hat_i_cell = static_cast<float>(w_hat[2 * proj_size + i]);
      }
      float coarse = static_cast<float>(coarse_at_sminus1[0]);
      float fine = static_cast<float>(fine_at_sminus1[0]);
      reset = qr_c_reset * coarse + qr_f_reset * fine;
      update = qr_c_update * coarse + qr_f_update * fine;
      qr_cell = qr_c_cell * coarse + qr_f_cell * fine;
      if (kInputsMode == ARInputsMode::k3ARInputs) {
        float coarse = static_cast<float>(coarse_at_s[0]);
        reset += w_hat_i_reset * coarse;
        update += w_hat_i_update * coarse;
        qr_cell += w_hat_i_cell * coarse;
      }
      reset += static_cast<float>(gru_gates_ptr[i]);
      update += static_cast<float>(gru_gates_ptr[proj_size + i]);
    }
    cell = static_cast<float>(gru_gates_ptr[2 * proj_size + i]);
    if (SplitGates) {
      reset += static_cast<float>(gru_gates_other_ptr[i]);
      update += static_cast<float>(gru_gates_other_ptr[proj_size + i]);
      cell += static_cast<float>(gru_gates_other_ptr[2 * proj_size + i]);
    }
    float reset_conditioning = static_cast<float>(conditioning_ptr[i]);
    float update_conditioning =
        static_cast<float>(conditioning_ptr[proj_size + i]);
    float cell_conditioning =
        static_cast<float>(conditioning_ptr[2 * proj_size + i]);
    reset = fast_sigmoid(reset + reset_conditioning);
    update = fast_sigmoid(update + update_conditioning);
    float hbar = fast_tanh(qr_cell + reset * cell + cell_conditioning);
    int h_index = i;
    float prev_h = static_cast<float>(gru_h_ptr[h_index]);
    float diff = prev_h - hbar;
    float new_h = hbar + diff * update;
    gru_h_ptr[h_index] = static_cast<GRUStateType>(new_h);
  }
}

}  // namespace csrblocksparse

#endif  // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_