Spaces:
Runtime error
Runtime error
/* | |
* Copyright 2021 Google LLC | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
// IWYU pragma: begin_exports | |
// IWYU pragma: end_exports | |
namespace csrblocksparse { | |
// The master template is really a catch-all for the unimplemented cases to | |
// run the generics. | |
template <typename GRUStateType, typename InputType, typename SampleType = void> | |
class GruGates : public MatmulBase { | |
public: | |
using SampleWeightType = float; | |
static constexpr int kSIMDWidth = kGenericSIMDWidth; | |
// Generic GRU function covers all uses for WaveRNN-like architectures and | |
// conditioning. | |
// Controlled by template parameters thus: | |
// - |kInputsMode| == |k0ARInputs|: There are no autoregressive inputs so | |
// |ar_sample0|, |ar_sample1|, |ar_sample2|, |ar_01_weights|, | |
// |ar_2_weights| are ignored. | |
// - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied | |
// by |ar_01_weights| and added to the (conditioning) input. | |
// - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by | |
// |ar_2_weights| and added to the other two |ar_inputs| (and added to the | |
// conditioning input). | |
// - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary | |
// recurrent input that must be added to |*gru_recurrent_ptr|. | |
// - |num_replicas| determines the number of duplicates of the output to be | |
// written, separated by |replica_stride|. | |
// - |start|, |end| are |rows| in [0, |state_size|] to be processed by this | |
// thread. | |
// | |
// Previous state is read from |*gru_state_ptr| and the new state is written | |
// to *(|gru_state_ptr| + i * |replica_stride| for i in [0, |num_replicas|)). | |
template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs, | |
bool kSplitGates = false> | |
void GruWithARInput(int start, int end, int state_size, | |
const InputType* gru_recurrent_ptr, | |
const InputType* input_ptr, GRUStateType* gru_state_ptr, | |
const SampleType* ar_sample0 = nullptr, | |
const SampleType* ar_sample1 = nullptr, | |
const SampleWeightType* ar_01_weights = nullptr, | |
int num_replicas = 1, int replica_stride = 0, | |
const SampleType* ar_sample2 = nullptr, | |
const SampleWeightType* ar_2_weights = nullptr, | |
const InputType* gru_recurrent_other_ptr = nullptr) { | |
CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica"; | |
GoThroughGates<GRUStateType, InputType, SampleWeightType, SampleType, | |
kInputsMode, kSplitGates>( | |
start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr, | |
input_ptr, gru_state_ptr, ar_2_weights, state_size, ar_sample0, | |
ar_sample1, ar_sample2); | |
} | |
// No AR inputs, no split gates, no batching, no replicated outputs. | |
// TODO(b/188702959): Redirect conditioning GRU here, removing code from | |
// gru_layer.h. | |
// Copy to specializations. | |
void PlainGru(int start, int end, int state_size, | |
const InputType* gru_recurrent_ptr, const InputType* input_ptr, | |
GRUStateType* gru_state_ptr) { | |
GruWithARInput<ARInputsMode::k0ARInputs>( | |
start, end, state_size, gru_recurrent_ptr, input_ptr, gru_state_ptr); | |
} | |
}; | |
// Partial specialization for float. | |
template <> | |
class GruGates<float, float, float> : public MatmulBase { | |
public: | |
static constexpr int kSIMDWidth = kNeonSIMDWidth; | |
// Generic GRU function covers all uses for WaveRNN-like architectures and | |
// conditioning. | |
template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs, | |
bool kSplitGates = false> | |
void GruWithARInput(int start, int end, int state_size, | |
const float* gru_recurrent_data, const float* input_data, | |
float* gru_state_data, const float* ar_sample0 = nullptr, | |
const float* ar_sample1 = nullptr, | |
const float* ar_01_weights = nullptr, | |
int num_replicas = 1, int replica_stride = 0, | |
const float* ar_sample2 = nullptr, | |
const float* ar_2_weights = nullptr, | |
const float* gru_recurrent_other_data = nullptr) { | |
DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica"; | |
GoThroughGatesFloat<kInputsMode, kSplitGates>( | |
start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data, | |
input_data, gru_state_data, ar_2_weights, state_size, ar_sample0, | |
ar_sample1, ar_sample2); | |
} | |
}; | |
// Partial specialization for fixed types. The sample weights are always float | |
// whatever the fixed type of the other weights. | |
template <int kGRUStateBits, int kInputBits, int kSampleBits> | |
class GruGates<fixed16<kGRUStateBits>, fixed32<kInputBits>, | |
fixed16<kSampleBits>> : public MatmulBase { | |
public: | |
static constexpr int kSIMDWidth = kNeonSIMDWidth; | |
static constexpr int kSIMDWidth = kAVX2SIMDWidth * 2; | |
static constexpr int kSIMDWidth = kGenericSIMDWidth; | |
using GRUStateType = fixed16<kGRUStateBits>; | |
using InputType = fixed32<kInputBits>; | |
using SampleType = fixed16<kSampleBits>; | |
using SampleWeightType = float; | |
static constexpr int kInputMantissaBits = InputType::kMantissaBits; | |
static constexpr int kSampleMantissaBits = SampleType::kMantissaBits; | |
static constexpr int kStateMantissaBits = GRUStateType::kMantissaBits; | |
// Generic GRU function covers all uses for WaveRNN-like architectures and | |
// conditioning. | |
template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs, | |
bool kSplitGates = false> | |
void GruWithARInput(int start, int end, int state_size, | |
const InputType* gru_recurrent_data, | |
const InputType* input_data, GRUStateType* gru_state_data, | |
const SampleType* ar_sample0 = nullptr, | |
const SampleType* ar_sample1 = nullptr, | |
const SampleWeightType* ar_01_weights = nullptr, | |
int num_replicas = 1, int replica_stride = 0, | |
const SampleType* ar_sample2 = nullptr, | |
const SampleWeightType* ar_2_weights = nullptr, | |
const InputType* gru_recurrent_other_data = nullptr) { | |
const int32_t* gru_recurrent_ptr = | |
reinterpret_cast<const int32_t*>(gru_recurrent_data); | |
const int32_t* gru_recurrent_other_ptr = | |
reinterpret_cast<const int32_t*>(gru_recurrent_other_data); | |
const int32_t* input_ptr = reinterpret_cast<const int32_t*>(input_data); | |
int16_t* gru_state_ptr = reinterpret_cast<int16_t*>(gru_state_data); | |
// The samples are fixed16, but we scale them up here and convert to float | |
// so that the product with the QR weights is always on the same scale as | |
// InputType, so we don't have to do any more scaling inside. | |
const float sample_factor = static_cast<float>(1 << kInputMantissaBits); | |
const float sample_factor = 1.0f; | |
// AR sample 0 and 1 are packed into a pair because the QR weights are | |
// formatted with the weights interleaved for sample 0 and 1. | |
std::pair<float, float> ar_sample01; | |
float ar_sample2_float = 0.0f; | |
if (kInputsMode == ARInputsMode::k2ARInputs || | |
kInputsMode == ARInputsMode::k3ARInputs) { | |
ar_sample01 = {static_cast<float>(*ar_sample0) * sample_factor, | |
static_cast<float>(*ar_sample1) * sample_factor}; | |
if (kInputsMode == ARInputsMode::k3ARInputs) { | |
ar_sample2_float = static_cast<float>(*ar_sample2) * sample_factor; | |
} | |
} | |
CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!"; | |
GruGatesAVXFixed<kInputMantissaBits, kStateMantissaBits, kInputsMode, | |
kSplitGates>( | |
start, end, state_size, gru_recurrent_ptr, input_ptr, &ar_sample01, | |
ar_01_weights, num_replicas, replica_stride, &ar_sample2_float, | |
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr); | |
DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica"; | |
GoThroughGatesFixed<GRUStateType, InputType, kInputsMode, kSplitGates>( | |
start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr, | |
input_ptr, gru_state_ptr, ar_2_weights, state_size, &ar_sample01, | |
&ar_sample2_float); | |
CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica"; | |
GoThroughGates<GRUStateType, InputType, SampleWeightType, SampleType, | |
kInputsMode, kSplitGates>( | |
start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data, | |
input_data, gru_state_data, ar_2_weights, state_size, ar_sample0, | |
ar_sample1, ar_sample2); | |
} | |
}; | |
} // namespace csrblocksparse | |