File size: 53,091 Bytes
d1a84ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
/*
 * Copyright 2021 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FAST_TRANSCENDENTALS_H_
#define LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FAST_TRANSCENDENTALS_H_

#include <cstdint>
#if defined __ARM_NEON || defined __aarch64__
#include <arm_neon.h>
#else
#include <algorithm>
#endif
#if defined __AVX__ || defined __AVX2__
#include <immintrin.h>
#endif
#include <math.h>

#include "sparse_matmul/numerics/fixed_types.h"
#include "sparse_matmul/numerics/type_utils.h"

namespace csrblocksparse {

// The input to exp is clipped to bounds that prevent overflow/underflow in a
// 32 bit float representation. e^80 ~ 6e34, which is close to maxfloat.
constexpr float kMaxExpInput = 80.f;
constexpr int kMaxExpInputInt = static_cast<int>(kMaxExpInput);
constexpr float kMinExpInput = -80.f;
// tanh(9) ~ 0.99999997, which cannot be resolved from 1 in a float32.
constexpr float kMaxTanhInput = 9.f;
constexpr float kMinTanhInput = -9.f;
// sigmoid(18) ~ 0.999999985, which cannot be resolved from 1 in a float32.
constexpr float kMaxSigmoidInput = 18.f;
constexpr float kMinSigmoidInput = -18.f;
// kAConstant ~= 2^23 / ln 2
constexpr uint32_t kAConstant = 0x4b38aa3b;
// kBConstant ~= (127 << 23) - 366000
constexpr uint32_t kBConstant = 0x4e7de9a9;
// Coefficients of the rational approximation to tanh.
// Coefficients of the numerator polynomial (odd).
constexpr float kTanhAlpha1 = 4.89352455891786e-03;
constexpr float kTanhAlpha3 = 6.37261928875436e-04;
constexpr float kTanhAlpha5 = 1.48572235717979e-05;
constexpr float kTanhAlpha7 = 5.12229709037114e-08;
constexpr float kTanhAlpha9 = -8.60467152213735e-11;
constexpr float kTanhAlpha11 = 2.00018790482477e-13;
constexpr float kTanhAlpha13 = -2.76076847742355e-16;
// The monomial coefficients of the denominator polynomial (even).
constexpr float kTanhBeta0 = 4.89352518554385e-03;
constexpr float kTanhBeta2 = 2.26843463243900e-03;
constexpr float kTanhBeta4 = 1.18534705686654e-04;
constexpr float kTanhBeta6 = 1.19825839466702e-06;

// Coefficients of the rational approximation to sigmoid.
// Coefficients of the numerator polynomial (odd).
constexpr float kSigmoidAlpha1 = 2.48287947061529e-01;
constexpr float kSigmoidAlpha3 = 8.51377133304701e-03;
constexpr float kSigmoidAlpha5 = 6.08574864600143e-05;
constexpr float kSigmoidAlpha7 = 1.15627324459942e-07;
constexpr float kSigmoidAlpha9 = 4.37031012579801e-11;

// The monomial coefficients of the denominator polynomial (even).
constexpr float kSigmoidBeta0 = 9.93151921023180e-01;
constexpr float kSigmoidBeta2 = 1.16817656904453e-01;
constexpr float kSigmoidBeta4 = 1.70198817374094e-03;
constexpr float kSigmoidBeta6 = 6.29106785017040e-06;
constexpr float kSigmoidBeta8 = 5.76102136993427e-09;
constexpr float kSigmoidBeta10 = 6.10247389755681e-13;

// x is the first term of the Taylor series approximation of tanh near 0 and
// because the leading error term of tanh(x) - x is O(x^3), it is good for a
// wide interval, use it in this region where the other approximation is
// inaccurate. tanh(x) = x - x^3 / 3 + 2x^5 / 15 - 17x^7 / 315 + ...
// Similarly for sigmoid where the first term is .25x
constexpr float kTanhLinearRegion = .15f;
constexpr float kSigmoidLinearRegion = .75f;

// Maximum shift factor for 1/log 2 to keep it inside int32.
constexpr int kMaxLog2Shift = 30;
static const int kLogFactor = static_cast<int>((1 << kMaxLog2Shift) / log(2.f));
static const float kOneOverLog2 = 1.0f / log(2.f);
// Number of real mantissa bits in IEEE float32.
constexpr int kFloatMantissaBits = 23;
// Offset to correct the exponent value in the resulting float.
constexpr int kFloatExponentOffset = 127 << kFloatMantissaBits;
// Mask for mantissa.
constexpr int kFloatMantissaMask = (1 << kFloatMantissaBits) - 1;
// Mask for exponent;
constexpr int kFloatExponentMask = (-1) ^ kFloatMantissaMask;

// ========== COMMON DOCUMENTATION FOR THE FLOATING EXPONENT TRICK ============
// Summary: Use the exponent-mantissa representation of a floating point number
// to give exponentiation of 2 for free. If we desire f(z) = e^z = 2^(x+n), (for
// some fixed-point z expressed as an integer with imaginary binary point within
// it) then we have to compute x+n = z / ln 2 and then splitting x+n into
// n = int(x+n) and x = fract(x+n) in [0, 1), we can use n and 2^x as the
// exponent and mantissa of a floating point number, and that float is equal to
// e^z. For original reference see:
// http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.9.4508&rep=rep1&type=pdf
// Important detail:
// IEEE floats are stored normalized, ie 1.bbbbbbb... x 2^exponent. The leading
// 1 bit is not actually stored, (as it is always 1), providing an extra bit of
// precision.
// Since 2^0=1 and 2^1=2, we can treat the problem as 2^x = 1 + u and we thus
// need a mapping x in [0, 1) -> u in [0, 1) and the 1 + is provided by the
// representation.
// In the original paper cited above, the mapping is u = x - c, where c is set
// to minimize the average error. The function to compute exp(x) this way is
// incredibly simple and computationally cheap, but not very accurate.
// Fortunately, the problem has been reduced to u = 2^x - 1 over [0, 1) for
// which it is far easier to construct accurate approximations with small
// polynomials than a full range exp(x), and this is what the cubic and quartic
// versions below do. An important feature of these functions is that they
// constrain the solution to be exact at 0 and 1 so there is continuity at each
// integer boundary where we wrap from 1 to 0 and increment the power of 2.

// Coefficients for quartic representation of 2^x - 1 for x on [0,1).
// The quartic representation is 2^x - 1 ~ x - x(1-x)(ax^2 + bx + c), hence the
// coefficients of a quadratic are all that is required.
// Coefficients came from numerical experiments.
constexpr float kExpQuarticFactor2 = 0.0135302434f;
constexpr float kExpQuarticFactor1 = 0.0656107542f;
constexpr float kExpQuarticFactor0 = 0.306963906f;
// Coefficients for cubic representation of 2^x - 1 for x on [0,1]
// The cubic representation is 2^x - 1 ~ x - x(1-x)(mx + c), hence the
// coefficients of a linear function are all that is required.
// Coefficients came from numerical experiments.
constexpr float kExpCubicFactor1 = 0.0780252018f;
constexpr float kExpCubicFactor0 = 0.304684167f;
// Coefficients are optimized to minimize the absolute error on
// tanh = (e^2x - 1) / (e^2x + 1) instead of on pure e^x.

// Enum that determines how a transcendental is computed.
enum TranscendentalMode {
  // Cubic using 16 bit integer arithmetic.
  TM_ORDER3_16BIT,
  // Quartic using 16 bit integer arithmetic.
  TM_ORDER4_16BIT,
  // Quartic using 32 bit float arithmetic.
  TM_ORDER4_FLOAT,
};

inline int FloatAsInt16(float x) {
  return static_cast<int>(x * (1 << 15) + 0.5f);
}

inline int FloatAsInt32(float x) {
  return static_cast<int>(x * (1 << 30) + 0.5f);
}

#if defined __ARM_NEON || defined __aarch64__

constexpr int kMaxSigmoidInputInt = static_cast<int>(kMaxSigmoidInput);

// Computes and returns 2^(x>>23) ie 2^u where x = u << 23 bits.
// Uses the quartic floating point exponent trick, see COMMON DOCUMENTATION FOR
// THE FLOATING EXPONENT TRICK above for details.
// Returns the true value, ie not scaled.
inline float32x4_t float32_pow2(float32x4_t x) {
  // The input is already shifted left by 23 bits, so when we convert to int,
  // the bottom 23 bits are the fractional part, and the top bits are the
  // integer part. We want to compute a function of the fractional part, so
  // we will mask it off and manipulate it.
  int32x4_t exp_int_x = vcvtq_s32_f32(x);
  // Mask to allow conversion of just the fractional part of x to fixed16<0>.
  int32x4_t mantissa_mask16 = vdupq_n_s32(0x7fff00);
  // Mask to allow conversion of just the fractional part of x to fixed32<1>.
  int32x4_t mantissa_mask32 = vdupq_n_s32(0x7fffff);
  // Narrowing shift to convert to fixed16<0>.
  int16x4_t x_16 = vshrn_n_s32(vandq_s32(mantissa_mask16, exp_int_x), 8);
  // Shift to convert to fixed32<1>.
  int32x4_t x_32 = vshlq_n_s32(vandq_s32(mantissa_mask32, exp_int_x), 7);
  // Compute the polynomial x(x - 1)(ax^2 + bx + c) of the fractional part.
  // Ordering these lines carefully makes it faster, as some of the multiply
  // operations can pipeline instead of waiting for the previous result.
  int32x4_t x_squared = vmull_s16(x_16, x_16);
  int16x4_t b = vdup_n_s16(FloatAsInt16(kExpQuarticFactor1));
  int32x4_t c = vdupq_n_s32(FloatAsInt32(kExpQuarticFactor0));
  int32x4_t bx_plus_c = vmlal_s16(c, b, x_16);
  int16x4_t a = vdup_n_s16(FloatAsInt16(kExpQuarticFactor2));
  // Finish the quadratic: result = ax^2 + bx + c.
  int32x4_t result = vmlal_s16(bx_plus_c, a, vshrn_n_s32(x_squared, 15));
  int32x4_t x_squared_minus_x = vsubq_s32(x_squared, x_32);

  // Multiply by x^2 - x.
  result = vqrdmulhq_s32(result, x_squared_minus_x);
  // Shift back to mantissa position. vqrdmulhq_s32 took 2x 30-mantissa bit
  // inputs, made 60-mantissa bit result, doubled it to 61 bits, then discarded
  // the bottom 32 making 29, so shift right 6 to get 23.
  result = vshrq_n_s32(result, 6);
  // Add the constant to normalize the exponent for IEEE format.
  int32x4_t exp_offset = vdupq_n_s32(kFloatExponentOffset);
  exp_int_x = vaddq_s32(exp_int_x, exp_offset);
  exp_int_x = vaddq_s32(exp_int_x, result);
  // Cast back to float, as we just computed the exponent and mantissa and
  // assembled them in IEEE format.
  return vreinterpretq_f32_s32(exp_int_x);
}

// Scaled float to float exp approximation, using a quartic refinement of
// the exponent trick. See COMMON DOCUMENTATION FOR THE FLOATING EXPONENT TRICK
// above for details. Input is a fixed32<31 - mantissa_bits> that has been
// converted to a float without any further shifting. MUST HAVE ALREADY BEEN
// CLIPPED to a suitable range for exp!
// Returns a vector of standard unscaled floats.
inline float32x4_t fixed32_exp_float_preclipped(const int mantissa_bits,
                                                float32x4_t x) {
  // Divide by log 2 to convert problem to 2^x, and scale to match the
  // mantissa bits required by IEEE floats.
  // This is the shift of the FP mantissa relative to the input mantissa.
  const int kXShift = kFloatMantissaBits - mantissa_bits;
  const float kLogFactor = static_cast<float>(1 << kXShift);
  float32x4_t factor = vdupq_n_f32(kLogFactor * kOneOverLog2);
  float32x4_t y = vmulq_f32(x, factor);
  // Now compute 2^x.
  return float32_pow2(y);
}

// uses trick that 2^x can be computed by shifting integer into the
// exponent, see the following reference for a derivation using double:
// goo.gl/aUVTK3
// Input x is clamped to [-64, 64], even infinity and NaN.
// Accurate to within 3% relative across the entire range.
// Fully pipelined throughput is about 10 cycles per fast_exp call.
inline float32x4_t fast_exp(float32x4_t x) {
#if defined FAST_TRANSCENDENTALS && __ARM_ARCH >= 800
  // Uses vcvtnq_s32_f32, not available on ARM v7 NEON.

  // Load A and B, which are defined as integers into float registers.
  float32x4_t A = vreinterpretq_f32_u32(vdupq_n_u32(kAConstant));
  float32x4_t res = vreinterpretq_f32_u32(vdupq_n_u32(kBConstant));

  // Make sure x within the allowed range.
  x = vminq_f32(x, vdupq_n_f32(kMaxExpInput));
  x = vmaxq_f32(x, vdupq_n_f32(kMinExpInput));

  // res = A * x + B.
  // This shifts x into the exponent field and adds the bias.
  res = vmlaq_f32(res, A, x);

  // Convert back to an integer, this is what uses the floating point
  // unit to compute 2^x.
  int32x4_t x_int = vcvtnq_s32_f32(res);

  return vreinterpretq_f32_s32(x_int);
#else
  float32x4_t return_val = vdupq_n_f32(0.f);

  float exponent = expf(vgetq_lane_f32(x, 0));
  return_val = vld1q_lane_f32(&exponent, return_val, 0);

  exponent = expf(vgetq_lane_f32(x, 1));
  return_val = vld1q_lane_f32(&exponent, return_val, 1);
  exponent = expf(vgetq_lane_f32(x, 2));
  return_val = vld1q_lane_f32(&exponent, return_val, 2);
  exponent = expf(vgetq_lane_f32(x, 3));
  return_val = vld1q_lane_f32(&exponent, return_val, 3);

  return return_val;
#endif  // FAST_TRANSCENDENTALS
}

// This version does a conversion of the input to floating point, then calls
// the floating point fast_exp function.  There is another version
// fast_exp_fixed, that never does a conversion and is less accurate, but much
// faster.
template <int ExponentBits>
inline float32x4_t fast_exp(int32x4_t x) {
  return fast_exp(vcvtq_n_f32_s32(x, 31 - ExponentBits));
}

// Performs an exp estimate without doing any floating point operations. The
// result is a floating point number.  See scalar version for an explanation.
template <int ExponentBits>
inline float32x4_t fast_exp_fixed(int32x4_t x) {
  static_assert(ExponentBits > 8, "Must have more than 8 ExponentBits");
  constexpr int kA = 1.4426950408889634 * (1 << (ExponentBits - 8));
  constexpr int kB = (127 << 23) - 366000;

  constexpr int maxInput = 80 << (31 - ExponentBits);
  constexpr int minInput = -maxInput;

  int32x4_t A = vdupq_n_s32(kA);
  int32x4_t res = vdupq_n_s32(kB);

  // Make sure x within the allowed range.
  x = vminq_s32(x, vdupq_n_s32(maxInput));
  x = vmaxq_s32(x, vdupq_n_s32(minInput));

  // res = A * x + B.
  // This shifts x into the exponent field and adds the bias.
  res = vmlaq_s32(res, A, x);

  return vreinterpretq_f32_s32(res);
}

// fast_exp_norange_check uses vcvtnq_s32_f32, not available on ARM v7 NEON.
#if __ARM_ARCH >= 800
namespace detail {
// tanh can do range check once.
// Input x is clamped to [-64, 64], even infinity and NaN.
inline float32x4_t fast_exp_norange_check(float32x4_t x) {
  float32x4_t A = vreinterpretq_f32_u32(vdupq_n_u32(kAConstant));
  float32x4_t res = vreinterpretq_f32_u32(vdupq_n_u32(kBConstant));

  res = vmlaq_f32(res, A, x);

  int32x4_t x_int = vcvtnq_s32_f32(res);

  return vreinterpretq_f32_s32(x_int);
}

}  // namespace detail
#endif  // __ARM_ARCH >= 800

// Clips float input to [-kLimit,kLimit].
inline float32x4_t ClipToFloatBounds(const float kLimit, const float32x4_t x) {
  // Clip to the input bounds for this approximation.
  float32x4_t clip_limit = vdupq_n_f32(kLimit);
  float32x4_t clipped_x = vminq_f32(x, clip_limit);
  clip_limit = vnegq_f32(clip_limit);
  return vmaxq_f32(clipped_x, clip_limit);
}

inline float32x4_t float_tanh_float(const float32x4_t& x) {
  float32x4_t clipped_x = ClipToFloatBounds(kMaxTanhInput, x);
  // Divide by log 2 to convert problem to 2^x, double (as we need exp(2x)) and
  // scale to the mantissa bits required by float32_pow2 all in one multiply.
  // Add one to double the input.
  const float kLogFactor = static_cast<float>(1 << (kFloatMantissaBits + 1));
  float32x4_t factor = vdupq_n_f32(kLogFactor * kOneOverLog2);
  clipped_x = vmulq_f32(clipped_x, factor);
  // Now compute 2^x.
  float32x4_t exp_result = float32_pow2(clipped_x);
  // Now compute tanh using (e^2x - 1) / (e^2x + 1).
  float32x4_t one = vdupq_n_f32(1.0f);
  float32x4_t numerator = vsubq_f32(exp_result, one);
  float32x4_t denominator = vaddq_f32(exp_result, one);
  float32x4_t recp = vrecpeq_f32(denominator);
  // Newton-Raphson iteration, accuracy is important for audio quality
  recp = vmulq_f32(recp, vrecpsq_f32(recp, denominator));
  recp = vmulq_f32(recp, numerator);
  // Compute 3rd-order Taylor tanh ~ x - x^3/3 for high accuracy and thus low
  // relative error close to 0.
  float32x4_t third = vdupq_n_f32(1.0f / 3.0f);
  float32x4_t taylor = vmulq_f32(x, x);
  taylor = vmulq_f32(taylor, x);
  taylor = vmulq_f32(taylor, third);
  taylor = vsubq_f32(x, taylor);
  // Test |x| <= 1/9, roughly where the errors cross over, without needing yet
  // another constant.
  float32x4_t ninth = vmulq_f32(third, third);
  uint32x4_t cmp_results = vcaleq_f32(x, ninth);
  return vbslq_f32(cmp_results, taylor, recp);
}

// Calculates (exp(x) - exp(-x)) / (exp(x) + exp(-x)).
// Input x is clamped to [-9, 9], even infinity and NaN.
// See test program for bounds.  Throughput of FAST is 334 Mega/sec,
// throughput of accurate is 232 Mega/sec.
inline float32x4_t fast_tanh(float32x4_t x) {
#if defined FASTER_TRANSCENDENTALS
  return float_tanh_float(x);
#elif defined ACCURATE_TRANSCENDENTAL_APPROX && defined FAST_TRANSCENDENTALS
  x = vminq_f32(x, vdupq_n_f32(kMaxTanhInput));
  x = vmaxq_f32(x, vdupq_n_f32(kMinTanhInput));

  // The monomial coefficients of the numerator polynomial (odd).
  const float32x4_t alpha_1 = vdupq_n_f32(kTanhAlpha1);
  const float32x4_t alpha_3 = vdupq_n_f32(kTanhAlpha3);
  const float32x4_t alpha_5 = vdupq_n_f32(kTanhAlpha5);
  const float32x4_t alpha_7 = vdupq_n_f32(kTanhAlpha7);
  const float32x4_t alpha_9 = vdupq_n_f32(kTanhAlpha9);
  const float32x4_t alpha_11 = vdupq_n_f32(kTanhAlpha11);
  const float32x4_t alpha_13 = vdupq_n_f32(kTanhAlpha13);

  // The monomial coefficients of the denominator polynomial (even).
  const float32x4_t beta_0 = vdupq_n_f32(kTanhBeta0);
  const float32x4_t beta_2 = vdupq_n_f32(kTanhBeta2);
  const float32x4_t beta_4 = vdupq_n_f32(kTanhBeta4);
  const float32x4_t beta_6 = vdupq_n_f32(kTanhBeta6);

  // Since the polynomials are odd/even, we need x^2.
  const float32x4_t x2 = vmulq_f32(x, x);

  // Evaluate the numerator polynomial |p|.
  float32x4_t p = vmlaq_f32(alpha_11, x2, alpha_13);
  p = vmlaq_f32(alpha_9, x2, p);
  p = vmlaq_f32(alpha_7, x2, p);
  p = vmlaq_f32(alpha_5, x2, p);
  p = vmlaq_f32(alpha_3, x2, p);
  p = vmlaq_f32(alpha_1, x2, p);
  p = vmulq_f32(x, p);

  // Evaluate the denominator polynomial p.
  float32x4_t q = vmlaq_f32(beta_4, x2, beta_6);
  q = vmlaq_f32(beta_2, x2, q);
  q = vmlaq_f32(beta_0, x2, q);

  // Divide the numerator by the denominator.
  float32x4_t recp = vrecpeq_f32(q);
  recp = vmulq_f32(recp, vrecpsq_f32(recp, q));
  return vmulq_f32(p, recp);
#elif defined FAST_TRANSCENDENTALS && __ARM_ARCH >= 800
  // Uses vcvtnq_s32_f32, not available on ARM v7 NEON.

  x = vminq_f32(x, vdupq_n_f32(kMaxTanhInput));
  x = vmaxq_f32(x, vdupq_n_f32(kMinTanhInput));
  float32x4_t exp_est = detail::fast_exp_norange_check(x);
  float32x4_t neg_exp_est = detail::fast_exp_norange_check(-x);

  // If we're in the linear region.
  // caleq = compare absolute <=
  uint32x4_t cmp_results = vcaleq_f32(x, vdupq_n_f32(kTanhLinearRegion));

  float32x4_t diff = vsubq_f32(exp_est, neg_exp_est);
  float32x4_t sum = vaddq_f32(exp_est, neg_exp_est);
  float32x4_t recp = vrecpeq_f32(sum);
  recp = vmulq_f32(recp, vrecpsq_f32(recp, sum));
  float32x4_t tanh_estimate = vmulq_f32(diff, recp);

  // Based on comparison, possibly copy x through instead of calculated value.
  // TODO(b/191497441): Is the compiler generating VBIT or VBSL ? VBIT is one
  // cycle and VBSL is two... documentation suggests it can do either.
  return vbslq_f32(cmp_results, x, tanh_estimate);
#else
  float32x4_t return_val = vdupq_n_f32(0.f);

  float tanh_value = tanhf(vgetq_lane_f32(x, 0));
  return_val = vld1q_lane_f32(&tanh_value, return_val, 0);
  tanh_value = tanhf(vgetq_lane_f32(x, 1));
  return_val = vld1q_lane_f32(&tanh_value, return_val, 1);
  tanh_value = tanhf(vgetq_lane_f32(x, 2));
  return_val = vld1q_lane_f32(&tanh_value, return_val, 2);
  tanh_value = tanhf(vgetq_lane_f32(x, 3));
  return_val = vld1q_lane_f32(&tanh_value, return_val, 3);

  return return_val;
#endif  // FAST_TRANSCENDENTALS
}

// Input x is clamped to [-18, 18], even infinity and NaN.
// See tests for error bounds.  Using SIGMOID_AS_TANH with
// ACCURATE_TRANSCENDENTAL_APPROX is both faster and more accurate.  Using
// SIGMOID_AS_TANH with just FAST is slower, but more accurate.
// SIGMOID_AS_TANH, ACCURATE is 205 Mega/sec
// SIGMOID_AS_TANH, FAST is 290 Mega/sec
// FAST is 340 Mega/sec
inline float32x4_t fast_sigmoid(float32x4_t x) {
#ifdef SIGMOID_AS_TANH
  float32x4_t half = vdupq_n_f32(0.5f);
  return vmlaq_f32(half, half, fast_tanh(vmulq_f32(half, x)));
#else  // SIGMOID_AS_TANH
#if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX
  x = vminq_f32(x, vdupq_n_f32(kMaxSigmoidInput));
  x = vmaxq_f32(x, vdupq_n_f32(kMinSigmoidInput));

  // The monomial coefficients of the numerator polynomial (odd).
  const float32x4_t alpha_1 = vdupq_n_f32(kSigmoidAlpha1);
  const float32x4_t alpha_3 = vdupq_n_f32(kSigmoidAlpha3);
  const float32x4_t alpha_5 = vdupq_n_f32(kSigmoidAlpha5);
  const float32x4_t alpha_7 = vdupq_n_f32(kSigmoidAlpha7);
  const float32x4_t alpha_9 = vdupq_n_f32(kSigmoidAlpha9);

  // The monomial coefficients of the denominator polynomial (even).
  const float32x4_t beta_0 = vdupq_n_f32(kSigmoidBeta0);
  const float32x4_t beta_2 = vdupq_n_f32(kSigmoidBeta2);
  const float32x4_t beta_4 = vdupq_n_f32(kSigmoidBeta4);
  const float32x4_t beta_6 = vdupq_n_f32(kSigmoidBeta6);
  const float32x4_t beta_8 = vdupq_n_f32(kSigmoidBeta8);
  const float32x4_t beta_10 = vdupq_n_f32(kSigmoidBeta10);

  // Since the polynomials are odd/even, we need x^2.
  const float32x4_t x2 = vmulq_f32(x, x);

  // Evaluate the numerator polynomial p.
  float32x4_t p = vmlaq_f32(alpha_7, x2, alpha_9);
  p = vmlaq_f32(alpha_5, x2, p);
  p = vmlaq_f32(alpha_3, x2, p);
  p = vmlaq_f32(alpha_1, x2, p);
  p = vmulq_f32(x, p);

  // Evaluate the denominator polynomial p.
  float32x4_t q = vmlaq_f32(beta_8, x2, beta_10);
  q = vmlaq_f32(beta_6, x2, q);
  q = vmlaq_f32(beta_4, x2, q);
  q = vmlaq_f32(beta_2, x2, q);
  q = vmlaq_f32(beta_0, x2, q);

  // Divide the numerator by the denominator.
  float32x4_t recp = vrecpeq_f32(q);
  recp = vmulq_f32(recp, vrecpsq_f32(recp, q));
  return vmlaq_f32(vdupq_n_f32(0.5f), p, recp);
#elif defined FAST_TRANSCENDENTALS
  float32x4_t denom = vaddq_f32(fast_exp(vnegq_f32(x)), vdupq_n_f32(1.f));

  float32x4_t recp = vrecpeq_f32(denom);
  // Newton-Raphson iteration, accuracy is important for audio quality.
  recp = vmulq_f32(recp, vrecpsq_f32(recp, denom));
  float32x4_t half = vdupq_n_f32(0.5f);
  float32x4_t quarter = vdupq_n_f32(0.245f);
  float32x4_t linear_approx = vmlaq_f32(half, quarter, x);
  uint32x4_t cmp_results = vcaleq_f32(x, vdupq_n_f32(kSigmoidLinearRegion));

  return vbslq_f32(cmp_results, linear_approx, recp);
#else
  float32x4_t return_val = vdupq_n_f32(0.f);

  float result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 0)));
  return_val = vld1q_lane_f32(&result, return_val, 0);
  result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 1)));
  return_val = vld1q_lane_f32(&result, return_val, 1);
  result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 2)));
  return_val = vld1q_lane_f32(&result, return_val, 2);
  result = 1.f / (1.f + expf(-vgetq_lane_f32(x, 3)));
  return_val = vld1q_lane_f32(&result, return_val, 3);

  return return_val;
#endif  // FAST_TRANSCENDENTALS
#endif  // SIGMOID_AS_TANH
}

// Scalar implementations, mainly useful for testing.
inline float fast_exp(float x) {
  return vgetq_lane_f32(fast_exp(vdupq_n_f32(x)), 0);
}

template <int ExponentBits>
inline float fast_exp(fixed32<ExponentBits> x) {
  return vgetq_lane_f32(fast_exp<ExponentBits>(vdupq_n_s32(x.raw_val())), 0);
}

// Returns the exponent of a fixed point number in floating point without ever
// doing any conversions.  Less accurate than the version that does conversions,
// but still accurate to within 4% relative for x < 16.
template <int ExponentBits>
inline float fast_exp_fixed(fixed32<ExponentBits> x) {
  return vgetq_lane_f32(fast_exp_fixed<ExponentBits>(vdupq_n_s32(x.raw_val())),
                        0);
}

inline float fast_sigmoid(float x) {
  return vgetq_lane_f32(fast_sigmoid(vdupq_n_f32(x)), 0);
}

inline float fast_tanh(float x) {
  return vgetq_lane_f32(fast_tanh(vdupq_n_f32(x)), 0);
}

// Clips integer input to [-|kLimit|, |kLimit|].
// Input: register containins 4x fixed32 with mantissa_bits.
// Output: register containing 4x fixed32 limited to
// [-|kLimit| << |mantissa_bits|, |kLimit| << |mantissa_bits|].
template <int kLimit>
inline int32x4_t ClipToBounds(const int mantissa_bits, const int32x4_t x) {
  // Clip to the input bounds for this approximation.
  int32x4_t clip_limit = vdupq_n_s32(-(kLimit << mantissa_bits));
  int32x4_t clipped_x = vmaxq_s32(x, clip_limit);
  clip_limit = vnegq_s32(clip_limit);
  return vminq_s32(clipped_x, clip_limit);
}

// Fixed32 sigmoid approximation via a quadratic refinement of the exponent
// trick.
// Input: Register containing 4x fixed32 with |mantissa_bits|.
// Output: Register containing 4x float results.
inline float32x4_t fixed32_sigmoid_float(const int mantissa_bits,
                                         const int32x4_t x) {
  int32x4_t input = vnegq_s32(x);
  float32x4_t y =
      vcvtq_f32_s32(ClipToBounds<kMaxSigmoidInputInt>(mantissa_bits, input));
  y = fixed32_exp_float_preclipped(mantissa_bits, y);
  float32x4_t one = vdupq_n_f32(1.0f);
  // Approximate reciprocal is not accurate enough - use full division.
  float32x4_t denom = vaddq_f32(y, one);
  float32x4_t recp = vrecpeq_f32(denom);
  // Newton-Raphson iteration, accuracy is important for audio quality
  recp = vmulq_f32(recp, vrecpsq_f32(recp, denom));
  return recp;
}

template <int ExponentBits>
inline float32x4_t fast_sigmoid(int32x4_t x) {
#if defined FASTER_TRANSCENDENTALS
  // Computation will fail to produce the right result if the input mantissa
  // bits exceeds the number in a float.
  static_assert(kFloatMantissaBits >= fixed32<ExponentBits>::kMantissaBits,
                "Mantissa bits must be at most 23!");
  return fixed32_sigmoid_float(fixed32<ExponentBits>::kMantissaBits, x);
#else
  return fast_sigmoid(vcvtq_n_f32_s32(x, fixed32<ExponentBits>::kMantissaBits));
#endif  // FASTER_TRANSCENDENTALS
}

template <int ExponentBits>
inline float fast_sigmoid(fixed32<ExponentBits> x) {
  return vgetq_lane_f32(fast_sigmoid<ExponentBits>(vdupq_n_s32(x.raw_val())),
                        0);
}

#else  // defined __ARM_NEON || defined __aarch64__

inline float fast_exp(float x) {
#ifdef FAST_TRANSCENDENTALS
  if (isnan(x)) return 0.0f;
  x = std::max(std::min(x, kMaxExpInput), kMinExpInput);
  float AConstant, BConstant;
  memcpy(&AConstant, &kAConstant, sizeof(int));
  memcpy(&BConstant, &kBConstant, sizeof(int));
  float y = x * AConstant + BConstant;
  int x_int = static_cast<int>(y);
  float ret;
  memcpy(&ret, &x_int, sizeof(float));
  return ret;
#else
  return expf(x);
#endif  // FAST_TRANSCENDENTALS
}

template <int ExponentBits>
inline float fast_exp(fixed32<ExponentBits> x) {
  return fast_exp(static_cast<float>(x));
}

template <int ExponentBits>
inline float fast_exp_fixed(fixed32<ExponentBits> x) {
  static_assert(ExponentBits > 8, "Must have more than 8 ExponentBits");
  int matched_decimal =
      std::max(std::min(x.raw_val(), (80 << (31 - ExponentBits))),
               -(80 << (31 - ExponentBits)));
  // Convert 1 / log(2) to 16-bit fixed point with 1 exponent bit
  // (1 / log(2)) * (1 << 14), but then right shift by the appropriate amount to
  // line the decimal point up with the 32-bit float representation.
  // (MantissaBits of x) + (MantissaBits of constant) = 23
  // 23 - (MantissaBits of x) = MantissaBits of constant
  // 23 - (31 - ExponentBits of x) = ...
  // (ExponentBits of x - 8) = MantissaBits of constant
  const int16_t A = (1.f / logf(2.f)) * (1 << (ExponentBits - 8));
  // Same rationale as for floating point versions, bias exponent, subtract
  // 366000 to reduce error by centering approximation, instead of being
  // one-sided.
  const int B = (127 << 23) - 366000;
  matched_decimal = A * matched_decimal + B;
  float ret_val;
  memcpy(&ret_val, &matched_decimal, sizeof(float));
  return ret_val;
}

inline float fast_tanh(float x) {
#if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX
  // Doesn't do anything fancy, just a 13/6-degree rational interpolant which
  // is accurate up to a couple of ulp in the range [-9, 9], outside of which
  // fl(tanh(x)) = +/-1.
  x = std::max(std::min(x, kMaxTanhInput), kMinTanhInput);

  // Since the polynomials are odd/even, we need x^2.
  float x2 = x * x;

  // Evaluate numerator.
  float p = kTanhAlpha11 + x2 * kTanhAlpha13;
  p = kTanhAlpha9 + x2 * p;
  p = kTanhAlpha7 + x2 * p;
  p = kTanhAlpha5 + x2 * p;
  p = kTanhAlpha3 + x2 * p;
  p = kTanhAlpha1 + x2 * p;
  p = x * p;

  // Evaluate denominator.
  float q = kTanhBeta4 + x2 * kTanhBeta6;
  q = kTanhBeta2 + x2 * q;
  q = kTanhBeta0 + x2 * q;

  return p / q;
#elif defined FAST_TRANSCENDENTALS
  if (std::abs(x) < kTanhLinearRegion) {
    return x;
  } else {
    x = std::max(std::min(x, kMaxTanhInput), kMinTanhInput);
    float positive = fast_exp(x);
    float negative = fast_exp(-x);
    return (positive - negative) / (positive + negative);
  }
#else
  return tanhf(x);
#endif  // FAST_TRANSCENDENTALS
}

inline float fast_sigmoid(float x) {
#ifdef SIGMOID_AS_TANH
  return .5f * fast_tanh(.5f * x) + .5f;
#else
#if defined FAST_TRANSCENDENTALS && defined ACCURATE_TRANSCENDENTAL_APPROX
  // Doesn't do anything fancy, just a 9/10-degree rational interpolant which
  // interpolates 1/(1+exp(-x)) - 0.5 up to a couple of ulp in the range
  // [-18, 18], outside of which the fl(sigmoid(x)) = {0|1}. The shifted
  // sigmoid is interpolated because it was easier to make the fit converge.
  // See GenericPacketMath.h* in the open source Eigen library.
  x = std::max(std::min(x, kMaxSigmoidInput), kMinSigmoidInput);

  // Since the polynomials are odd/even, we need x^2.
  float x2 = x * x;

  // Evaluate numerator.
  float p = kSigmoidAlpha7 + x2 * kSigmoidAlpha9;
  p = kSigmoidAlpha5 + x2 * p;
  p = kSigmoidAlpha3 + x2 * p;
  p = kSigmoidAlpha1 + x2 * p;
  p = x * p;

  // Evaluate denominator.
  float q = kSigmoidBeta8 + x2 * kSigmoidBeta10;
  q = kSigmoidBeta6 + x2 * q;
  q = kSigmoidBeta4 + x2 * q;
  q = kSigmoidBeta2 + x2 * q;
  q = kSigmoidBeta0 + x2 * q;

  return p / q + 0.5f;
#elif defined FAST_TRANSCENDENTALS
  if (std::abs(x) < kSigmoidLinearRegion) {
    return .245 * x + .5;
  } else {
    return 1.f / (1.f + fast_exp(-x));
  }
#else
  return 1.f / (1.f + expf(-x));
#endif  // FAST_TRANSCENDENTALS
#endif  // SIGMOID_AS_TANH
}

template <int ExponentBits>
inline float fast_sigmoid(fixed32<ExponentBits> x) {
  return fast_sigmoid(static_cast<float>(x));
}

#endif  // defined __aarch64__

// Number of exponent bits to use for tanh.
static constexpr int kNumTanhExpBits = 3;
// Number of exponent bits to use for sigmoid.
static constexpr int kNumSigmoidExpBits = 4;
// Number of extra bits to shift sigmoid, due to its low gradient.
static constexpr int kNumExtraSigmoidShiftBits = 1;

// Returns (and builds if not done yet) a static data table (that is never
// deleted, as per the style guide) that implements tanh on fixed32 input,
// returning another fixed32 with the given number of mantissa bits (which is
// assumed to be less than the input mantissa bits).
// NOTE that this function is intended to be used only with fixed16 outputs that
// are sign-extended to 32 bits for convenience, and will return a nullptr
// if asked for more than |kMaxMantissaBits| of precision in the output table.
const int* TanhTable(int num_mantissa_bits_out);
// As TanhTable, but for Sigmoid.
const int* SigmoidTable(int num_mantissa_bits_out);

// Scalar/generic function to compute and return the fast approximation to exp
// via a polynomial refinement of the floating point exponent trick.
// TM_ORDER4_16BIT:Max relative error < 5e-6, absolute error < 1e-5 for x < 1.
// TM_ORDER3_16BIT:Max relative error < 1.1e-4, absolute error < 3e-4 for x
// < 1.
template <int kExponentBits, TranscendentalMode kOrder = TM_ORDER4_16BIT>
float fixed32_exp(fixed32<kExponentBits> x) {
  constexpr int kMantissaBits = MantissaBitsOf<fixed32<kExponentBits>>::value;
  // Clip x to min/max exp input to avoid infinities.
  int64_t clipped_x =
      std::max(std::min(x.raw_val(), kMaxExpInputInt << kMantissaBits),
               -(kMaxExpInputInt << kMantissaBits));
  // First convert problem from e^x to 2^x by multiplying by 1/log(2).
  // To maximize precision, log_factor is shifted left the maximum amount to
  // keep within int32, and we shift x left a further amount such that the
  // binary point of the product sits in the correct place in the top 32 bits of
  // the result to be used directly as a float. We can't do that directly, as x
  // would overflow, so we have to shift by 1 bit less and shift the result by
  // 1 bit less to match.
  constexpr int kXShift =
      kFloatMantissaBits + 31 - kMaxLog2Shift - kMantissaBits;
  static_assert(kXShift >= 0,
                "Mantissa bits > kFloatMantissaBits + 31 - kMaxLog2Shift");
  clipped_x <<= kXShift;
  int float_as_int = (kLogFactor * clipped_x >> 31) + kFloatExponentOffset;
  // Separate the resulting fixed-point into integer and fractional parts.
  int int_part = float_as_int & kFloatExponentMask;
  int float_part = float_as_int & kFloatMantissaMask;
  float fraction = static_cast<float>(float_part) / (1 << kFloatMantissaBits);
  // Compute the mantissa = 2^fraction using:
  // fraction - fraction*(1-fraction)*(polynomial of fraction)
  // This guarantees exactness at 0 and 1, providing continuity of the error at
  // integer boundaries.
  float mantissa;
  if (kOrder == TM_ORDER4_16BIT || kOrder == TM_ORDER4_FLOAT) {
    mantissa = (kExpQuarticFactor2 * fraction + kExpQuarticFactor1) * fraction +
               kExpQuarticFactor0;
  } else if (kOrder == TM_ORDER3_16BIT) {
    mantissa = kExpCubicFactor1 * fraction + kExpCubicFactor0;
  }
  mantissa = fraction - fraction * (1.0f - fraction) * mantissa;
  // Since the function above guarantees to stay within [0, 1), we could do all
  // the above in fixed point if necessary, in which case, we can just stuff
  // the bottom kFloatMantissaBits in with the exponent and we are done.
  // In the floating point world, it is simpler to just multiply them together.
  float result;
  memcpy(&result, &int_part, sizeof(float));
  return result * (1.0f + mantissa);
}

// Computes and returns tanh(x) fixed32->float using a polynomial refinement of
// the floating point exponent trick.
// kOrder=4: Absolute error < 1.8e-6. Relative error < 1.2e-4 for |x| > 0.01.
// kOrder=3: Absolute error < 6e-5. Relative error < 3e-3 for |x| > 0.01
template <int kExponentBits, TranscendentalMode kOrder = TM_ORDER4_16BIT>
float fixed32_tanh(fixed32<kExponentBits> x) {
  float float_x = static_cast<float>(x);
  if (std::abs(float_x) < 1.0f / 9.0f) {
    return float_x * (1 - float_x * float_x / 3.0f);
  }
  x = static_cast<fixed32<kExponentBits>>(x.raw_val() * 2);
  float exp_2x = fixed32_exp<kExponentBits, kOrder>(x);
  return (exp_2x - 1.0f) / (exp_2x + 1.0f);
}

// Computes and returns sigmoid(x) fixed32->float using a polynomial refinement
// of the floating point exponent trick.
// TM_ORDER4_16BIT: Absolute error < 9e-7, relative < 4e-6.
// TM_ORDER3_16BIT: Absolute error < 3e-5, relative < 1.1e-4.
template <int kExponentBits, TranscendentalMode kOrder = TM_ORDER4_16BIT>
float fixed32_sigmoid(fixed32<kExponentBits> x) {
  x = static_cast<fixed32<kExponentBits>>(-x.raw_val());
  float exp_x = fixed32_exp<kExponentBits, kOrder>(x);
  return 1.0f / (exp_x + 1.0f);
}

#if defined __AVX2__

// Inline function to access an int32 data table by shifting |x| right by
// |kNumShiftBits|, and adding |kTableOffset| to the result. |x| contains 8
// indices and 8 results are returned. The data table is of size
// |kTableOffset| * 2 + 1.
template <int kNumShiftBits, int kTableOffset>
inline __m256i index_data_table(const int32_t* data_table, const __m256i& x) {
  // Shift right with rounding to match input and output precision.
  __m256i shifted = _mm256_set1_epi32(1 << (kNumShiftBits - 1));
  shifted = _mm256_add_epi32(x, shifted);
  shifted = _mm256_srai_epi32(shifted, kNumShiftBits);
  // Add the offset.
  __m256i addend = _mm256_set1_epi32(kTableOffset);
  shifted = _mm256_add_epi32(shifted, addend);
  // And clamp to the indices of the LUT.
  addend = _mm256_add_epi32(addend, addend);
  shifted = _mm256_min_epi32(shifted, addend);
  shifted = _mm256_max_epi32(shifted, _mm256_setzero_si256());
  // Lookup the results in the table.
  return _mm256_i32gather_epi32(data_table, shifted, 4);
}

// Fixed32 to fixed16-in-an-int32 tanh LUT function.
// Input: register containins 8x fixed32 with |NumInputMantissaBits|.
// Output: a register containing 8x fixed16 with |NumOutputMantissaBits|, but
// note that they are sign-extended to 32 bits and are therefore basically the
// same as fixed32 with |NumOutputMantissaBits|.
template <int NumInputMantissaBits, int NumOutputMantissaBits>
inline __m256i fixed32_tanh_fixed16(const int* tanh_table, const __m256i& x) {
  // Lose the unnecessary input precision.
  constexpr int kNumShiftBits = NumInputMantissaBits - NumOutputMantissaBits;
  constexpr int kTableOffset = 1 << (NumOutputMantissaBits + kNumTanhExpBits);
  return index_data_table<kNumShiftBits, kTableOffset>(tanh_table, x);
}

// Fixed32 to fixed16-in-an-int32 sigmoid LUT function.
// Input: register containins 8x fixed32 with |NumInputMantissaBits|.
// Output: a register containing 8x fixed16 with |NumOutputMantissaBits|, but
// note  that they are sign-extended to 32 bits and are therefore basically the
// same as fixed32 with |NumOutputMantissaBits|.
template <int NumInputMantissaBits, int NumOutputMantissaBits>
inline __m256i fixed32_sigmoid_fixed16(const int* sigmoid_table,
                                       const __m256i& x) {
  // Lose the unnecessary input precision.
  constexpr int kNumShiftBits =
      kNumExtraSigmoidShiftBits + NumInputMantissaBits - NumOutputMantissaBits;
  constexpr int kTableOffset = 1
                               << (NumOutputMantissaBits + kNumSigmoidExpBits -
                                   kNumExtraSigmoidShiftBits);
  return index_data_table<kNumShiftBits, kTableOffset>(sigmoid_table, x);
}

// Convert 2x registers of 8x float32 into 1 register of 16x16 bit fixed int,
// assuming that the floats are already scaled up.
inline __m256i PackFloatsToFixed16(const __m256& x0, const __m256& x1) {
  __m256i int0 = _mm256_cvtps_epi32(x0);
  __m256i int1 = _mm256_cvtps_epi32(x1);
  int0 = _mm256_packs_epi32(int0, int1);
  // Swap the middle 64 bit elements so the results are in the right order.
  return _mm256_permute4x64_epi64(int0, 0xd8);
}

// Clips integer input to [-|kLimit|, |kLimit|].
// Input: register containins 8x fixed32 with |mantissa_bits|.
// Output: register containing 8x fixed32 limited to
// [-|kLimit| << |mantissa_bits|, |kLimit| << |mantissa_bits|].
template <int kLimit>
inline __m256i ClipToBounds(const int mantissa_bits, const __m256i& x) {
  // Clip to the input bounds for this approximation.
  __m256i clip_limit = _mm256_set1_epi32(-(kLimit << mantissa_bits));
  __m256i clipped_x = _mm256_max_epi32(x, clip_limit);
  // This quickly negates the limit without having to load another constant.
  clip_limit = _mm256_sign_epi32(clip_limit, clip_limit);
  return _mm256_min_epi32(clipped_x, clip_limit);
}

// Clips float input to [-|kLimit|, |kLimit|].
// Input: register containins 8x float.
// Output: register containing 8x float limited to [-|kLimit|, |kLimit|].
inline __m256 ClipToFloatBounds(const float kLimit, const __m256& x) {
  __m256 clip_limit = _mm256_set1_ps(kLimit);
  __m256 clipped_x = _mm256_min_ps(x, clip_limit);
  clip_limit = _mm256_set1_ps(-kLimit);
  return _mm256_max_ps(clipped_x, clip_limit);
}

// Float to float power of 2 approximation, using a quartic refinement of
// the exponent trick. For TM_ORDER4_16BIT and TM_ORDER3_16BIT, implementation
// is entirely in integer, using 16x16=16 multiplication, using AVX2, which
// enables 16 elements to be computed in parallel, hence the double register
// input/output args.
// The price paid for this speed is an increase in error over the (scalar) int32
// example implementations above by a variable factor of 4-10.
// For the TM_ORDER4_FLOAT case, the computation is all done in float, solving
// this lower precision problem.
// NOTE: The input must have already been clipped to prevent overflow, which
// sets the practical limit to +/-126 << kFloatMantissaBits.
// NOTE: The input is a scaled float, as if converted raw from int, and the
// scale factor is fixed at kFloatMantissaBits!
// Input: 2x register containining 8x float * 1 << kFloatMantissaBits.
// Output: 2x register containing 8x float.
// TM_ORDER4_FLOAT: Max relative error < 8e-6, absolute error < 9e-6 for x < 1.
// TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 6e-5 for x < 1.
// TM_ORDER3_16BIT: Max relative error < 6e-4, absolute error < 2e-3 for x < 1.
template <TranscendentalMode kOrder = TM_ORDER4_16BIT>
inline void float32_pow2(__m256& x0, __m256& x1) {
  // Convert straight to int.
  __m256i exp_int_x0 = _mm256_cvtps_epi32(x0);
  __m256i exp_int_x1 = _mm256_cvtps_epi32(x1);
  __m256i result_x0, result_x1;

  static_assert(kOrder == TM_ORDER4_FLOAT || kOrder == TM_ORDER4_16BIT ||
                    kOrder == TM_ORDER3_16BIT,
                "Invalid order.");

  if (kOrder == TM_ORDER4_FLOAT) {
    __m256i mantissa_mask = _mm256_set1_epi32(0x7fffff);
    __m256 float_factor =
        _mm256_set1_ps(1.0f / static_cast<float>(1 << kFloatMantissaBits));
    __m256i fract0 = _mm256_and_si256(mantissa_mask, exp_int_x0);
    __m256i fract1 = _mm256_and_si256(mantissa_mask, exp_int_x1);
    __m256 float0 = _mm256_mul_ps(_mm256_cvtepi32_ps(fract0), float_factor);
    __m256 float1 = _mm256_mul_ps(_mm256_cvtepi32_ps(fract1), float_factor);
    // Compute the polynomial of the fractional part.
    // Ordering these lines carefully makes it faster, as some of the multiply
    // operations can pipeline instead of waiting for the previous result.
    __m256 x_squared0 = _mm256_mul_ps(float0, float0);
    __m256 x_squared1 = _mm256_mul_ps(float1, float1);
    __m256 b = _mm256_set1_ps(kExpQuarticFactor1);
    __m256 b_x0 = _mm256_mul_ps(b, float0);
    __m256 b_x1 = _mm256_mul_ps(b, float1);
    __m256 a = _mm256_set1_ps(kExpQuarticFactor2);
    __m256 a_x_squared0 = _mm256_mul_ps(a, x_squared0);
    __m256 a_x_squared1 = _mm256_mul_ps(a, x_squared1);
    __m256 x_squared_minus_x0 = _mm256_sub_ps(x_squared0, float0);
    __m256 x_squared_minus_x1 = _mm256_sub_ps(x_squared1, float1);
    __m256 c = _mm256_set1_ps(kExpQuarticFactor0);
    b_x0 = _mm256_add_ps(b_x0, c);
    b_x1 = _mm256_add_ps(b_x1, c);
    float_factor = _mm256_set1_ps(static_cast<float>(1 << kFloatMantissaBits));
    a_x_squared0 = _mm256_add_ps(a_x_squared0, b_x0);
    a_x_squared1 = _mm256_add_ps(a_x_squared1, b_x1);
    a_x_squared0 = _mm256_mul_ps(a_x_squared0, x_squared_minus_x0);
    a_x_squared1 = _mm256_mul_ps(a_x_squared1, x_squared_minus_x1);
    result_x0 = _mm256_cvtps_epi32(_mm256_mul_ps(a_x_squared0, float_factor));
    result_x1 = _mm256_cvtps_epi32(_mm256_mul_ps(a_x_squared1, float_factor));
  } else {
    // Combine the fractional part of both inputs into a single register.
    // The representation is fixed16<0>, ie 15 mantissa bits.
    __m256i mantissa_mask = _mm256_set1_epi32(0x7fff00);
    __m256i x_01 =
        _mm256_srli_epi32(_mm256_and_si256(mantissa_mask, exp_int_x0), 8);
    x_01 = _mm256_or_si256(
        x_01,
        _mm256_slli_epi32(_mm256_and_si256(mantissa_mask, exp_int_x1), 8));
    // Compute the polynomial of the fractional part.
    // Ordering these lines carefully makes it faster, as some of the multiply
    // operations can pipeline instead of waiting for the previous result.
    __m256i x_squared = _mm256_mulhrs_epi16(x_01, x_01);
    __m256i result, x_squared_minus_x;
    if (kOrder == TM_ORDER4_16BIT) {
      __m256i b = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor1));
      __m256i b_x = _mm256_mulhrs_epi16(b, x_01);
      __m256i a = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor2));
      __m256i a_x_squared = _mm256_mulhrs_epi16(a, x_squared);
      x_squared_minus_x = _mm256_sub_epi16(x_squared, x_01);
      // LOG(INFO) << "x_squared_minus_x=" <<
      // static_cast<int16>(_mm256_extract_epi16(x_squared_minus_x, 0)) /
      // 32768.0f;
      __m256i c = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor0));
      b_x = _mm256_add_epi16(b_x, c);
      // LOG(INFO) << "bx+c=" << static_cast<int16>(_mm256_extract_epi16(b_x,
      // 0)) / 32768.0f;
      result = _mm256_add_epi16(a_x_squared, b_x);
    } else {  // kOrder = TM_ORDER3_16BIT
      __m256i a = _mm256_set1_epi16(FloatAsInt16(kExpCubicFactor1));
      __m256i b = _mm256_set1_epi16(FloatAsInt16(kExpQuarticFactor0));
      __m256i a_x = _mm256_mulhrs_epi16(a, x_01);
      x_squared_minus_x = _mm256_sub_epi16(x_squared, x_01);
      result = _mm256_add_epi16(a_x, b);
    }
    result = _mm256_mulhrs_epi16(result, x_squared_minus_x);
    // Extract 16x16-bit results back to the separate sets of 8x32.
    result_x0 = _mm256_slli_epi32(result, 16);
    result_x0 = _mm256_srai_epi32(result_x0, 8);
    result_x1 = _mm256_srai_epi32(result, 16);
    result_x1 = _mm256_slli_epi32(result_x1, 8);
  }
  // Add the constant to normalize the exponent.
  __m256i exp_offset = _mm256_set1_epi32(kFloatExponentOffset);
  exp_int_x0 = _mm256_add_epi32(exp_int_x0, exp_offset);
  exp_int_x0 = _mm256_add_epi32(exp_int_x0, result_x0);
  exp_int_x1 = _mm256_add_epi32(exp_int_x1, exp_offset);
  exp_int_x1 = _mm256_add_epi32(exp_int_x1, result_x1);
  // Cast back to float, as we just computed the exponent and mantissa and
  // assembled them in IEEE format.
  x0 = _mm256_castsi256_ps(exp_int_x0);
  x1 = _mm256_castsi256_ps(exp_int_x1);
}

// Fixed32 to to float exp approximation, using a quartic/cubic refinement of
// the exponent trick. Implementation is entirely in integer, using 16x16=16
// multiplication, using AVX2, which enables 16 elements to be computed in
// parallel, hence the double register input/output args.
// The price paid for this speed is an increase in error over the (scalar) int32
// example implementations above by a variable factor of 4-10.
// The TM_ORDER4_FLOAT version uses floats and improves the precision.
// Input: 2x registers containins 8x fixed32 with kMantissaBits.
// Output: 2x registers containing 8x float32.
// TM_ORDER4_FLOAT: Max relative error < 8e-6, absolute error < 9e-6 for x < 1.
// TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 6e-5 for x < 1.
// TM_ORDER3_16BIT: Max relative error < 6e-4, absolute error < 2e-3 for x < 1.
template <int kInputMantissaBits, TranscendentalMode kOrder = TM_ORDER4_16BIT>
inline void float_exp_float_preclipped(__m256& y0, __m256& y1) {
  // Divide by log 2 to convert problem to 2^x, and scale to match the
  // mantissa bits required by IEEE floats. Without a _mm256_mulhrs_epi32, it is
  // much easier to do this in float, even with the double conversion, as 16 bit
  // is not precise enough here.
  // This is the shift of the FP mantissa relative to the input mantissa.
  constexpr int kXShift = kFloatMantissaBits - kInputMantissaBits;
  constexpr float kLogFactor = static_cast<float>(1 << kXShift);
  __m256 factor = _mm256_set1_ps(kLogFactor * kOneOverLog2);
  y0 = _mm256_mul_ps(y0, factor);
  y1 = _mm256_mul_ps(y1, factor);
  // Now compute 2^x.
  float32_pow2<kOrder>(y0, y1);
}
template <int kInputMantissaBits, TranscendentalMode kOrder = TM_ORDER4_16BIT>
inline void fixed32_exp_float(const __m256i& x0, const __m256i& x1, __m256& y0,
                              __m256& y1) {
  // Clip to acceptable bounds to prevent overflow, and convert to float.
  y0 =
      _mm256_cvtepi32_ps(ClipToBounds<kMaxExpInputInt>(kInputMantissaBits, x0));
  y1 =
      _mm256_cvtepi32_ps(ClipToBounds<kMaxExpInputInt>(kInputMantissaBits, x1));
  float_exp_float_preclipped<kInputMantissaBits, kOrder>(y0, y1);
}

// Float->float tanh approximation via the exponent trick.
// Note that the input is scaled floats, as if converted raw from fixed16/32.
// Input: 2x registers containing 8x float scaled by input_mantissa_bits.
// Output: two registers containing 8x float.
// TM_ORDER4_FLOAT: Max relative error < 2.1e-5, absolute error < 2.3e-6.
// TM_ORDER4_16BIT: Max relative error < 1e-4, absolute error < 1.3e-5.
// TM_ORDER3_16BIT: Max relative error < 2.1e-3, absolute error < 3e-4.
template <int kInputMantissaBits, TranscendentalMode kOrder = TM_ORDER4_FLOAT>
inline void float_tanh_float(const __m256& x0, const __m256& x1, __m256& y0,
                             __m256& y1) {
  // Divide by log 2 to convert problem to 2^x, double (as we need exp(2x)) and
  // scale to the mantissa bits required by float32_pow2 all in one multiply.
  // This is the shift of the FP mantissa relative to the input mantissa.
  // Add one to double the input.
  const float kLogFactor =
      static_cast<float>(1 << (kFloatMantissaBits - kInputMantissaBits + 1));
  __m256 factor = _mm256_set1_ps(kLogFactor * kOneOverLog2);
  // Clip to suitable input bounds for tanh.
  __m256 clip_limit = _mm256_set1_ps(kMaxTanhInput * (1 << kInputMantissaBits));
  __m256 clip0 = _mm256_min_ps(x0, clip_limit);
  __m256 clip1 = _mm256_min_ps(x1, clip_limit);
  clip_limit = _mm256_set1_ps(-kMaxTanhInput * (1 << kInputMantissaBits));
  clip0 = _mm256_max_ps(clip0, clip_limit);
  clip1 = _mm256_max_ps(clip1, clip_limit);
  __m256 exp0 = _mm256_mul_ps(clip0, factor);
  __m256 exp1 = _mm256_mul_ps(clip1, factor);
  // Now compute 2^x.
  float32_pow2<kOrder>(exp0, exp1);
  // Now compute tanh using (e^2x - 1) / (e^2x + 1).
  __m256 one = _mm256_set1_ps(1.0f);
  __m256 numerator = _mm256_sub_ps(exp0, one);
  __m256 denominator = _mm256_add_ps(exp0, one);
  // Approximate reciprocal is not accurate enough - use full division.
  exp0 = _mm256_div_ps(numerator, denominator);
  numerator = _mm256_sub_ps(exp1, one);
  denominator = _mm256_add_ps(exp1, one);
  exp1 = _mm256_div_ps(numerator, denominator);
  // Compute 3rd-order Taylor tanh ~ x - x^3/3 for high accuracy and thus low
  // relative error close to 0.
  // Normalize the inputs back to proper floats.
  factor = _mm256_set1_ps(1.0f / (1 << kInputMantissaBits));
  clip0 = _mm256_mul_ps(clip0, factor);
  clip1 = _mm256_mul_ps(clip1, factor);
  __m256 third = _mm256_set1_ps(-1.0f / 3.0f);
  __m256 taylor0 = _mm256_mul_ps(clip0, clip0);
  __m256 taylor1 = _mm256_mul_ps(clip1, clip1);
  taylor0 = _mm256_mul_ps(taylor0, clip0);
  taylor1 = _mm256_mul_ps(taylor1, clip1);
  // TODO(b/191497441): The next two pairs of instructions could be combined to
  // _mm256_fmadd_ps, but requires -mfma compilation option, eg:
  // taylor0 = _mm256_fmadd_ps(taylor0, third, clip0);
  taylor0 = _mm256_mul_ps(taylor0, third);
  taylor1 = _mm256_mul_ps(taylor1, third);
  taylor0 = _mm256_add_ps(clip0, taylor0);
  taylor1 = _mm256_add_ps(clip1, taylor1);
  // Test |x| <= 1/9, roughly where the errors cross over, without needing yet
  // another constant.
  third = _mm256_mul_ps(third, third);
  __m256 neg_zero = _mm256_set1_ps(-0.0f);
  clip0 = _mm256_andnot_ps(neg_zero, clip0);
  clip1 = _mm256_andnot_ps(neg_zero, clip1);
  __m256 cmp_results0 = _mm256_cmp_ps(clip0, third, _CMP_LE_OQ);
  __m256 cmp_results1 = _mm256_cmp_ps(clip1, third, _CMP_LE_OQ);
  y0 = _mm256_blendv_ps(exp0, taylor0, cmp_results0);
  y1 = _mm256_blendv_ps(exp1, taylor1, cmp_results1);
}

// Fixed32 sigmoid approximation via the AVX2 implementation of the exponent
// trick.
// Input: 2x registers containins 8x float containing converted fixed32 scaled
// with kInputMantissaBits.
// Output: 2x registers containing 8x float.
// TM_ORDER4_FLOAT: Max relative error < 4e-6, absolute error < 1e-6.
// TM_ORDER4_16BIT: Max relative error < 3e-5, absolute error < 7e-6.
// TM_ORDER3_16BIT: Max relative error < 5.4e-4, absolute error < 1.4e-4.
template <int kInputMantissaBits, TranscendentalMode kOrder = TM_ORDER4_FLOAT>
inline void float_sigmoid_float(__m256& y0, __m256& y1) {
  constexpr float kInputFactor = static_cast<float>(1 << kInputMantissaBits);
  // Negate the inputs.
  __m256 minus_zero = _mm256_set1_ps(-0.0f);
  y0 = _mm256_xor_ps(y0, minus_zero);
  y1 = _mm256_xor_ps(y1, minus_zero);
  y0 = ClipToFloatBounds(kMaxSigmoidInput * kInputFactor, y0);
  y1 = ClipToFloatBounds(kMaxSigmoidInput * kInputFactor, y1);
  float_exp_float_preclipped<kInputMantissaBits, kOrder>(y0, y1);
  __m256 one = _mm256_set1_ps(1.0f);
  // Approximate reciprocal is not accurate enough - use full division.
  y0 = _mm256_div_ps(one, _mm256_add_ps(y0, one));
  y1 = _mm256_div_ps(one, _mm256_add_ps(y1, one));
}

#endif  // defined __AVX2__

}  // namespace csrblocksparse

#endif  // LYRA_CODEC_SPARSE_MATMUL_NUMERICS_FAST_TRANSCENDENTALS_H_