Spaces:
Runtime error
Runtime error
/* | |
* Copyright 2021 Google LLC | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
// TODO(b/188702959): Remove fast_transcendentals with GRU refactor. | |
namespace csrblocksparse { | |
namespace detail { | |
template <typename WeightType, typename RhsType, typename OutType> | |
struct IsAllowableFloatTypes | |
: std::integral_constant<bool, std::is_same<WeightType, float>::value && | |
std::is_same<RhsType, float>::value && | |
std::is_same<OutType, float>::value> {}; | |
// 16-bit inputs, 32-bit output exponent matches sum of input exponents | |
// OR | |
// 16-bit inputs, 16-bit output - will shift to match exponent | |
template <typename WeightType, typename RhsType, typename OutType> | |
struct IsAllowableFixedTypes | |
: std::integral_constant<bool, (IsFixed16Type<WeightType>::value && | |
IsFixed16Type<RhsType>::value) && | |
(IsFixed32Type<OutType>::value || | |
IsFixed16Type<OutType>::value)> {}; | |
template <typename WeightType, typename RhsType, typename OutType> | |
struct ShouldEnableGenericKernel | |
: std::integral_constant< | |
bool, | |
!IsAllowableFloatTypes<WeightType, RhsType, OutType>::value && | |
!IsAllowableFixedTypes<WeightType, RhsType, OutType>::value> {}; | |
template <typename Type> | |
struct IsAddableFixedTypes | |
: std::integral_constant<bool, IsFixed32Type<Type>::value || | |
IsFixed16Type<Type>::value> {}; | |
template <typename Type> | |
struct ShouldEnableGenericAdd | |
: std::integral_constant<bool, !IsAddableFixedTypes<Type>::value> {}; | |
template <typename WeightType, typename RhsType, typename OutType> | |
struct ShouldEnableGenericKernel | |
: std::integral_constant< | |
bool, !IsAllowableFloatTypes<WeightType, RhsType, OutType>::value> {}; | |
template <typename Type> | |
struct ShouldEnableGenericAdd : std::true_type {}; | |
template <typename WeightType, typename RhsType, typename OutType> | |
struct ShouldEnableGenericSpMV_4x4 | |
: ShouldEnableGenericKernel<WeightType, RhsType, OutType> {}; | |
template <typename WeightType, typename RhsType, typename OutType> | |
struct ShouldEnableGenericSpMM5_4x4 | |
: ShouldEnableGenericKernel<WeightType, RhsType, OutType> {}; | |
template <typename WeightType, typename RhsType, typename OutType> | |
struct ShouldEnableGenericSpMV_1x1 : std::true_type {}; | |
template <typename WeightType, typename RhsType, typename OutType> | |
struct ShouldEnableGenericSpMM5_1x1 : std::true_type {}; | |
// The computational routines do NO error checking for speed. It is assumed | |
// that this has been handled by CSRBlockSparseMatrix. | |
// In-line function to extract results from a pair of registers and store in | |
// memory. Note that the non-const references are registers, and are modified | |
// by this function! | |
inline void Extract4Results(bool relu, __m256& sum1, __m256& sum2, | |
float** out_ptr) { | |
// Horizontally add the results. We have 2 registers, |sum1| and |sum2| that | |
// each contain 2 sets of 4 values that need to be added. | |
sum1 = _mm256_hadd_ps(sum1, sum2); | |
sum1 = _mm256_hadd_ps(sum1, sum1); | |
// Now |sum1| contains [|res0|, |res2|, |res0|, |res2|, |res1|, |res3|, | |
// |res1|, |res3|] | |
if (relu) { | |
sum1 = _mm256_max_ps(sum1, _mm256_setzero_ps()); | |
} | |
// It is really hard in AVX to cross the 128 bit 'lanes' and this is the | |
// *only* way to do it. | |
// Get the top half of |sum1| in to bottom of |sum2|. | |
sum2 = _mm256_permute2f128_ps(sum1, sum1, 1); | |
// Interleave the values between the two registers. | |
sum1 = _mm256_unpacklo_ps(sum1, sum2); | |
// Save the lower 128 bits (4 floats). | |
__m128 result = _mm256_extractf128_ps(sum1, 0); | |
_mm_store_ps(*out_ptr, result); | |
*out_ptr += 4; | |
} | |
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 | |
// blocked pattern, x is a vector and b is vector. Weights are stored for this | |
// routine by making each 4x4 block contiguous. Blocks are ordered in standard | |
// row-major format. column indices are converted to deltas and then multiplied | |
// by 2 to convert to bytes, so that the value can be used directly to offset | |
// the pointer into the rhs vector. | |
// | |
// NOTE: The bias is expected to have be multiplied by .25f prior to calling | |
// this function. This is automatically taken care of in SparseLinearLayer. | |
// The bias is reconstructed through horizontal additions, leads to a small | |
// speedup by reducing latencies at the end of the loop. | |
template <typename WeightType, typename RhsType, typename OutType> | |
typename std::enable_if<std::is_same<WeightType, float>::value && | |
std::is_same<RhsType, float>::value && | |
std::is_same<OutType, float>::value>::type | |
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
OutType* out_ptr, int64_t assigned_rows, | |
int64_t rows /* only used in SpMM variants */, | |
int64_t cols /* only used in SpMM variants */, int relu) { | |
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { | |
// Broadcast the biases by 4 to undo the division by 4 in the input biases. | |
__m256 sum1 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), | |
_mm_broadcast_ss(bias_ptr)); | |
bias_ptr += 2; | |
__m256 sum2 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), | |
_mm_broadcast_ss(bias_ptr)); | |
bias_ptr += 2; | |
int reduced_col_count = *nnz_per_row++; | |
for (int c = 0; c < reduced_col_count; ++c) { | |
int col_delta = *col_deltas_bytes++ / sizeof(RhsType); | |
rhs_ptr += col_delta; | |
// Multiply this 4x4 block. | |
__m256 rhs = | |
_mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr)); | |
__m256 weights1 = _mm256_load_ps(weights_ptr); | |
weights_ptr += 8; | |
sum1 = _mm256_add_ps(sum1, _mm256_mul_ps(weights1, rhs)); | |
__m256 weights2 = _mm256_load_ps(weights_ptr); | |
weights_ptr += 8; | |
sum2 = _mm256_add_ps(sum2, _mm256_mul_ps(weights2, rhs)); | |
} | |
Extract4Results(relu, sum1, sum2, &out_ptr); | |
} | |
} | |
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 | |
// blocked pattern, x is a fat vector with 5 columns and b is vector. b is | |
// broadcast. Weights are stored for this routine by making each 4x4 block | |
// contiguous. Blocks are ordered in standard row-major format. column indices | |
// are converted to deltas and then multiplied by 2 to convert to bytes, so | |
// that the value can be used directly to offset the pointer into the rhs | |
// vector. | |
// | |
// NOTE: The bias is expected to have be multiplied by .25f prior to calling | |
// this function. This is automatically taken care of in SparseLinearLayer. | |
// The bias is reconstructed through horizontal additions, leads to a small | |
// speedup by reducing latencies at the end of the loop. | |
template <typename WeightType, typename RhsType, typename OutType> | |
typename std::enable_if<std::is_same<WeightType, float>::value && | |
std::is_same<RhsType, float>::value && | |
std::is_same<OutType, float>::value>::type | |
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, | |
int relu) { | |
const RhsType* rhs_ptrs[5]; | |
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; | |
OutType* out_ptrs[5]; | |
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; | |
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { | |
// We will acumulate the results in 10 registers, |sum1_0| to |sum2_4|. | |
// Broadcast the biases by 4 to undo the division by 4 in the input biases. | |
__m256 sum1_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), | |
_mm_broadcast_ss(bias_ptr)); | |
bias_ptr += 2; | |
__m256 sum2_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1), | |
_mm_broadcast_ss(bias_ptr)); | |
bias_ptr += 2; | |
__m256 sum1_1 = sum1_0; | |
__m256 sum2_1 = sum2_0; | |
__m256 sum1_2 = sum1_0; | |
__m256 sum2_2 = sum2_0; | |
__m256 sum1_3 = sum1_0; | |
__m256 sum2_3 = sum2_0; | |
__m256 sum1_4 = sum1_0; | |
__m256 sum2_4 = sum2_0; | |
int reduced_col_count = *nnz_per_row++; | |
for (int c = 0; c < reduced_col_count; ++c) { | |
int col_delta = *col_deltas_bytes++ / sizeof(RhsType); | |
for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta; | |
// Multiply this 4x4 block. | |
__m256 rhs = | |
_mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[0])); | |
__m256 weights1 = _mm256_load_ps(weights_ptr); | |
weights_ptr += 8; | |
sum1_0 = _mm256_add_ps(sum1_0, _mm256_mul_ps(weights1, rhs)); | |
__m256 weights2 = _mm256_load_ps(weights_ptr); | |
weights_ptr += 8; | |
sum2_0 = _mm256_add_ps(sum2_0, _mm256_mul_ps(weights2, rhs)); | |
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[1])); | |
sum1_1 = _mm256_add_ps(sum1_1, _mm256_mul_ps(weights1, rhs)); | |
sum2_1 = _mm256_add_ps(sum2_1, _mm256_mul_ps(weights2, rhs)); | |
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[2])); | |
sum1_2 = _mm256_add_ps(sum1_2, _mm256_mul_ps(weights1, rhs)); | |
sum2_2 = _mm256_add_ps(sum2_2, _mm256_mul_ps(weights2, rhs)); | |
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[3])); | |
sum1_3 = _mm256_add_ps(sum1_3, _mm256_mul_ps(weights1, rhs)); | |
sum2_3 = _mm256_add_ps(sum2_3, _mm256_mul_ps(weights2, rhs)); | |
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[4])); | |
sum1_4 = _mm256_add_ps(sum1_4, _mm256_mul_ps(weights1, rhs)); | |
sum2_4 = _mm256_add_ps(sum2_4, _mm256_mul_ps(weights2, rhs)); | |
} | |
Extract4Results(relu, sum1_0, sum2_0, &out_ptrs[0]); | |
Extract4Results(relu, sum1_1, sum2_1, &out_ptrs[1]); | |
Extract4Results(relu, sum1_2, sum2_2, &out_ptrs[2]); | |
Extract4Results(relu, sum1_3, sum2_3, &out_ptrs[3]); | |
Extract4Results(relu, sum1_4, sum2_4, &out_ptrs[4]); | |
} | |
} | |
// In-line function to finish the computation of the result as 4x int32 in | |
// |sum|. | |
inline void Compute4Results(bool relu, int kShiftAmount, __m256i& sum) { | |
// Horizontally add the results. We have 1 register that contains results | |
// [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not | |
// cross lanes, so we end up with [0 1 0 1 2 3 2 3] | |
sum = _mm256_hadd_epi32(sum, sum); | |
// Permutes the middle two pairs to get the answers together. | |
sum = _mm256_permute4x64_epi64(sum, 0xd8); | |
if (kShiftAmount > 0) { | |
// Shift right with rounding to get the right number of mantissa bits. | |
__m256i rounding = _mm256_set1_epi32(1 << (kShiftAmount - 1)); | |
sum = _mm256_add_epi32(sum, rounding); | |
sum = _mm256_srai_epi32(sum, kShiftAmount); | |
} | |
// Now |sum| contains [|res0|, |res1|, |res2|, |res3|, |res0|, |res1|, | |
// |res2|, |res3|] | |
if (relu) { | |
sum = _mm256_max_epi32(sum, _mm256_setzero_si256()); | |
} | |
} | |
// In-line function to extract the 4x int32 results from |sum| to memory. | |
// Non-const reference for |sum| as it is a register. | |
inline void Extract4xint32(bool relu, int kShiftAmount, __m256i& sum, | |
int32_t** out_ptr) { | |
Compute4Results(relu, kShiftAmount, sum); | |
// Save the lower 128 bits (4x int32). | |
__m128i result = _mm256_extractf128_si256(sum, 0); | |
_mm_store_si128(reinterpret_cast<__m128i*>(*out_ptr), result); | |
*out_ptr += 4; | |
} | |
// In-line function to extract the 4x int32 results from sum to 4x int16 in | |
// memory. | |
// Non-const reference for |sum| as it is a register. | |
inline void Extract4xint16(bool relu, int kShiftAmount, __m256i& sum, | |
int16_t** out_ptr) { | |
Compute4Results(relu, kShiftAmount, sum); | |
// Clip to 16 bit range (with saturation) and pack in the bottom 64 bits. | |
// Converts the lower 4x int32 in bottom 128 bits to 4x int16 in bottom 64 | |
// bits, replicated in the next 64 bits. | |
sum = _mm256_packs_epi32(sum, sum); | |
// Save 4x int 16 from the bottom 64 bits. | |
*reinterpret_cast<int64_t*>(*out_ptr) = _mm256_extract_epi64(sum, 0); | |
*out_ptr += 4; | |
} | |
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 | |
// blocked pattern, x is a vector and b is vector. Weights are stored for this | |
// routine by making each 4x4 block contiguous. Blocks are ordered in standard | |
// row-major format. column indices are converted to deltas and then multiplied | |
// by 2 to convert to bytes, so that the value can be used directly to offset | |
// the pointer into the rhs vector. | |
// | |
// NOTE: The bias is expected to have be multiplied by .25f prior to calling | |
// this function. This is automatically taken care of in SparseLinearLayer. | |
// The bias is reconstructed through horizontal additions, leads to a small | |
// speedup by reducing latencies at the end of the loop. | |
template <typename WeightType, typename RhsType, typename OutType> | |
typename std::enable_if< | |
IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && | |
(IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type | |
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
OutType* out_ptr, int64_t assigned_rows, | |
int64_t rows /* only used in SpMM variants */, | |
int64_t cols /* only used in SpMM variants */, int relu) { | |
constexpr int kShiftAmount = | |
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - | |
OutType::kMantissaBits; | |
static_assert(kShiftAmount >= 0, | |
"Result must have fewer mantissa bits than product"); | |
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { | |
// Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3]. | |
__m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr)); | |
__m256i biases = _mm256_set_m128i(bias, bias); | |
bias_ptr += 4; | |
// Swap the top two pairs: [0 1 2 3 2 3 0 1] | |
// TODO(b/188702959): consider |_mm256_permutevar8x32|, and set the index | |
// register outside the row loop. | |
biases = _mm256_permute4x64_epi64(biases, 0xb4); | |
// Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3]. | |
biases = _mm256_unpacklo_epi32(biases, biases); | |
// Double the results to make up for the division by 4. | |
// TODO(b/188702959): consider moving this to where the biases are computed. | |
__m256i sum = _mm256_add_epi32(biases, biases); | |
// TODO(b/188702959): People don't like the old-fashioned, close-to-the- | |
// metal notation of *|nnz_per_row|++, so measure the effect of putting the | |
// increment in the for loop. | |
int reduced_col_count = *nnz_per_row; | |
++nnz_per_row; | |
for (int c = 0; c < reduced_col_count; ++c) { | |
int col_delta = *col_deltas_bytes++ / sizeof(RhsType); | |
rhs_ptr += col_delta; | |
// Multiply this 4x4 block. | |
// Get the 4x int16 into the bottom of rhs_64. | |
__m128i rhs_64 = | |
_mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptr)); | |
// Load all 16 weights. | |
__m256i weights = | |
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); | |
// Broadcast the rhs, pretending that each is a 64-bit unit: | |
// [0123 0123 0123 0123]. | |
__m256i rhs = _mm256_broadcastq_epi64(rhs_64); | |
weights_ptr += 16; | |
// |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally | |
// adds adjacent pairs to make 8x32 bit results. Add these to the sum. | |
sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs)); | |
} | |
static_assert( | |
IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value, | |
"AVX2 kernel only supports fixed16 and fixed32 types"); | |
// The only significant difference between fixed16 and fixed32 is the size | |
// of the storage unit. The registers have to be repacked accordingly. | |
if (IsFixed32Type<OutType>::value) { | |
Extract4xint32(relu, kShiftAmount, sum, | |
reinterpret_cast<int32_t**>(&out_ptr)); | |
} else { | |
Extract4xint16(relu, kShiftAmount, sum, | |
reinterpret_cast<int16_t**>(&out_ptr)); | |
} | |
} | |
} | |
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4 | |
// blocked pattern, x is a fat vector with 5 columns and b is vector. b is | |
// broadcast. Weights are stored for this routine by making each 4x4 block | |
// contiguous. Blocks are ordered in standard row-major format. column indices | |
// are converted to deltas and then multiplied by 2 to convert to bytes, so | |
// that the value can be used directly to offset the pointer into the rhs | |
// vector. | |
// | |
// NOTE: The bias is expected to have be multiplied by .25f prior to calling | |
// this function. This is automatically taken care of in SparseLinearLayer. | |
// The bias is reconstructed through horizontal additions, leads to a small | |
// speedup by reducing latencies at the end of the loop. | |
template <typename WeightType, typename RhsType, typename OutType> | |
typename std::enable_if< | |
IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value && | |
(IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type | |
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes, | |
const int32_t* nnz_per_row, const RhsType* rhs_ptr, | |
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr, | |
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols, | |
int relu) { | |
constexpr int kShiftAmount = | |
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits - | |
OutType::kMantissaBits; | |
static_assert(kShiftAmount >= 0, | |
"Result must have fewer mantissa bits than product"); | |
const RhsType* rhs_ptrs[5]; | |
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols; | |
OutType* out_ptrs[5]; | |
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows; | |
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) { | |
// We will acumulate the results in 5 registers, sum_0 to sum_4. | |
// Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3]. | |
__m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr)); | |
__m256i biases = _mm256_set_m128i(bias, bias); | |
bias_ptr += 4; | |
// Swap the top two pairs: [0 1 2 3 2 3 0 1] | |
biases = _mm256_permute4x64_epi64(biases, 0xb4); | |
// Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3]. | |
biases = _mm256_unpacklo_epi32(biases, biases); | |
// Double the results to make up for the division by 4. | |
__m256i sum_0 = _mm256_add_epi32(biases, biases); | |
__m256i sum_1 = sum_0; | |
__m256i sum_2 = sum_0; | |
__m256i sum_3 = sum_0; | |
__m256i sum_4 = sum_0; | |
int reduced_col_count = *nnz_per_row; | |
++nnz_per_row; | |
for (int c = 0; c < reduced_col_count; ++c) { | |
int col_delta = *col_deltas_bytes++ / sizeof(RhsType); | |
for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta; | |
// Multiply this 4x4 block. | |
// Get the 4x int16 into the bottom of |rhs_64|. | |
__m128i rhs_64 = | |
_mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[0])); | |
// Load all 16 weights. | |
__m256i weights = | |
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr)); | |
// Broadcast the rhs, pretending that each is a 64-bit unit: | |
// [0123 0123 0123 0123]. | |
__m256i rhs = _mm256_broadcastq_epi64(rhs_64); | |
weights_ptr += 16; | |
// |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally | |
// adds adjacent pairs to make 8x32 bit results. Add these to the sum. | |
sum_0 = _mm256_add_epi32(sum_0, _mm256_madd_epi16(weights, rhs)); | |
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[1])); | |
rhs = _mm256_broadcastq_epi64(rhs_64); | |
sum_1 = _mm256_add_epi32(sum_1, _mm256_madd_epi16(weights, rhs)); | |
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[2])); | |
rhs = _mm256_broadcastq_epi64(rhs_64); | |
sum_2 = _mm256_add_epi32(sum_2, _mm256_madd_epi16(weights, rhs)); | |
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[3])); | |
rhs = _mm256_broadcastq_epi64(rhs_64); | |
sum_3 = _mm256_add_epi32(sum_3, _mm256_madd_epi16(weights, rhs)); | |
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[4])); | |
rhs = _mm256_broadcastq_epi64(rhs_64); | |
sum_4 = _mm256_add_epi32(sum_4, _mm256_madd_epi16(weights, rhs)); | |
} | |
static_assert( | |
IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value, | |
"AVX2 kernel only supports fixed16 and fixed32 types"); | |
// The only significant difference between fixed16 and fixed32 is the size | |
// of the storage unit. The registers have to be repacked accordingly. | |
if (IsFixed32Type<OutType>::value) { | |
Extract4xint32(relu, kShiftAmount, sum_0, | |
reinterpret_cast<int32_t**>(&out_ptrs[0])); | |
Extract4xint32(relu, kShiftAmount, sum_1, | |
reinterpret_cast<int32_t**>(&out_ptrs[1])); | |
Extract4xint32(relu, kShiftAmount, sum_2, | |
reinterpret_cast<int32_t**>(&out_ptrs[2])); | |
Extract4xint32(relu, kShiftAmount, sum_3, | |
reinterpret_cast<int32_t**>(&out_ptrs[3])); | |
Extract4xint32(relu, kShiftAmount, sum_4, | |
reinterpret_cast<int32_t**>(&out_ptrs[4])); | |
} else { | |
Extract4xint16(relu, kShiftAmount, sum_0, | |
reinterpret_cast<int16_t**>(&out_ptrs[0])); | |
Extract4xint16(relu, kShiftAmount, sum_1, | |
reinterpret_cast<int16_t**>(&out_ptrs[1])); | |
Extract4xint16(relu, kShiftAmount, sum_2, | |
reinterpret_cast<int16_t**>(&out_ptrs[2])); | |
Extract4xint16(relu, kShiftAmount, sum_3, | |
reinterpret_cast<int16_t**>(&out_ptrs[3])); | |
Extract4xint16(relu, kShiftAmount, sum_4, | |
reinterpret_cast<int16_t**>(&out_ptrs[4])); | |
} | |
} | |
} | |
// Processes one GRU gate input with sigmoid. | |
template <int InputMantissaBits, int StateMantissaBits, bool SplitGates> | |
inline __m256i GRUGateSigmoid(const void* gate_ptr, const void* gate_other_ptr, | |
const __m256i& input, | |
const int32_t* sigmoid_table) { | |
__m256i gate = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_ptr)); | |
if (SplitGates) { | |
__m256i other = | |
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_other_ptr)); | |
gate = _mm256_add_epi32(gate, other); | |
} | |
gate = _mm256_add_epi32(gate, input); | |
// Compute sigmoids on reset and update. | |
return csrblocksparse::fixed32_sigmoid_fixed16<InputMantissaBits, | |
StateMantissaBits>( | |
sigmoid_table, gate); | |
} | |
// Processes the tanh and the final combination, returning the new GRU state. | |
template <int InputMantissaBits, int StateMantissaBits, bool SplitGates = false> | |
inline __m256i GRUGateState(const __m256i& cell, const __m256i& reset, | |
const __m256i& update, | |
const __m256i& rounding_offset, | |
const void* gate_ptr, const void* gate_other_ptr, | |
const void* gru_h_ptr, const int32_t* tanh_table) { | |
// Multiply the cell GRU output and the reset. There is a slight danger of | |
// loss of precision here, so use 32x32=64 bit and shift back after. | |
__m256i gru = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_ptr)); | |
if (SplitGates) { | |
__m256i other_gru = | |
_mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_other_ptr)); | |
gru = _mm256_add_epi32(gru, other_gru); | |
} | |
// This only computes the products of the low-order 32 bits of each pair. | |
__m256i gru_lo = _mm256_mul_epi32(gru, reset); | |
// Swap odd and even 32-bit units and do it again to get the high products. | |
gru = _mm256_shuffle_epi32(gru, 0xb1); | |
__m256i gru_hi = _mm256_mul_epi32(gru, _mm256_shuffle_epi32(reset, 0xb1)); | |
// Now shift right to compensate for the multiply and re-interleave the | |
// 32-bit results. | |
// NOTE: There is no shift right arithmetic for 64 bit values until AVX512! | |
// Fortunately it doesn't matter, as the results are being truncated to 32 | |
// bits and we aren't shifting right by more than 32 bits here. | |
gru_lo = _mm256_srli_epi64(gru_lo, StateMantissaBits); | |
// The upper results are shifted LEFT, so we can use blend to recombine in | |
// a single instruction. | |
gru_hi = _mm256_slli_epi64(gru_hi, 32 - StateMantissaBits); | |
// Recombine the 32 bit results from lo and hi, alternating. | |
gru = _mm256_blend_epi32(gru_lo, gru_hi, 0xaa); | |
gru = _mm256_add_epi32(cell, gru); | |
// Compute tanh on the result. Although this instantly discards a bunch of | |
// bits, there were only 7 surplus bits for the multiply, which isn't enough | |
// to do it as 16x16=32. | |
__m256i hbar = | |
csrblocksparse::fixed32_tanh_fixed16<InputMantissaBits, | |
StateMantissaBits>(tanh_table, gru); | |
// Load the 16-bit previous GRU state and sign-extend to 32 bits. | |
gru = _mm256_cvtepi16_epi32( | |
_mm_load_si128(reinterpret_cast<__m128i const*>(gru_h_ptr))); | |
gru = _mm256_sub_epi32(gru, hbar); | |
// Since |gru| is 16 bit sign-extended to 32, and |update| is the output of | |
// sigmoid, it is always contained within 16 bits and never negative, we can | |
// use |madd_epi16| to do 16x16=32 multiply with horizontal adding as the | |
// addend will always be zero, and this is twice as fast as full blown | |
// 32x32=32. The only possible problem is if the subtract above caused | |
// overflow. | |
gru = _mm256_madd_epi16(gru, update); | |
// Renormalize to fixed16. This time rounding is critical, as this is the | |
// output GRU state. | |
gru = _mm256_add_epi32(gru, rounding_offset); | |
gru = _mm256_srai_epi32(gru, StateMantissaBits); | |
return _mm256_add_epi32(gru, hbar); | |
} | |
template <typename Type> | |
typename std::enable_if<IsFixed32Type<Type>::value>::type SumVectors( | |
int start, int end, const Type* add1, const Type* add2, Type* result) { | |
constexpr int kSIMDWidth = 8; | |
for (int i = start; i < end; i += kSIMDWidth) { | |
__m256i data1 = | |
_mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i)); | |
__m256i data2 = | |
_mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i)); | |
data1 = _mm256_add_epi32(data1, data2); | |
_mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1); | |
} | |
} | |
template <typename Type> | |
typename std::enable_if<IsFixed16Type<Type>::value>::type SumVectors( | |
int start, int end, const Type* add1, const Type* add2, Type* result) { | |
constexpr int kSIMDWidth = 16; | |
for (int i = start; i < end; i += kSIMDWidth) { | |
__m256i data1 = | |
_mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i)); | |
__m256i data2 = | |
_mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i)); | |
data1 = _mm256_add_epi16(data1, data2); | |
_mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1); | |
} | |
} | |
} // namespace detail | |
} // namespace csrblocksparse | |