Spaces:
Runtime error
Runtime error
/* | |
* Copyright 2021 Google LLC | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
namespace csrblocksparse { | |
// The input to exp is clipped to bounds that prevent overflow/underflow in a | |
// 32 bit float representation. e^80 ~ 6e34, which is close to maxfloat. | |
constexpr float kMaxExpInput = 80.f; | |
constexpr int kMaxExpInputInt = static_cast<int>(kMaxExpInput); | |
constexpr float kMinExpInput = -80.f; | |
// tanh(9) ~ 0.99999997, which cannot be resolved from 1 in a float32. | |
constexpr float kMaxTanhInput = 9.f; | |
constexpr float kMinTanhInput = -9.f; | |
// sigmoid(18) ~ 0.999999985, which cannot be resolved from 1 in a float32. | |
constexpr float kMaxSigmoidInput = 18.f; | |
constexpr float kMinSigmoidInput = -18.f; | |
// kAConstant ~= 2^23 / ln 2 | |
constexpr uint32_t kAConstant = 0x4b38aa3b; | |
// kBConstant ~= (127 << 23) - 366000 | |
constexpr uint32_t kBConstant = 0x4e7de9a9; | |
// Coefficients of the rational approximation to tanh. | |
// Coefficients of the numerator polynomial (odd). | |
constexpr float kTanhAlpha1 = 4.89352455891786e-03; | |
constexpr float kTanhAlpha3 = 6.37261928875436e-04; | |
constexpr float kTanhAlpha5 = 1.48572235717979e-05; | |
constexpr float kTanhAlpha7 = 5.12229709037114e-08; | |
constexpr float kTanhAlpha9 = -8.60467152213735e-11; | |
constexpr float kTanhAlpha11 = 2.00018790482477e-13; | |
constexpr float kTanhAlpha13 = -2.76076847742355e-16; | |
// The monomial coefficients of the denominator polynomial (even). | |
constexpr float kTanhBeta0 = 4.89352518554385e-03; | |
constexpr float kTanhBeta2 = 2.26843463243900e-03; | |
constexpr float kTanhBeta4 = 1.18534705686654e-04; | |
constexpr float kTanhBeta6 = 1.19825839466702e-06; | |
// Coefficients of the rational approximation to sigmoid. | |
// Coefficients of the numerator polynomial (odd). | |
constexpr float kSigmoidAlpha1 = 2.48287947061529e-01; | |
constexpr float kSigmoidAlpha3 = 8.51377133304701e-03; | |
constexpr float kSigmoidAlpha5 = 6.08574864600143e-05; | |
constexpr float kSigmoidAlpha7 = 1.15627324459942e-07; | |
constexpr float kSigmoidAlpha9 = 4.37031012579801e-11; | |
// The monomial coefficients of the denominator polynomial (even). | |
constexpr float kSigmoidBeta0 = 9.93151921023180e-01; | |
constexpr float kSigmoidBeta2 = 1.16817656904453e-01; | |
constexpr float kSigmoidBeta4 = 1.70198817374094e-03; | |
constexpr float kSigmoidBeta6 = 6.29106785017040e-06; | |
constexpr float kSigmoidBeta8 = 5.76102136993427e-09; | |
constexpr float kSigmoidBeta10 = 6.10247389755681e-13; | |
// x is the first term of the Taylor series approximation of tanh near 0 and | |
// because the leading error term of tanh(x) - x is O(x^3), it is good for a | |
// wide interval, use it in this region where the other approximation is | |
// inaccurate. tanh(x) = x - x^3 / 3 + 2x^5 / 15 - 17x^7 / 315 + ... | |
// Similarly for sigmoid where the first term is .25x | |
constexpr float kTanhLinearRegion = .15f; | |
constexpr float kSigmoidLinearRegion = .75f; | |
// Maximum shift factor for 1/log 2 to keep it inside int32. | |
constexpr int kMaxLog2Shift = 30; | |
static const int kLogFactor = static_cast<int>((1 << kMaxLog2Shift) / log(2.f)); | |
static const float kOneOverLog2 = 1.0f / log(2.f); | |
// Number of real mantissa bits in IEEE float32. | |
constexpr int kFloatMantissaBits = 23; | |
// Offset to correct the exponent value in the resulting float. | |
constexpr int kFloatExponentOffset = 127 << kFloatMantissaBits; | |
// Mask for mantissa. | |
constexpr int kFloatMantissaMask = (1 << kFloatMantissaBits) - 1; | |
// Mask for exponent; | |
constexpr int kFloatExponentMask = (-1) ^ kFloatMantissaMask; | |
// ========== COMMON DOCUMENTATION FOR THE FLOATING EXPONENT TRICK ============ | |
// Summary: Use the exponent-mantissa representation of a floating point number | |
// to give exponentiation of 2 for free. If we desire f(z) = e^z = 2^(x+n), (for | |
// some fixed-point z expressed as an integer with imaginary binary point within | |
// it) then we have to compute x+n = z / ln 2 and then splitting x+n into | |
// n = int(x+n) and x = fract(x+n) in [0, 1), we can use n and 2^x as the | |
// exponent and mantissa of a floating point number, and that float is equal to | |
// e^z. For original reference see: | |
// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.9.4508&rep=rep1&type=pdf | |
// Important detail: | |
// IEEE floats are stored normalized, ie 1.bbbbbbb... x 2^exponent. The leading | |
// 1 bit is not actually stored, (as it is always 1), providing an extra bit of | |
// precision. | |
// Since 2^0=1 and 2^1=2, we can treat the problem as 2^x = 1 + u and we thus | |
// need a mapping x in [0, 1) -> u in [0, 1) and the 1 + is provided by the | |
// representation. | |
// In the original paper cited above, the mapping is u = x - c, where c is set | |
// to minimize the average error. The function to compute exp(x) this way is | |
// incredibly simple and computationally cheap, but not very accurate. | |
// Fortunately, the problem has been reduced to u = 2^x - 1 over [0, 1) for | |
// which it is far easier to construct accurate approximations with small | |
// polynomials than a full range exp(x), and this is what the cubic and quartic | |
// versions below do. An important feature of these functions is that they | |
// constrain the solution to be exact at 0 and 1 so there is continuity at each | |
// integer boundary where we wrap from 1 to 0 and increment the power of 2. | |
// Coefficients for quartic representation of 2^x - 1 for x on [0,1). | |
// The quartic representation is 2^x - 1 ~ x - x(1-x)(ax^2 + bx + c), hence the | |
// coefficients of a quadratic are all that is required. | |
// Coefficients came from numerical experiments. | |
constexpr float kExpQuarticFactor2 = 0.0135302434f; | |
constexpr float kExpQuarticFactor1 = 0.0656107542f; | |
constexpr float kExpQuarticFactor0 = 0.306963906f; | |
// Coefficients for cubic representation of 2^x - 1 for x on [0,1] | |
// The cubic representation is 2^x - 1 ~ x - x(1-x)(mx + c), hence the | |
// coefficients of a linear function are all that is required. | |
// Coefficients came from numerical experiments. | |
constexpr float kExpCubicFactor1 = 0.0780252018f; | |
constexpr float kExpCubicFactor0 = 0.304684167f; | |
// Coefficients are optimized to minimize the absolute error on | |
// tanh = (e^2x - 1) / (e^2x + 1) instead of on pure e^x. | |
// Enum that determines how a transcendental is computed. | |
enum TranscendentalMode { | |
// Cubic using 16 bit integer arithmetic. | |
TM_ORDER3_16BIT, | |
// Quartic using 16 bit integer arithmetic. | |
TM_ORDER4_16BIT, | |
// Quartic using 32 bit float arithmetic. | |
TM_ORDER4_FLOAT, | |
}; | |
inline int FloatAsInt16(float x) { | |
return static_cast<int>(x * (1 << 15) + 0.5f); | |
} | |
inline int FloatAsInt32(float x) { | |
return static_cast<int>(x * (1 << 30) + 0.5f); | |
} | |
constexpr int kMaxSigmoidInputInt = static_cast<int>(kMaxSigmoidInput); | |
// Computes and returns 2^(x>>23) ie 2^u where x = u << 23 bits. | |
// Uses the quartic floating point exponent trick, see COMMON DOCUMENTATION FOR | |
// THE FLOATING EXPONENT TRICK above for details. | |
// Returns the true value, ie not scaled. | |
inline float32x4_t float32_pow2(float32x4_t x) { | |
// The input is already shifted left by 23 bits, so when we convert to int, | |
// the bottom 23 bits are the fractional part, and the top bits are the | |
// integer part. We want to compute a function of the fractional part, so | |
// we will mask it off and manipulate it. | |
int32x4_t exp_int_x = vcvtq_s32_f32(x); | |
// Mask to allow conversion of just the fractional part of x to fixed16<0>. | |
int32x4_t mantissa_mask16 = vdupq_n_s32(0x7fff00); | |
// Mask to allow conversion of just the fractional part of x to fixed32<1>. | |
int32x4_t mantissa_mask32 = vdupq_n_s32(0x7fffff); | |
// Narrowing shift to convert to fixed16<0>. | |
int16x4_t x_16 = vshrn_n_s32(vandq_s32(mantissa_mask16, exp_int_x), 8); | |
// Shift to convert to fixed32<1>. | |
int32x4_t x_32 = vshlq_n_s32(vandq_s32(mantissa_mask32, exp_int_x), 7); | |
// Compute the polynomial x(x - 1)(ax^2 + bx + c) of the fractional part. | |
// Ordering these lines carefully makes it faster, as some of the multiply | |
// operations can pipeline instead of waiting for the previous result. | |
int32x4_t x_squared = vmull_s16(x_16, x_16); | |
int16x4_t b = vdup_n_s16(FloatAsInt16(kExpQuarticFactor1)); | |
int32x4_t c = vdupq_n_s32(FloatAsInt32(kExpQuarticFactor0)); | |
int32x4_t bx_plus_c = vmlal_s16(c, b, x_16); | |
int16x4_t a = vdup_n_s16(FloatAsInt16(kExpQuarticFactor2)); | |
// Finish the quadratic: result = ax^2 + bx + c. | |
int32x4_t result = vmlal_s16(bx_plus_c, a, vshrn_n_s32(x_squared, 15)); | |
int32x4_t x_squared_minus_x = vsubq_s32(x_squared, x_32); | |
// Multiply by x^2 - x. | |
result = vqrdmulhq_s32(result, x_squared_minus_x); | |
// Shift back to mantissa position. vqrdmulhq_s32 took 2x 30-mantissa bit | |
// inputs, made 60-mantissa bit result, doubled it to 61 bits, then discarded | |
// the bottom 32 making 29, so shift right 6 to get 23. | |
result = vshrq_n_s32(result, 6); | |
// Add the constant to normalize the exponent for IEEE format. | |
int32x4_t exp_offset = vdupq_n_s32(kFloatExponentOffset); | |
exp_int_x = vaddq_s32(exp_int_x, exp_offset); | |
exp_int_x = vaddq_s32(exp_int_x, result); | |
// Cast back to float, as we just computed the exponent and mantissa and | |
// assembled them in IEEE format. | |
return vreinterpretq_f32_s32(exp_int_x); | |
} | |
// Scaled float to float exp approximation, using a quartic refinement of | |
// the exponent trick. See COMMON DOCUMENTATION FOR THE FLOATING EXPONENT TRICK | |
// above for details. Input is a fixed32<31 - mantissa_bits> that has been | |
// converted to a float without any further shifting. MUST HAVE ALREADY BEEN | |
// CLIPPED to a suitable range for exp! | |
// Returns a vector of standard unscaled floats. | |
inline float32x4_t fixed32_exp_float_preclipped(const int mantissa_bits, | |
float32x4_t x) { | |
// Divide by log 2 to convert problem to 2^x, and scale to match the | |
// mantissa bits required by IEEE floats. | |
// This is the shift of the FP mantissa relative to the input mantissa. | |
const int kXShift = kFloatMantissaBits - mantissa_bits; | |
const float kLogFactor = static_cast<float>(1 << kXShift); | |
float32x4_t factor = vdupq_n_f32(kLogFactor * kOneOverLog2); | |
float32x4_t y = vmulq_f32(x, factor); | |
// Now compute 2^x. | |
return float32_pow2(y); | |
} | |
// uses trick that 2^x can be computed by shifting integer into the | |
// exponent, see the following reference for a derivation using double: | |
// goo.gl/aUVTK3 | |
// Input x is clamped to [-64, 64], even infinity and NaN. | |
// Accurate to within 3% relative across the entire range. | |
// Fully pipelined throughput is about 10 cycles per fast_exp call. | |
inline float32x4_t fast_exp(float32x4_t x) { | |
// Uses vcvtnq_s32_f32, not available on ARM v7 NEON. | |
// Load A and B, which are defined as integers into float registers. | |
float32x4_t A = vreinterpretq_f32_u32(vdupq_n_u32(kAConstant)); | |
float32x4_t res = vreinterpretq_f32_u32(vdupq_n_u32(kBConstant)); | |
// Make sure x within the allowed range. | |
x = vminq_f32(x, vdupq_n_f32(kMaxExpInput)); | |
x = vmaxq_f32(x, vdupq_n_f32(kMinExpInput)); | |
// res = A * x + B. | |
// This shifts x into the exponent field and adds the bias. | |
res = vmlaq_f32(res, A, x); | |
// Convert back to an integer, this is what uses the floating point | |
// unit to compute 2^x. | |
int32x4_t x_int = vcvtnq_s32_f32(res); | |
return vreinterpretq_f32_s32(x_int); | |
float32x4_t return_val = vdupq_n_f32(0.f); | |
float exponent = expf(vgetq_lane_f32(x, 0)); | |
return_val = vld1q_lane_f32(&exponent, return_val, 0); | |
exponent = expf(vgetq_lane_f32(x, 1)); | |
return_val = vld1q_lane_f32(&exponent, return_val, 1); | |
exponent = expf(vgetq_lane_f32(x, 2)); | |
return_val = vld1q_lane_f32(&exponent, return_val, 2); | |
exponent = expf(vgetq_lane_f32(x, 3)); | |
return_val = vld1q_lane_f32(&exponent, return_val, 3); | |
return return_val; | |
} | |
// This version does a conversion of the input to floating point, then calls | |
// the floating point fast_exp function. There is another version | |
// fast_exp_fixed, that never does a conversion and is less accurate, but much | |
// faster. | |
template <int ExponentBits> | |
inline float32x4_t fast_exp(int32x4_t x) { | |
return fast_exp(vcvtq_n_f32_s32(x, 31 - ExponentBits)); | |
} | |
// Performs an exp estimate without doing any floating point operations. The | |
// result is a floating point number. See scalar version for an explanation. | |
template <int ExponentBits> | |
inline float32x4_t fast_exp_fixed(int32x4_t x) { | |
static_assert(ExponentBits > 8, "Must have more than 8 ExponentBits"); | |
constexpr int kA = 1.4426950408889634 * (1 << (ExponentBits - 8)); | |
constexpr int kB = (127 << 23) - 366000; | |
constexpr int maxInput = 80 << (31 - ExponentBits); | |
constexpr int minInput = -maxInput; | |
int32x4_t A = vdupq_n_s32(kA); | |
int32x4_t res = vdupq_n_s32(kB); | |
// Make sure x within the allowed range. | |
x = vminq_s32(x, vdupq_n_s32(maxInput)); | |
x = vmaxq_s32(x, vdupq_n_s32(minInput)); | |
// res = A * x + B. | |
// This shifts x into the exponent field and adds the bias. | |
res = vmlaq_s32(res, A, x); | |
return vreinterpretq_f32_s32(res); | |
} | |
// fast_exp_norange_check uses vcvtnq_s32_f32, not available on ARM v7 NEON. | |
namespace detail { | |
// tanh can do range check once. | |
// Input x is clamped to [-64, 64], even infinity and NaN. | |
inline float32x4_t fast_exp_norange_check(float32x4_t x) { | |
float32x4_t A = vreinterpretq_f32_u32(vdupq_n_u32(kAConstant)); | |
float32x4_t res = vreinterpretq_f32_u32(vdupq_n_u32(kBConstant)); | |
res = vmlaq_f32(res, A, x); | |
int32x4_t x_int = vcvtnq_s32_f32(res); | |
return vreinterpretq_f32_s32(x_int); | |
} | |
} // namespace detail | |
// Clips float input to [-kLimit,kLimit]. | |
inline float32x4_t ClipToFloatBounds(const float kLimit, const float32x4_t x) { | |
// Clip to the input bounds for this approximation. | |
float32x4_t clip_limit = vdupq_n_f32(kLimit); | |
float32x4_t clipped_x = vminq_f32(x, clip_limit); | |
clip_limit = vnegq_f32(clip_limit); | |
return vmaxq_f32(clipped_x, clip_limit); | |
} | |
inline float32x4_t float_tanh_float(const float32x4_t& x) { | |
float32x4_t clipped_x = ClipToFloatBounds(kMaxTanhInput, x); | |
// Divide by log 2 to convert problem to 2^x, double (as we need exp(2x)) and | |
// scale to the mantissa bits required by float32_pow2 all in one multiply. | |
// Add one to double the input. | |
const float kLogFactor = static_cast<float>(1 << (kFloatMantissaBits + 1)); | |
float32x4_t factor = vdupq_n_f32(kLogFactor * kOneOverLog2); | |
clipped_x = vmulq_f32(clipped_x, factor); | |
// Now compute 2^x. | |
float32x4_t exp_result = float32_pow2(clipped_x); | |
// Now compute tanh using (e^2x - 1) / (e^2x + 1). | |
float32x4_t one = vdupq_n_f32(1.0f); | |
float32x4_t numerator = vsubq_f32(exp_result, one); | |
float32x4_t denominator = vaddq_f32(exp_result, one); | |
float32x4_t recp = vrecpeq_f32(denominator); | |
// Newton-Raphson iteration, accuracy is important for audio quality | |
recp = vmulq_f32(recp, vrecpsq_f32(recp, denominator)); | |
recp = vmulq_f32(recp, numerator); | |
// Compute 3rd-order Taylor tanh ~ x - x^3/3 for high accuracy and thus low | |
// relative error close to 0. | |
float32x4_t third = vdupq_n_f32(1.0f / 3.0f); | |
float32x4_t taylor = vmulq_f32(x, x); | |
taylor = vmulq_f32(taylor, x); | |
taylor = vmulq_f32(taylor, third); | |
taylor = vsubq_f32(x, taylor); | |
// Test |x| <= 1/9, roughly where the errors cross over, without needing yet | |
// another constant. | |
float32x4_t ninth = vmulq_f32(third, third); | |
uint32x4_t cmp_results = vcaleq_f32(x, ninth); | |
return vbslq_f32(cmp_results, taylor, recp); | |
} | |
// Calculates (exp(x) - exp(-x)) / (exp(x) + exp(-x)). | |
// Input x is clamped to [-9, 9], even infinity and NaN. | |
// See test program for bounds. Throughput of FAST is 334 Mega/sec, | |
// throughput of accurate is 232 Mega/sec. | |
inline float32x4_t fast_tanh(float32x4_t x) { | |
return float_tanh_float(x); | |
x = vminq_f32(x, vdupq_n_f32(kMaxTanhInput)); | |
x = vmaxq_f32(x, vdupq_n_f32(kMinTanhInput)); | |
// The monomial coefficients of the numerator polynomial (odd). | |
const float32x4_t alpha_1 = vdupq_n_f32(kTanhAlpha1); | |
const float32x4_t alpha_3 = vdupq_n_f32(kTanhAlpha3); | |
const float32x4_t alpha_5 = vdupq_n_f32(kTanhAlpha5); | |
const float32x4_t alpha_7 = vdupq_n_f32(kTanhAlpha7); | |
const float32x4_t alpha_9 = vdupq_n_f32(kTanhAlpha9); | |
const float32x4_t alpha_11 = vdupq_n_f32(kTanhAlpha11); | |
const float32x4_t alpha_13 = vdupq_n_f32(kTanhAlpha13); | |
// The monomial coefficients of the denominator polynomial (even). | |
const float32x4_t beta_0 = vdupq_n_f32(kTanhBeta0); | |
const float32x4_t beta_2 = vdupq_n_f32(kTanhBeta2); | |
const float32x4_t beta_4 = vdupq_n_f32(kTanhBeta4); | |
const float32x4_t beta_6 = vdupq_n_f32(kTanhBeta6); | |
// Since the polynomials are odd/even, we need x^2. | |
const float32x4_t x2 = vmulq_f32(x, x); | |
// Evaluate the numerator polynomial |p|. | |
float32x4_t p = vmlaq_f32(alpha_11, x2, alpha_13); | |
p = vmlaq_f32(alpha_9, x2, p); | |
p = vmlaq_f32(alpha_7, x2, p); | |
p = vmlaq_f32(alpha_5, x2, p); | |
p = vmlaq_f32(alpha_3, x2, p); | |
p = vmlaq_f32(alpha_1, x2, p); | |
p = vmulq_f32(x, p); | |
// Evaluate the denominator polynomial p. | |
float32x4_t q = vmlaq_f32(beta_4, x2, beta_6); | |
q = vmlaq_f32(beta_2, x2, q); | |
q = vmlaq_f32(beta_0, x2, q); | |
// Divide the numerator by the denominator. | |
float32x4_t recp = vrecpeq_f32(q); | |
recp = vmulq_f32(recp, vrecpsq_f32(recp, q)); | |
return vmulq_f32(p, recp); | |
// Uses vcvtnq_s32_f32, not available on ARM v7 NEON. | |
x = vminq_f32(x, vdupq_n_f32(kMaxTanhInput)); | |
x = vmaxq_f32(x, vdupq_n_f32(kMinTanhInput)); | |
float32x4_t exp_est = detail::fast_exp_norange_check(x); | |
float32x4_t neg_exp_est = detail::fast_exp_norange_check(-x); | |
// If we're in the linear region. | |
// caleq = compare absolute <= | |
uint32x4_t cmp_results = vcaleq_f32(x, vdupq_n_f32(kTanhLinearRegion)); | |
float32x4_t diff = vsubq_f32(exp_est, neg_exp_est); | |
float32x4_t sum = vaddq_f32(exp_est, neg_exp_est); | |
float32x4_t recp = vrecpeq_f32(sum); | |
recp = vmulq_f32(recp, vrecpsq_f32(recp, sum)); | |
float32x4_t tanh_estimate = vmulq_f32(diff, recp); | |
// Based on comparison, possibly copy x through instead of calculated value. | |
// TODO(b/191497441): Is the compiler generating VBIT or VBSL ? VBIT is one | |
// cycle and VBSL is two... documentation suggests it can do either. | |
return vbslq_f32(cmp_results, x, tanh_estimate); | |
float32x4_t return_val = vdupq_n_f32(0.f); | |
float tanh_value = tanhf(vgetq_lane_f32(x, 0)); | |
return_val = vld1q_lane_f32(&tanh_value, return_val, 0); | |
tanh_value = tanhf(vgetq_lane_f32(x, 1)); | |
return_val = vld1q_lane_f32(&tanh_value, return_val, 1); | |
tanh_value = tanhf(vgetq_lane_f32(x, 2)); | |
return_val = vld1q_lane_f32(&tanh_value, return_val, 2); | |
tanh_value = tanhf(vgetq_lane_f32(x, 3)); | |
return_val = vld1q_lane_f32(&tanh_value, return_val, 3); | |
return return_val; | |
} | |
// Input x is clamped to [-18, 18], even infinity and NaN. | |
// See tests for error bounds. Using SIGMOID_AS_TANH with | |
// ACCURATE_TRANSCENDENTAL_APPROX is both faster and more accurate. Using | |
// SIGMOID_AS_TANH with just FAST is slower, but more accurate. | |
// SIGMOID_AS_TANH, ACCURATE is 205 Mega/sec | |
// SIGMOID_AS_TANH, FAST is 290 Mega/sec | |
// FAST is 340 Mega/sec | |
inline float32x4_t fast_sigmoid(float32x4_t x) { | |
float32x4_t half = vdupq_n_f32(0.5f); | |
return vmlaq_f32(half, half, fast_tanh(vmulq_f32(half, x))); | |
x = vminq_f32(x, vdupq_n_f32(kMaxSigmoidInput)); | |
x = vmaxq_f32(x, vdupq_n_f32(kMinSigmoidInput)); | |
// The monomial coefficients of the numerator polynomial (odd). | |
const float32x4_t alpha_1 = vdupq_n_f32(kSigmoidAlpha1); | |
const float32x4_t alpha_3 = vdupq_n_f32(kSigmoidAlpha3); | |
const float32x4_t alpha_5 = vdupq_n_f32(kSigmoidAlpha5); | |
const float32x4_t alpha_7 = vdupq_n_f32(kSigmoidAlpha7); | |
const float32x4_t alpha_9 = vdupq_n_f32(kSigmoidAlpha9); | |
// The monomial coefficients of the denominator polynomial (even). | |
const float32x4_t beta_0 = vdupq_n_f32(kSigmoidBeta0); | |
const float32x4_t beta_2 = vdupq_n_f32(kSigmoidBeta2); | |
const float32x4_t beta_4 = vdupq_n_f32(kSigmoidBeta4); | |
const float32x4_t beta_6 = vdupq_n_f32(kSigmoidBeta6); | |
const float32x4_t beta_8 = vdupq_n_f32(kSigmoidBeta8); | |
const float32x4_t beta_10 = vdupq_n_f32(kSigmoidBeta10); | |
// Since the polynomials are odd/even, we need x^2. | |
const float32x4_t x2 = vmulq_f32(x, x); | |
// Evaluate the numerator polynomial p. | |
float32x4_t p = vmlaq_f32(alpha_7, x2, alpha_9); | |
p = vmlaq_f32(alpha_5, x2, p); | |
p = vmlaq_f32(alpha_3, x2, p); | |
p = vmlaq_f32(alpha_1, x2, p); | |
p = vmulq_f32(x, p); | |
// Evaluate the denominator polynomial p. | |
float32x4_t q = vmlaq_f32(beta_8, x2, beta_10); | |
q = vmlaq_f32(beta_6, x2, q); | |
q = vmlaq_f32(beta_4, x2, q); | |
q = vmlaq_f32(beta_2, x2, q); | |
q = vmlaq_f32(beta_0, x2, q); | |
// Divide the numerator by the denominator. | |
float32x4_t recp = vrecpeq_f32(q); | |
recp = vmulq_f32(recp, vrecpsq_f32(recp, q)); | |
return vmlaq_f32(vdupq_n_f32(0.5f), p, recp); | |
float32x4_t denom = vaddq_f32(fast_exp(vnegq_f32(x)), vdupq_n_f32(1.f)); | |
float32x4_t recp = vrecpeq_f32(denom); | |
// Newton-Raphson iteration, accuracy is important for audio quality. | |
recp = vmulq_f32(recp, vrecpsq_f32(recp, denom)); | |
float32x4_t half = vdupq_n_f32(0.5f); | |
float32x4_t quarter = vdupq_n_f32(0.245f); | |
float32x4_t linear_approx = vmlaq_f32(half, quarter, x); | |
uint32x4_t cmp_results = vcaleq_f32(x, vdupq_n_f32(kSigmoidLinearRegion)); | |
return vbslq_f32(cmp_results, linear_approx, recp); | |
float32x4_t return_val = vdupq_n_f32(0.f); | |
float result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 0))); | |
return_val = vld1q_lane_f32(&result, return_val, 0); | |
result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 1))); | |
return_val = vld1q_lane_f32(&result, return_val, 1); | |
result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 2))); | |
return_val = vld1q_lane_f32(&result, return_val, 2); | |
result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 3))); | |
return_val = vld1q_lane_f32(&result, return_val, 3); | |
return return_val; | |
} | |
// Scalar implementations, mainly useful for testing. | |
inline float fast_exp(float x) { | |
return vgetq_lane_f32(fast_exp(vdupq_n_f32(x)), 0); | |
} | |
template <int ExponentBits> | |
inline float fast_exp(fixed32<ExponentBits> x) { | |
return vgetq_lane_f32(fast_exp<ExponentBits>(vdupq_n_s32(x.raw_val())), 0); | |
} | |
// Returns the exponent of a fixed point number in floating point without ever | |
// doing any conversions. Less accurate than the version that does conversions, | |
// but still accurate to within 4% relative for x < 16. | |
template <int ExponentBits> | |
inline float fast_exp_fixed(fixed32<ExponentBits> x) { | |
return vgetq_lane_f32(fast_exp_fixed<ExponentBits>(vdupq_n_s32(x.raw_val())), | |
0); | |
} | |
inline float fast_sigmoid(float x) { | |
return vgetq_lane_f32(fast_sigmoid(vdupq_n_f32(x)), 0); | |
} | |
inline float fast_tanh(float x) { | |
return vgetq_lane_f32(fast_tanh(vdupq_n_f32(x)), 0); | |
} | |
// Clips integer input to [-|kLimit|, |kLimit|]. | |
// Input: register containins 4x fixed32 with mantissa_bits. | |
// Output: register containing 4x fixed32 limited to | |
// [-|kLimit| << |mantissa_bits|, |kLimit| << |mantissa_bits|]. | |
template <int kLimit> | |
inline int32x4_t ClipToBounds(const int mantissa_bits, const int32x4_t x) { | |
// Clip to the input bounds for this approximation. | |
int32x4_t clip_limit = vdupq_n_s32(-(kLimit << mantissa_bits)); | |
int32x4_t clipped_x = vmaxq_s32(x, clip_limit); | |
clip_limit = vnegq_s32(clip_limit); | |
return vminq_s32(clipped_x, clip_limit); | |
} | |
// Fixed32 sigmoid approximation via a quadratic refinement of the exponent | |
// trick. | |
// Input: Register containing 4x fixed32 with |mantissa_bits|. | |
// Output: Register containing 4x float results. | |
inline float32x4_t fixed32_sigmoid_float(const int mantissa_bits, | |
const int32x4_t x) { | |
int32x4_t input = vnegq_s32(x); | |
float32x4_t y = | |
vcvtq_f32_s32(ClipToBounds<kMaxSigmoidInputInt>(mantissa_bits, input)); | |
y = fixed32_exp_float_preclipped(mantissa_bits, y); | |
float32x4_t one = vdupq_n_f32(1.0f); | |
// Approximate reciprocal is not accurate enough - use full division. | |
float32x4_t denom = vaddq_f32(y, one); | |
float32x4_t recp = vrecpeq_f32(denom); | |
// Newton-Raphson iteration, accuracy is important for audio quality | |
recp = vmulq_f32(recp, vrecpsq_f32(recp, denom)); | |
return recp; | |
} | |
template <int ExponentBits> | |
inline float32x4_t fast_sigmoid(int32x4_t x) { | |
// Computation will fail to produce the right result if the input mantissa | |
// bits exceeds the number in a float. | |
static_assert(kFloatMantissaBits >= fixed32<ExponentBits>::kMantissaBits, | |
"Mantissa bits must be at most 23!"); | |
return fixed32_sigmoid_float(fixed32<ExponentBits>::kMantissaBits, x); | |
return fast_sigmoid(vcvtq_n_f32_s32(x, fixed32<ExponentBits>::kMantissaBits)); | |
} | |
template <int ExponentBits> | |
inline float fast_sigmoid(fixed32<ExponentBits> x) { | |
return vgetq_lane_f32(fast_sigmoid<ExponentBits>(vdupq_n_s32(x.raw_val())), | |
0); | |
} | |
inline float fast_exp(float x) { | |
if (isnan(x)) return 0.0f; | |
x = std::max(std::min(x, kMaxExpInput), kMinExpInput); | |
float AConstant, BConstant; | |
memcpy(&AConstant, &kAConstant, sizeof(int)); | |
memcpy(&BConstant, &kBConstant, sizeof(int)); | |
float y = x * AConstant + BConstant; | |
int x_int = static_cast<int>(y); | |
float ret; | |
memcpy(&ret, &x_int, sizeof(float)); | |
return ret; | |
return expf(x); | |
} | |
template <int ExponentBits> | |
inline float fast_exp(fixed32<ExponentBits> x) { | |
return fast_exp(static_cast<float>(x)); | |
} | |
template <int ExponentBits> | |
inline float fast_exp_fixed(fixed32<ExponentBits> x) { | |
static_assert(ExponentBits > 8, "Must have more than 8 ExponentBits"); | |
int matched_decimal = | |
std::max(std::min(x.raw_val(), (80 << (31 - ExponentBits))), | |
-(80 << (31 - ExponentBits))); | |
// Convert 1 / log(2) to 16-bit fixed point with 1 exponent bit | |
// (1 / log(2)) * (1 << 14), but then right shift by the appropriate amount to | |
// line the decimal point up with the 32-bit float representation. | |
// (MantissaBits of x) + (MantissaBits of constant) = 23 | |
// 23 - (MantissaBits of x) = MantissaBits of constant | |
// 23 - (31 - ExponentBits of x) = ... | |
// (ExponentBits of x - 8) = MantissaBits of constant | |
const int16_t A = (1.f / logf(2.f)) * (1 << (ExponentBits - 8)); | |
// Same rationale as for floating point versions, bias exponent, subtract | |
// 366000 to reduce error by centering approximation, instead of being | |
// one-sided. | |
const int B = (127 << 23) - 366000; | |
matched_decimal = A * matched_decimal + B; | |
float ret_val; | |
memcpy(&ret_val, &matched_decimal, sizeof(float)); | |
return ret_val; | |
} | |
inline float fast_tanh(float x) { | |
// Doesn't do anything fancy, just a 13/6-degree rational interpolant which | |
// is accurate up to a couple of ulp in the range [-9, 9], outside of which | |
// fl(tanh(x)) = +/-1. | |
x = std::max(std::min(x, kMaxTanhInput), kMinTanhInput); | |
// Since the polynomials are odd/even, we need x^2. | |
float x2 = x * x; | |
// Evaluate numerator. | |
float p = kTanhAlpha11 + x2 * kTanhAlpha13; | |
p = kTanhAlpha9 + x2 * p; | |
p = kTanhAlpha7 + x2 * p; | |
p = kTanhAlpha5 + x2 * p; | |
p = kTanhAlpha3 + x2 * p; | |
p = kTanhAlpha1 + x2 * p; | |
p = x * p; | |
// Evaluate denominator. | |
float q = kTanhBeta4 + x2 * kTanhBeta6; | |
q = kTanhBeta2 + x2 * q; | |
q = kTanhBeta0 + x2 * q; | |
return p / q; | |
if (std::abs(x) < kTanhLinearRegion) { | |
return x; | |
} else { | |
x = std::max(std::min(x, kMaxTanhInput), kMinTanhInput); | |
float positive = fast_exp(x); | |
float negative = fast_exp(-x); | |
return (positive - negative) / (positive + negative); | |
} | |
return tanhf(x); | |
} | |
inline float fast_sigmoid(float x) { | |
return .5f * fast_tanh(.5f * x) + .5f; | |
// Doesn't do anything fancy, just a 9/10-degree rational interpolant which | |
// interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulp in the range | |
// [-18, 18], outside of which the fl(sigmoid(x)) = {0|1}. The shifted | |
// sigmoid is interpolated because it was easier to make the fit converge. | |
// See GenericPacketMath.h* in the open source Eigen library. | |
x = std::max(std::min(x, kMaxSigmoidInput), kMinSigmoidInput); | |
// Since the polynomials are odd/even, we need x^2. | |
float x2 = x * x; | |
// Evaluate numerator. | |
float p = kSigmoidAlpha7 + x2 * kSigmoidAlpha9; | |
p = kSigmoidAlpha5 + x2 * p; | |
p = kSigmoidAlpha3 + x2 * p; | |
p = kSigmoidAlpha1 + x2 * p; | |
p = x * p; | |
// Evaluate denominator. | |
float q = kSigmoidBeta8 + x2 * kSigmoidBeta10; | |
q = kSigmoidBeta6 + x2 * q; | |
q = kSigmoidBeta4 + x2 * q; | |
q = kSigmoidBeta2 + x2 * q; | |
q = kSigmoidBeta0 + x2 * q; | |
return p / q + 0.5f; | |
if (std::abs(x) < kSigmoidLinearRegion) { | |
return .245 * x + .5; | |
} else { | |
return 1.f / (1.f + fast_exp(-x)); | |
} | |
return 1.f / (1.f + expf(-x)); | |
} | |
template <int ExponentBits> | |
inline float fast_sigmoid(fixed32<ExponentBits> x) { | |
return fast_sigmoid(static_cast<float>(x)); | |
} | |
// Number of exponent bits to use for tanh. | |
static constexpr int kNumTanhExpBits = 3; | |
// Number of exponent bits to use for sigmoid. | |
static constexpr int kNumSigmoidExpBits = 4; | |
// Number of extra bits to shift sigmoid, due to its low gradient. | |
static constexpr int kNumExtraSigmoidShiftBits = 1; | |
// Returns (and builds if not done yet) a static data table (that is never | |
// deleted, as per the style guide) that implements tanh on fixed32 input, | |
// returning another fixed32 with the given number of mantissa bits (which is | |
// assumed to be less than the input mantissa bits). | |
// NOTE that this function is intended to be used only with fixed16 outputs that | |
// are sign-extended to 32 bits for convenience, and will return a nullptr | |
// if asked for more than |kMaxMantissaBits| of precision in the output table. | |
const int* TanhTable(int num_mantissa_bits_out); | |
// As TanhTable, but for Sigmoid. | |
const int* SigmoidTable(int num_mantissa_bits_out); | |
// Scalar/generic function to compute and return the fast approximation to exp | |
// via a polynomial refinement of the floating point exponent trick. | |
// TM_ORDER4_16BIT:Max relative error < 5e-6, absolute error < 1e-5 for x < 1. | |
// TM_ORDER3_16BIT:Max relative error < 1.1e-4, absolute error < 3e-4 for x | |
// < 1. | |
template <int kExponentBits, TranscendentalMode kOrder = TM_ORDER4_16BIT> | |
float fixed32_exp(fixed32<kExponentBits> x) { | |
constexpr int kMantissaBits = MantissaBitsOf<fixed32<kExponentBits>>::value; | |
// Clip x to min/max exp input to avoid infinities. | |
int64_t clipped_x = | |
std::max(std::min(x.raw_val(), kMaxExpInputInt << kMantissaBits), | |
-(kMaxExpInputInt << kMantissaBits)); | |
// First convert problem from e^x to 2^x by multiplying by 1/log(2). | |
// To maximize precision, log_factor is shifted left the maximum amount to | |
// keep within int32, and we shift x left a further amount such that the | |
// binary point of the product sits in the correct place in the top 32 bits of | |
// the result to be used directly as a float. We can't do that directly, as x | |
// would overflow, so we have to shift by 1 bit less and shift the result by | |
// 1 bit less to match. | |
constexpr int kXShift = | |
kFloatMantissaBits + 31 - kMaxLog2Shift - kMantissaBits; | |
static_assert(kXShift >= 0, | |
"Mantissa bits > kFloatMantissaBits + 31 - kMaxLog2Shift"); | |
clipped_x <<= kXShift; | |
int float_as_int = (kLogFactor * clipped_x >> 31) + kFloatExponentOffset; | |
// Separate the resulting fixed-point into integer and fractional parts. | |
int int_part = float_as_int & kFloatExponentMask; | |
int float_part = float_as_int & kFloatMantissaMask; | |
float fraction = static_cast<float>(float_part) / (1 << kFloatMantissaBits); | |
// Compute the mantissa = 2^fraction using: | |
// fraction - fraction*(1-fraction)*(polynomial of fraction) | |
// This guarantees exactness at 0 and 1, providing continuity of the error at | |
// integer boundaries. | |
float mantissa; | |
if (kOrder == TM_ORDER4_16BIT || kOrder == TM_ORDER4_FLOAT) { | |
mantissa = (kExpQuarticFactor2 * fraction + kExpQuarticFactor1) * fraction + | |
kExpQuarticFactor0; | |
} else if (kOrder == TM_ORDER3_16BIT) { | |
mantissa = kExpCubicFactor1 * fraction + kExpCubicFactor0; | |
} | |
mantissa = fraction - fraction * (1.0f - fraction) * mantissa; | |
// Since the function above guarantees to stay within [0, 1), we could do all | |
// the above in fixed point if necessary, in which case, we can just stuff | |
// the bottom kFloatMantissaBits in with the exponent and we are done. | |
// In the floating point world, it is simpler to just multiply them together. | |
float result; | |
memcpy(&result, &int_part, sizeof(float)); | |
return result * (1.0f + mantissa); | |
} | |
// Computes and returns tanh(x) fixed32->float using a polynomial refinement of | |
// the floating point exponent trick. | |
// kOrder=4: Absolute error < 1.8e-6. Relative error < 1.2e-4 for |x| > 0.01. | |
// kOrder=3: Absolute error < 6e-5. Relative error < 3e-3 for |x| > 0.01 | |
template <int kExponentBits, TranscendentalMode kOrder = TM_ORDER4_16BIT> | |
float fixed32_tanh(fixed32<kExponentBits> x) { | |
float float_x = static_cast<float>(x); | |
if (std::abs(float_x) < 1.0f / 9.0f) { | |
return float_x * (1 - float_x * float_x / 3.0f); | |
} | |
x = static_cast<fixed32<kExponentBits>>(x.raw_val() * 2); | |
float exp_2x = fixed32_exp<kExponentBits, kOrder>(x); | |
return (exp_2x - 1.0f) / (exp_2x + 1.0f); | |
} | |
// Computes and returns sigmoid(x) fixed32->float using a polynomial refinement | |
// of the floating point exponent trick. | |
// TM_ORDER4_16BIT: Absolute error < 9e-7, relative < 4e-6. | |
// TM_ORDER3_16BIT: Absolute error < 3e-5, relative < 1.1e-4. | |
template <int kExponentBits, TranscendentalMode kOrder = TM_ORDER4_16BIT> | |
float fixed32_sigmoid(fixed32<kExponentBits> x) { | |
x = static_cast<fixed32<kExponentBits>>(-x.raw_val()); | |
float exp_x = fixed32_exp<kExponentBits, kOrder>(x); | |
return 1.0f / (exp_x + 1.0f); | |
} | |
// Inline function to access an int32 data table by shifting |x| right by | |
// |kNumShiftBits|, and adding |kTableOffset| to the result. |x| contains 8 | |
// indices and 8 results are returned. The data table is of size | |
// |kTableOffset| * 2 + 1. | |
template <int kNumShiftBits, int kTableOffset> | |
inline __m256i index_data_table(const int32_t* data_table, const __m256i& x) { | |
// Shift right with rounding to match input and output precision. | |
__m256i shifted = _mm256_set1_epi32(1 << (kNumShiftBits - 1)); | |
shifted = _mm256_add_epi32(x, shifted); | |
shifted = _mm256_srai_epi32(shifted, kNumShiftBits); | |
// Add the offset. | |
__m256i addend = _mm256_set1_epi32(kTableOffset); | |
shifted = _mm256_add_epi32(shifted, addend); | |
// And clamp to the indices of the LUT. | |
addend = _mm256_add_epi32(addend, addend); | |
shifted = _mm256_min_epi32(shifted, addend); | |
shifted = _mm256_max_epi32(shifted, _mm256_setzero_si256()); | |
// Lookup the results in the table. | |
return _mm256_i32gather_epi32(data_table, shifted, 4); | |
} | |
// Fixed32 to fixed16-in-an-int32 tanh LUT function. | |
// Input: register containins 8x fixed32 with |NumInputMantissaBits|. | |
// Output: a register containing 8x fixed16 with |NumOutputMantissaBits|, but | |
// note that they are sign-extended to 32 bits and are therefore basically the | |
// same as fixed32 with |NumOutputMantissaBits|. | |
template <int NumInputMantissaBits, int NumOutputMantissaBits> | |
inline __m256i fixed32_tanh_fixed16(const int* tanh_table, const __m256i& x) { | |
// Lose the unnecessary input precision. | |
constexpr int kNumShiftBits = NumInputMantissaBits - NumOutputMantissaBits; | |
constexpr int kTableOffset = 1 << (NumOutputMantissaBits + kNumTanhExpBits); | |
return index_data_table<kNumShiftBits, kTableOffset>(tanh_table, x); | |
} | |
// Fixed32 to fixed16-in-an-int32 sigmoid LUT function. | |
// Input: register containins 8x fixed32 with |NumInputMantissaBits|. | |
// Output: a register containing 8x fixed16 with |NumOutputMantissaBits|, but | |
// note that they are sign-extended to 32 bits and are therefore basically the | |
// same as fixed32 with |NumOutputMantissaBits|. | |
template <int NumInputMantissaBits, int NumOutputMantissaBits> | |
inline __m256i fixed32_sigmoid_fixed16(const int* sigmoid_table, | |
const __m256i& x) { | |
// Lose the unnecessary input precision. | |
constexpr int kNumShiftBits = | |
kNumExtraSigmoidShiftBits + NumInputMantissaBits - NumOutputMantissaBits; | |
constexpr int kTableOffset = 1 | |
<< (NumOutputMantissaBits + kNumSigmoidExpBits - | |
kNumExtraSigmoidShiftBits); | |
return index_data_table<kNumShiftBits, kTableOffset>(sigmoid_table, x); | |
} | |
// Convert 2x registers of 8x float32 into 1 register of 16x16 bit fixed int, | |
// assuming that the floats are already scaled up. | |
inline __m256i PackFloatsToFixed16(const __m256& x0, const __m256& x1) { | |
__m256i int0 = _mm256_cvtps_epi32(x0); | |
__m256i int1 = _mm256_cvtps_epi32(x1); | |
int0 = _mm256_packs_epi32(int0, int1); | |
// Swap the middle 64 bit elements so the results are in the right order. | |
return _mm256_permute4x64_epi64(int0, 0xd8); | |
} | |
// Clips integer input to [-|kLimit|, |kLimit|]. | |
// Input: register containins 8x fixed32 with |mantissa_bits|. | |
// Output: register containing 8x fixed32 limited to | |
// [-|kLimit| << |mantissa_bits|, |kLimit| << |mantissa_bits|]. | |
template <int kLimit> | |
inline __m256i ClipToBounds(const int mantissa_bits, const __m256i& x) { | |
// Clip to the input bounds for this approximation. | |
__m256i clip_limit = _mm256_set1_epi32(-(kLimit << mantissa_bits)); | |
__m256i clipped_x = _mm256_max_epi32(x, clip_limit); | |
// This quickly negates the limit without having to load another constant. | |
clip_limit = _mm256_sign_epi32(clip_limit, clip_limit); | |
return _mm256_min_epi32(clipped_x, clip_limit); | |
} | |
// Clips float input to [-|kLimit|, |kLimit|]. | |
// Input: register containins 8x float. | |
// Output: register containing 8x float limited to [-|kLimit|, |kLimit|]. | |
inline __m256 ClipToFloatBounds(const float kLimit, const __m256& x) { | |
__m256 clip_limit = _mm256_set1_ps(kLimit); | |
__m256 clipped_x = _mm256_min_ps(x, clip_limit); | |
clip_limit = _mm256_set1_ps(-kLimit); | |
return _mm256_max_ps(clipped_x, clip_limit); | |
} | |
// Float to float power of 2 approximation, using a quartic refinement of | |
// the exponent trick. For TM_ORDER4_16BIT and TM_ORDER3_16BIT, implementation | |
// is entirely in integer, using 16x16=16 multiplication, using AVX2, which | |
// enables 16 elements to be computed in parallel, hence the double register | |
// input/output args. | |
// The price paid for this speed is an increase in error over the (scalar) int32 | |
// example implementations above by a variable factor of 4-10. | |
// For the TM_ORDER4_FLOAT case, the computation is all done in float, solving | |
// this lower precision problem. | |
// NOTE: The input must have already been clipped to prevent overflow, which | |
// sets the practical limit to +/-126 << kFloatMantissaBits. | |
// NOTE: The input is a scaled float, as if converted raw from int, and the | |
// scale factor is fixed at kFloatMantissaBits! | |
// Input: 2x register containining 8x float * 1 << kFloatMantissaBits. | |
// Output: 2x register containing 8x float. | |
// TM_ORDER4_FLOAT: Max relative error < 8e-6, absolute error < 9e-6 for x < 1. | |
// TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 6e-5 for x < 1. | |
// TM_ORDER3_16BIT: Max relative error < 6e-4, absolute error < 2e-3 for x < 1. | |
template <TranscendentalMode kOrder = TM_ORDER4_16BIT> | |
inline void float32_pow2(__m256& x0, __m256& x1) { | |
// Convert straight to int. | |
__m256i exp_int_x0 = _mm256_cvtps_epi32(x0); | |
__m256i exp_int_x1 = _mm256_cvtps_epi32(x1); | |
__m256i result_x0, result_x1; | |
static_assert(kOrder == TM_ORDER4_FLOAT || kOrder == TM_ORDER4_16BIT || | |
kOrder == TM_ORDER3_16BIT, | |
"Invalid order."); | |
if (kOrder == TM_ORDER4_FLOAT) { | |
__m256i mantissa_mask = _mm256_set1_epi32(0x7fffff); | |
__m256 float_factor = | |
_mm256_set1_ps(1.0f / static_cast<float>(1 << kFloatMantissaBits)); | |
__m256i fract0 = _mm256_and_si256(mantissa_mask, exp_int_x0); | |
__m256i fract1 = _mm256_and_si256(mantissa_mask, exp_int_x1); | |
__m256 float0 = _mm256_mul_ps(_mm256_cvtepi32_ps(fract0), float_factor); | |
__m256 float1 = _mm256_mul_ps(_mm256_cvtepi32_ps(fract1), float_factor); | |
// Compute the polynomial of the fractional part. | |
// Ordering these lines carefully makes it faster, as some of the multiply | |
// operations can pipeline instead of waiting for the previous result. | |
__m256 x_squared0 = _mm256_mul_ps(float0, float0); | |
__m256 x_squared1 = _mm256_mul_ps(float1, float1); | |
__m256 b = _mm256_set1_ps(kExpQuarticFactor1); | |
__m256 b_x0 = _mm256_mul_ps(b, float0); | |
__m256 b_x1 = _mm256_mul_ps(b, float1); | |
__m256 a = _mm256_set1_ps(kExpQuarticFactor2); | |
__m256 a_x_squared0 = _mm256_mul_ps(a, x_squared0); | |
__m256 a_x_squared1 = _mm256_mul_ps(a, x_squared1); | |
__m256 x_squared_minus_x0 = _mm256_sub_ps(x_squared0, float0); | |
__m256 x_squared_minus_x1 = _mm256_sub_ps(x_squared1, float1); | |
__m256 c = _mm256_set1_ps(kExpQuarticFactor0); | |
b_x0 = _mm256_add_ps(b_x0, c); | |
b_x1 = _mm256_add_ps(b_x1, c); | |
float_factor = _mm256_set1_ps(static_cast<float>(1 << kFloatMantissaBits)); | |
a_x_squared0 = _mm256_add_ps(a_x_squared0, b_x0); | |
a_x_squared1 = _mm256_add_ps(a_x_squared1, b_x1); | |
a_x_squared0 = _mm256_mul_ps(a_x_squared0, x_squared_minus_x0); | |
a_x_squared1 = _mm256_mul_ps(a_x_squared1, x_squared_minus_x1); | |
result_x0 = _mm256_cvtps_epi32(_mm256_mul_ps(a_x_squared0, float_factor)); | |
result_x1 = _mm256_cvtps_epi32(_mm256_mul_ps(a_x_squared1, float_factor)); | |
} else { | |
// Combine the fractional part of both inputs into a single register. | |
// The representation is fixed16<0>, ie 15 mantissa bits. | |
__m256i mantissa_mask = _mm256_set1_epi32(0x7fff00); | |
__m256i x_01 = | |
_mm256_srli_epi32(_mm256_and_si256(mantissa_mask, exp_int_x0), 8); | |
x_01 = _mm256_or_si256( | |
x_01, | |
_mm256_slli_epi32(_mm256_and_si256(mantissa_mask, exp_int_x1), 8)); | |
// Compute the polynomial of the fractional part. | |
// Ordering these lines carefully makes it faster, as some of the multiply | |
// operations can pipeline instead of waiting for the previous result. | |
__m256i x_squared = _mm256_mulhrs_epi16(x_01, x_01); | |
__m256i result, x_squared_minus_x; | |
if (kOrder == TM_ORDER4_16BIT) { | |
__m256i b = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor1)); | |
__m256i b_x = _mm256_mulhrs_epi16(b, x_01); | |
__m256i a = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor2)); | |
__m256i a_x_squared = _mm256_mulhrs_epi16(a, x_squared); | |
x_squared_minus_x = _mm256_sub_epi16(x_squared, x_01); | |
// LOG(INFO) << "x_squared_minus_x=" << | |
// static_cast<int16>(_mm256_extract_epi16(x_squared_minus_x, 0)) / | |
// 32768.0f; | |
__m256i c = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor0)); | |
b_x = _mm256_add_epi16(b_x, c); | |
// LOG(INFO) << "bx+c=" << static_cast<int16>(_mm256_extract_epi16(b_x, | |
// 0)) / 32768.0f; | |
result = _mm256_add_epi16(a_x_squared, b_x); | |
} else { // kOrder = TM_ORDER3_16BIT | |
__m256i a = _mm256_set1_epi16(FloatAsInt16(kExpCubicFactor1)); | |
__m256i b = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor0)); | |
__m256i a_x = _mm256_mulhrs_epi16(a, x_01); | |
x_squared_minus_x = _mm256_sub_epi16(x_squared, x_01); | |
result = _mm256_add_epi16(a_x, b); | |
} | |
result = _mm256_mulhrs_epi16(result, x_squared_minus_x); | |
// Extract 16x16-bit results back to the separate sets of 8x32. | |
result_x0 = _mm256_slli_epi32(result, 16); | |
result_x0 = _mm256_srai_epi32(result_x0, 8); | |
result_x1 = _mm256_srai_epi32(result, 16); | |
result_x1 = _mm256_slli_epi32(result_x1, 8); | |
} | |
// Add the constant to normalize the exponent. | |
__m256i exp_offset = _mm256_set1_epi32(kFloatExponentOffset); | |
exp_int_x0 = _mm256_add_epi32(exp_int_x0, exp_offset); | |
exp_int_x0 = _mm256_add_epi32(exp_int_x0, result_x0); | |
exp_int_x1 = _mm256_add_epi32(exp_int_x1, exp_offset); | |
exp_int_x1 = _mm256_add_epi32(exp_int_x1, result_x1); | |
// Cast back to float, as we just computed the exponent and mantissa and | |
// assembled them in IEEE format. | |
x0 = _mm256_castsi256_ps(exp_int_x0); | |
x1 = _mm256_castsi256_ps(exp_int_x1); | |
} | |
// Fixed32 to to float exp approximation, using a quartic/cubic refinement of | |
// the exponent trick. Implementation is entirely in integer, using 16x16=16 | |
// multiplication, using AVX2, which enables 16 elements to be computed in | |
// parallel, hence the double register input/output args. | |
// The price paid for this speed is an increase in error over the (scalar) int32 | |
// example implementations above by a variable factor of 4-10. | |
// The TM_ORDER4_FLOAT version uses floats and improves the precision. | |
// Input: 2x registers containins 8x fixed32 with kMantissaBits. | |
// Output: 2x registers containing 8x float32. | |
// TM_ORDER4_FLOAT: Max relative error < 8e-6, absolute error < 9e-6 for x < 1. | |
// TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 6e-5 for x < 1. | |
// TM_ORDER3_16BIT: Max relative error < 6e-4, absolute error < 2e-3 for x < 1. | |
template <int kInputMantissaBits, TranscendentalMode kOrder = TM_ORDER4_16BIT> | |
inline void float_exp_float_preclipped(__m256& y0, __m256& y1) { | |
// Divide by log 2 to convert problem to 2^x, and scale to match the | |
// mantissa bits required by IEEE floats. Without a _mm256_mulhrs_epi32, it is | |
// much easier to do this in float, even with the double conversion, as 16 bit | |
// is not precise enough here. | |
// This is the shift of the FP mantissa relative to the input mantissa. | |
constexpr int kXShift = kFloatMantissaBits - kInputMantissaBits; | |
constexpr float kLogFactor = static_cast<float>(1 << kXShift); | |
__m256 factor = _mm256_set1_ps(kLogFactor * kOneOverLog2); | |
y0 = _mm256_mul_ps(y0, factor); | |
y1 = _mm256_mul_ps(y1, factor); | |
// Now compute 2^x. | |
float32_pow2<kOrder>(y0, y1); | |
} | |
template <int kInputMantissaBits, TranscendentalMode kOrder = TM_ORDER4_16BIT> | |
inline void fixed32_exp_float(const __m256i& x0, const __m256i& x1, __m256& y0, | |
__m256& y1) { | |
// Clip to acceptable bounds to prevent overflow, and convert to float. | |
y0 = | |
_mm256_cvtepi32_ps(ClipToBounds<kMaxExpInputInt>(kInputMantissaBits, x0)); | |
y1 = | |
_mm256_cvtepi32_ps(ClipToBounds<kMaxExpInputInt>(kInputMantissaBits, x1)); | |
float_exp_float_preclipped<kInputMantissaBits, kOrder>(y0, y1); | |
} | |
// Float->float tanh approximation via the exponent trick. | |
// Note that the input is scaled floats, as if converted raw from fixed16/32. | |
// Input: 2x registers containing 8x float scaled by input_mantissa_bits. | |
// Output: two registers containing 8x float. | |
// TM_ORDER4_FLOAT: Max relative error < 2.1e-5, absolute error < 2.3e-6. | |
// TM_ORDER4_16BIT: Max relative error < 1e-4, absolute error < 1.3e-5. | |
// TM_ORDER3_16BIT: Max relative error < 2.1e-3, absolute error < 3e-4. | |
template <int kInputMantissaBits, TranscendentalMode kOrder = TM_ORDER4_FLOAT> | |
inline void float_tanh_float(const __m256& x0, const __m256& x1, __m256& y0, | |
__m256& y1) { | |
// Divide by log 2 to convert problem to 2^x, double (as we need exp(2x)) and | |
// scale to the mantissa bits required by float32_pow2 all in one multiply. | |
// This is the shift of the FP mantissa relative to the input mantissa. | |
// Add one to double the input. | |
const float kLogFactor = | |
static_cast<float>(1 << (kFloatMantissaBits - kInputMantissaBits + 1)); | |
__m256 factor = _mm256_set1_ps(kLogFactor * kOneOverLog2); | |
// Clip to suitable input bounds for tanh. | |
__m256 clip_limit = _mm256_set1_ps(kMaxTanhInput * (1 << kInputMantissaBits)); | |
__m256 clip0 = _mm256_min_ps(x0, clip_limit); | |
__m256 clip1 = _mm256_min_ps(x1, clip_limit); | |
clip_limit = _mm256_set1_ps(-kMaxTanhInput * (1 << kInputMantissaBits)); | |
clip0 = _mm256_max_ps(clip0, clip_limit); | |
clip1 = _mm256_max_ps(clip1, clip_limit); | |
__m256 exp0 = _mm256_mul_ps(clip0, factor); | |
__m256 exp1 = _mm256_mul_ps(clip1, factor); | |
// Now compute 2^x. | |
float32_pow2<kOrder>(exp0, exp1); | |
// Now compute tanh using (e^2x - 1) / (e^2x + 1). | |
__m256 one = _mm256_set1_ps(1.0f); | |
__m256 numerator = _mm256_sub_ps(exp0, one); | |
__m256 denominator = _mm256_add_ps(exp0, one); | |
// Approximate reciprocal is not accurate enough - use full division. | |
exp0 = _mm256_div_ps(numerator, denominator); | |
numerator = _mm256_sub_ps(exp1, one); | |
denominator = _mm256_add_ps(exp1, one); | |
exp1 = _mm256_div_ps(numerator, denominator); | |
// Compute 3rd-order Taylor tanh ~ x - x^3/3 for high accuracy and thus low | |
// relative error close to 0. | |
// Normalize the inputs back to proper floats. | |
factor = _mm256_set1_ps(1.0f / (1 << kInputMantissaBits)); | |
clip0 = _mm256_mul_ps(clip0, factor); | |
clip1 = _mm256_mul_ps(clip1, factor); | |
__m256 third = _mm256_set1_ps(-1.0f / 3.0f); | |
__m256 taylor0 = _mm256_mul_ps(clip0, clip0); | |
__m256 taylor1 = _mm256_mul_ps(clip1, clip1); | |
taylor0 = _mm256_mul_ps(taylor0, clip0); | |
taylor1 = _mm256_mul_ps(taylor1, clip1); | |
// TODO(b/191497441): The next two pairs of instructions could be combined to | |
// _mm256_fmadd_ps, but requires -mfma compilation option, eg: | |
// taylor0 = _mm256_fmadd_ps(taylor0, third, clip0); | |
taylor0 = _mm256_mul_ps(taylor0, third); | |
taylor1 = _mm256_mul_ps(taylor1, third); | |
taylor0 = _mm256_add_ps(clip0, taylor0); | |
taylor1 = _mm256_add_ps(clip1, taylor1); | |
// Test |x| <= 1/9, roughly where the errors cross over, without needing yet | |
// another constant. | |
third = _mm256_mul_ps(third, third); | |
__m256 neg_zero = _mm256_set1_ps(-0.0f); | |
clip0 = _mm256_andnot_ps(neg_zero, clip0); | |
clip1 = _mm256_andnot_ps(neg_zero, clip1); | |
__m256 cmp_results0 = _mm256_cmp_ps(clip0, third, _CMP_LE_OQ); | |
__m256 cmp_results1 = _mm256_cmp_ps(clip1, third, _CMP_LE_OQ); | |
y0 = _mm256_blendv_ps(exp0, taylor0, cmp_results0); | |
y1 = _mm256_blendv_ps(exp1, taylor1, cmp_results1); | |
} | |
// Fixed32 sigmoid approximation via the AVX2 implementation of the exponent | |
// trick. | |
// Input: 2x registers containins 8x float containing converted fixed32 scaled | |
// with kInputMantissaBits. | |
// Output: 2x registers containing 8x float. | |
// TM_ORDER4_FLOAT: Max relative error < 4e-6, absolute error < 1e-6. | |
// TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 7e-6. | |
// TM_ORDER3_16BIT: Max relative error < 5.4e-4, absolute error < 1.4e-4. | |
template <int kInputMantissaBits, TranscendentalMode kOrder = TM_ORDER4_FLOAT> | |
inline void float_sigmoid_float(__m256& y0, __m256& y1) { | |
constexpr float kInputFactor = static_cast<float>(1 << kInputMantissaBits); | |
// Negate the inputs. | |
__m256 minus_zero = _mm256_set1_ps(-0.0f); | |
y0 = _mm256_xor_ps(y0, minus_zero); | |
y1 = _mm256_xor_ps(y1, minus_zero); | |
y0 = ClipToFloatBounds(kMaxSigmoidInput * kInputFactor, y0); | |
y1 = ClipToFloatBounds(kMaxSigmoidInput * kInputFactor, y1); | |
float_exp_float_preclipped<kInputMantissaBits, kOrder>(y0, y1); | |
__m256 one = _mm256_set1_ps(1.0f); | |
// Approximate reciprocal is not accurate enough - use full division. | |
y0 = _mm256_div_ps(one, _mm256_add_ps(y0, one)); | |
y1 = _mm256_div_ps(one, _mm256_add_ps(y1, one)); | |
} | |
} // namespace csrblocksparse | |