/* * Copyright 2021 Google LLC * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_ #define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_ #if defined __ARM_NEON || defined __aarch64__ #include #endif #include #include "sparse_matmul/compute/ar_inputs.h" #include "sparse_matmul/numerics/fast_transcendentals.h" namespace csrblocksparse { static constexpr int kNeonSIMDWidth = 4; // ------ Scalar calculation -------- // See "Efficient Neural Audio Synthesis" for a description of the calculation. // https://arxiv.org/abs/1802.08435 // // NOTE: // |sample| = (|coarse_at_sminus1|, |fine_at_sminus1|, // |coarse_at_sminus1|, |fine_at_sminus1|) // |w_sample| = (|coarse_at_s|, |coarse_at_s|, |coarse_at_s|, |coarse_at_s|) // // CHEATSHEET: // vld1q_f32 = load 4 32-bit floats // vmulq_f32(a, b) : return a * b; // vaddq_f32(a, b) : return a + b; // vmlaq_f32(c, a, b) : return c + a * b; // vpaddq_f32(a, b) : return (a0 + a1, a2 + a3, b0 + b1, b2 + b3) // vsubq_f32(a, b) : return a - b; // vst1q_f32 = store 4 32-bit floats #if defined __ARM_NEON || defined __aarch64__ #if !defined __aarch64__ // Backport of vpaddq_f32 to ARM32. inline float32x4_t vpaddq_f32(float32x4_t a, float32x4_t b) { float32x2_t a10 = vget_low_f32(a); float32x2_t a32 = vget_high_f32(a); float32x2_t b10 = vget_low_f32(b); float32x2_t b32 = vget_high_f32(b); return vcombine_f32(vpadd_f32(a10, a32), vpadd_f32(b10, b32)); } #endif template void GoThroughGatesFloat(int start, int end, const float* qr_ptr, const float* gru_gates_ptr, const float* gru_gates_other_ptr, const float* conditioning_ptr, float* gru_h_ptr, const float* w_hat, int proj_size, const float* coarse_at_sminus1, const float* fine_at_sminus1, const float* coarse_at_s) { // Increment all the pointers to save on pointer arithmetic in the loop. conditioning_ptr += start; gru_h_ptr += start; gru_gates_ptr += start; if (SplitGates) { DCHECK_NE(gru_gates_other_ptr, nullptr); gru_gates_other_ptr += start; } if (kInputsMode != ARInputsMode::k0ARInputs) { DCHECK_NE(qr_ptr, nullptr); qr_ptr += 2 * start; DCHECK_NE(coarse_at_sminus1, nullptr); DCHECK_NE(fine_at_sminus1, nullptr); if (kInputsMode == ARInputsMode::k3ARInputs) { DCHECK_NE(w_hat, nullptr); DCHECK_NE(coarse_at_s, nullptr); w_hat += start; } } for (int i = start; i < end; i += kNeonSIMDWidth) { float32x4_t reset = vld1q_f32(gru_gates_ptr); float32x4_t update = vld1q_f32(gru_gates_ptr + proj_size); float32x4_t cell = vld1q_f32(gru_gates_ptr + 2 * proj_size); float32x4_t qr_cell; if (SplitGates) { reset = vaddq_f32(reset, vld1q_f32(gru_gates_other_ptr)); update = vaddq_f32(update, vld1q_f32(gru_gates_other_ptr + proj_size)); cell = vaddq_f32(cell, vld1q_f32(gru_gates_other_ptr + 2 * proj_size)); } if (kInputsMode != ARInputsMode::k0ARInputs) { // Setup the sample vector. float32x4_t sample = vdupq_n_f32(*coarse_at_sminus1); sample = vsetq_lane_f32(*fine_at_sminus1, sample, 1); sample = vsetq_lane_f32(*fine_at_sminus1, sample, 3); // All auto types are float32x4_t, auto used to fit statements on one line // for readability. Do two rows of QR at once. auto qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample); auto qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample); auto qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1); auto qr_update_0 = vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample); auto qr_update_1 = vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample); auto qr_update = vpaddq_f32(qr_update_0, qr_update_1); auto qr_cell_0 = vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample); auto qr_cell_1 = vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample); qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1); if (kInputsMode == ARInputsMode::k3ARInputs) { float32x4_t w_sample = vdupq_n_f32(*coarse_at_s); qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample); qr_update = vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample); qr_cell = vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample); } reset = vaddq_f32(reset, qr_reset); update = vaddq_f32(update, qr_update); } auto reset_conditioning = vld1q_f32(conditioning_ptr); auto update_conditioning = vld1q_f32(conditioning_ptr + proj_size); auto cell_conditioning = vld1q_f32(conditioning_ptr + 2 * proj_size); reset = fast_sigmoid(vaddq_f32(reset, reset_conditioning)); update = fast_sigmoid(vaddq_f32(update, update_conditioning)); if (kInputsMode == ARInputsMode::k0ARInputs) { cell = vmulq_f32(reset, cell); } else { cell = vmlaq_f32(qr_cell, reset, cell); } auto hbar = fast_tanh(vaddq_f32(cell, cell_conditioning)); auto prev_h = vld1q_f32(gru_h_ptr); auto diff = vsubq_f32(prev_h, hbar); auto new_h = vmlaq_f32(hbar, diff, update); vst1q_f32(gru_h_ptr, new_h); // Increment all the pointers. conditioning_ptr += kNeonSIMDWidth; gru_h_ptr += kNeonSIMDWidth; gru_gates_ptr += kNeonSIMDWidth; if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth; if (kInputsMode != ARInputsMode::k0ARInputs) { qr_ptr += 2 * kNeonSIMDWidth; if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth; } } } // This version should only be used if all of the 32-bit fixed point // representations have the same number of mantissa bits. // |ar_at_sminus1| packs sample 0 and 1 into a pair because the QR weights are // formatted with the weights interleaved for sample 0 and 1. The two samples // represent coarse and fine for WaveRNN. template void GoThroughGatesFixed(int start, int end, const float* qr_ptr, const int32_t* gru_gates_ptr, const int32_t* gru_gates_other_ptr, const int32_t* conditioning_ptr, int16_t* gru_h_ptr, const float* w_hat, int proj_size, const std::pair* ar_at_sminus1, const float* coarse_at_s) { // Increment all the pointers to save on pointer arithmetic in the loop. conditioning_ptr += start; gru_h_ptr += start; gru_gates_ptr += start; if (SplitGates) { DCHECK_NE(gru_gates_other_ptr, nullptr); gru_gates_other_ptr += start; } float32x4_t sample01; float32x4_t w_sample; if (kInputsMode != ARInputsMode::k0ARInputs) { DCHECK_NE(qr_ptr, nullptr); qr_ptr += 2 * start; DCHECK_NE(ar_at_sminus1, nullptr); sample01 = vdupq_n_f32(ar_at_sminus1->first); sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 1); sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 3); if (kInputsMode == ARInputsMode::k3ARInputs) { DCHECK_NE(w_hat, nullptr); DCHECK_NE(coarse_at_s, nullptr); w_hat += start; w_sample = vdupq_n_f32(*coarse_at_s); } } for (int i = start; i < end; i += kNeonSIMDWidth) { auto reset = vld1q_s32(gru_gates_ptr); auto update = vld1q_s32(gru_gates_ptr + proj_size); // vcvtq_n_f32_s32 = convert 32-bit fixed point to fp32 auto cell_int = vld1q_s32(gru_gates_ptr + 2 * proj_size); if (SplitGates) { reset = vaddq_s32(reset, vld1q_s32(gru_gates_other_ptr)); update = vaddq_s32(update, vld1q_s32(gru_gates_other_ptr + proj_size)); cell_int = vaddq_s32(cell_int, vld1q_s32(gru_gates_other_ptr + 2 * proj_size)); } float32x4_t cell = vcvtq_n_f32_s32(cell_int, GRUMatMulOutType::kMantissaBits); float32x4_t qr_cell; if (kInputsMode != ARInputsMode::k0ARInputs) { // Do two rows of QR at once. float32x4_t qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample01); float32x4_t qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample01); float32x4_t qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1); float32x4_t qr_update_0 = vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample01); float32x4_t qr_update_1 = vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample01); float32x4_t qr_update = vpaddq_f32(qr_update_0, qr_update_1); float32x4_t qr_cell_0 = vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample01); float32x4_t qr_cell_1 = vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample01); qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1); if (kInputsMode == ARInputsMode::k3ARInputs) { float32x4_t w_sample = vdupq_n_f32(*coarse_at_s); qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample); qr_update = vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample); qr_cell = vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample); } reset = vaddq_s32( reset, vcvtq_n_s32_f32(qr_reset, GRUMatMulOutType::kMantissaBits)); update = vaddq_s32( update, vcvtq_n_s32_f32(qr_update, GRUMatMulOutType::kMantissaBits)); } auto reset_conditioning = vld1q_s32(conditioning_ptr); auto update_conditioning = vld1q_s32(conditioning_ptr + proj_size); float32x4_t cell_conditioning = vcvtq_n_f32_s32(vld1q_s32(conditioning_ptr + 2 * proj_size), GRUMatMulOutType::kMantissaBits); float32x4_t reset_f32 = fast_sigmoid( vaddq_s32(reset, reset_conditioning)); float32x4_t update_f32 = fast_sigmoid( vaddq_s32(update, update_conditioning)); if (kInputsMode == ARInputsMode::k0ARInputs) { cell = vmulq_f32(reset_f32, cell); } else { cell = vmlaq_f32(qr_cell, reset_f32, cell); } float32x4_t hbar = fast_tanh(vaddq_f32(cell, cell_conditioning)); float32x4_t prev_h = vcvtq_n_f32_s32(vmovl_s16(vld1_s16(gru_h_ptr)), GRUStateType::kMantissaBits); float32x4_t diff = vsubq_f32(prev_h, hbar); float32x4_t new_h = vmlaq_f32(hbar, diff, update_f32); // vcvtq_n_s32_f32 = convert fp32 to signed 32-bit fixed point // vqrshrn_n_s32 = saturating, rounding, narrowing right shift - used to // convert a 32-bit fixed point value to a 16-bit fixed point value vst1_s16(gru_h_ptr, vqrshrn_n_s32( vcvtq_n_s32_f32(new_h, GRUStateType::kMantissaBits + 16), 16)); // Increment all the pointers. conditioning_ptr += kNeonSIMDWidth; gru_h_ptr += kNeonSIMDWidth; gru_gates_ptr += kNeonSIMDWidth; if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth; if (kInputsMode != ARInputsMode::k0ARInputs) { qr_ptr += 2 * kNeonSIMDWidth; if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth; } } } #endif // defined __ARM_NEON || defined __aarch64__ } // namespace csrblocksparse #endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_