Spaces:
Runtime error
Runtime error
NTT123
commited on
Commit
•
d1a84ee
1
Parent(s):
df1ad02
add fast cpp wavegru
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- BUILD +44 -0
- WORKSPACE +154 -0
- app.py +12 -1
- inference.py +7 -6
- packages.txt +2 -1
- sparse_matmul/BUILD +22 -0
- sparse_matmul/compute/BUILD +88 -0
- sparse_matmul/compute/ar_inputs.h +37 -0
- sparse_matmul/compute/gru_gates.h +214 -0
- sparse_matmul/compute/gru_gates_arm.h +288 -0
- sparse_matmul/compute/gru_gates_avx_fixed.h +348 -0
- sparse_matmul/compute/gru_gates_generic.h +97 -0
- sparse_matmul/compute/gru_gates_test.cc +164 -0
- sparse_matmul/compute/kernels_arm.h +0 -0
- sparse_matmul/compute/kernels_avx.h +601 -0
- sparse_matmul/compute/kernels_generic.h +273 -0
- sparse_matmul/compute/matmul.h +199 -0
- sparse_matmul/compute/matmul_fixed_avx2.cc +235 -0
- sparse_matmul/compute/matmul_fixed_avx2.h +49 -0
- sparse_matmul/compute/matmul_generic.cc +122 -0
- sparse_matmul/compute/matmul_generic.h +41 -0
- sparse_matmul/compute/thread_bounds.cc +106 -0
- sparse_matmul/compute/thread_bounds.h +74 -0
- sparse_matmul/layers/BUILD +146 -0
- sparse_matmul/layers/csr_blocksparse_matrix.h +835 -0
- sparse_matmul/layers/csrblocksparse_test.cc +977 -0
- sparse_matmul/layers/errno_mapping.cc +195 -0
- sparse_matmul/layers/errno_mapping.h +29 -0
- sparse_matmul/layers/masked_sparse_matrix.h +206 -0
- sparse_matmul/layers/read_array_ifstream.h +66 -0
- sparse_matmul/layers/sparse_linear_layer.h +365 -0
- sparse_matmul/layers/sparse_linear_layer_test.cc +187 -0
- sparse_matmul/layers/status_macros.h +34 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_weights.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_bias.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_mask.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_weights.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_bias.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_mask.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_weights.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_bias.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_mask.raw.gz +3 -0
- sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_weights.raw.gz +3 -0
BUILD
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# [internal] load cc_fuzz_target.bzl
|
2 |
+
# [internal] load cc_proto_library.bzl
|
3 |
+
# [internal] load android_cc_test:def.bzl
|
4 |
+
|
5 |
+
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
|
6 |
+
|
7 |
+
package(default_visibility = [":__subpackages__"])
|
8 |
+
|
9 |
+
licenses(["notice"])
|
10 |
+
|
11 |
+
# To run all cc_tests in this directory:
|
12 |
+
# bazel test //:all
|
13 |
+
|
14 |
+
# [internal] Command to run dsp_util_android_test.
|
15 |
+
|
16 |
+
# [internal] Command to run lyra_integration_android_test.
|
17 |
+
|
18 |
+
exports_files(
|
19 |
+
srcs = [
|
20 |
+
"wavegru_mod.cc",
|
21 |
+
],
|
22 |
+
)
|
23 |
+
|
24 |
+
pybind_extension(
|
25 |
+
name = "wavegru_mod", # This name is not actually created!
|
26 |
+
srcs = ["wavegru_mod.cc"],
|
27 |
+
deps = [
|
28 |
+
"//sparse_matmul",
|
29 |
+
],
|
30 |
+
)
|
31 |
+
|
32 |
+
py_library(
|
33 |
+
name = "wavegru_mod",
|
34 |
+
data = [":wavegru_mod.so"],
|
35 |
+
)
|
36 |
+
|
37 |
+
py_binary(
|
38 |
+
name = "wavegru",
|
39 |
+
srcs = ["wavegru.py"],
|
40 |
+
deps = [
|
41 |
+
":wavegru_mod"
|
42 |
+
],
|
43 |
+
)
|
44 |
+
|
WORKSPACE
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
########################
|
2 |
+
# Platform Independent #
|
3 |
+
########################
|
4 |
+
|
5 |
+
load("@bazel_tools//tools/build_defs/repo:git.bzl", "git_repository", "new_git_repository")
|
6 |
+
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
7 |
+
|
8 |
+
# GoogleTest/GoogleMock framework.
|
9 |
+
git_repository(
|
10 |
+
name = "com_google_googletest",
|
11 |
+
remote = "https://github.com/google/googletest.git",
|
12 |
+
tag = "release-1.10.0",
|
13 |
+
)
|
14 |
+
|
15 |
+
# Google benchmark.
|
16 |
+
http_archive(
|
17 |
+
name = "com_github_google_benchmark",
|
18 |
+
urls = ["https://github.com/google/benchmark/archive/bf585a2789e30585b4e3ce6baf11ef2750b54677.zip"], # 2020-11-26T11:14:03Z
|
19 |
+
strip_prefix = "benchmark-bf585a2789e30585b4e3ce6baf11ef2750b54677",
|
20 |
+
sha256 = "2a778d821997df7d8646c9c59b8edb9a573a6e04c534c01892a40aa524a7b68c",
|
21 |
+
)
|
22 |
+
|
23 |
+
# proto_library, cc_proto_library, and java_proto_library rules implicitly
|
24 |
+
# depend on @com_google_protobuf for protoc and proto runtimes.
|
25 |
+
# This statement defines the @com_google_protobuf repo.
|
26 |
+
git_repository(
|
27 |
+
name = "com_google_protobuf",
|
28 |
+
remote = "https://github.com/protocolbuffers/protobuf.git",
|
29 |
+
tag = "v3.15.4",
|
30 |
+
)
|
31 |
+
|
32 |
+
load("@com_google_protobuf//:protobuf_deps.bzl", "protobuf_deps")
|
33 |
+
protobuf_deps()
|
34 |
+
|
35 |
+
# Google Abseil Libs
|
36 |
+
git_repository(
|
37 |
+
name = "com_google_absl",
|
38 |
+
remote = "https://github.com/abseil/abseil-cpp.git",
|
39 |
+
branch = "lts_2020_09_23",
|
40 |
+
)
|
41 |
+
|
42 |
+
# Filesystem
|
43 |
+
# The new_* prefix is used because it is not a bazel project and there is
|
44 |
+
# no BUILD file in that repo.
|
45 |
+
FILESYSTEM_BUILD = """
|
46 |
+
cc_library(
|
47 |
+
name = "filesystem",
|
48 |
+
hdrs = glob(["include/ghc/*"]),
|
49 |
+
visibility = ["//visibility:public"],
|
50 |
+
)
|
51 |
+
"""
|
52 |
+
|
53 |
+
new_git_repository(
|
54 |
+
name = "gulrak_filesystem",
|
55 |
+
remote = "https://github.com/gulrak/filesystem.git",
|
56 |
+
tag = "v1.3.6",
|
57 |
+
build_file_content = FILESYSTEM_BUILD
|
58 |
+
)
|
59 |
+
|
60 |
+
# Audio DSP
|
61 |
+
git_repository(
|
62 |
+
name = "com_google_audio_dsp",
|
63 |
+
remote = "https://github.com/google/multichannel-audio-tools.git",
|
64 |
+
# There are no tags for this repo, we are synced to bleeding edge.
|
65 |
+
branch = "master",
|
66 |
+
repo_mapping = {
|
67 |
+
"@com_github_glog_glog" : "@com_google_glog"
|
68 |
+
}
|
69 |
+
)
|
70 |
+
|
71 |
+
|
72 |
+
http_archive(
|
73 |
+
name = "pybind11_bazel",
|
74 |
+
strip_prefix = "pybind11_bazel-72cbbf1fbc830e487e3012862b7b720001b70672",
|
75 |
+
urls = ["https://github.com/pybind/pybind11_bazel/archive/72cbbf1fbc830e487e3012862b7b720001b70672.zip"],
|
76 |
+
)
|
77 |
+
# We still require the pybind library.
|
78 |
+
http_archive(
|
79 |
+
name = "pybind11",
|
80 |
+
build_file = "@pybind11_bazel//:pybind11.BUILD",
|
81 |
+
strip_prefix = "pybind11-2.9.0",
|
82 |
+
urls = ["https://github.com/pybind/pybind11/archive/v2.9.0.tar.gz"],
|
83 |
+
)
|
84 |
+
load("@pybind11_bazel//:python_configure.bzl", "python_configure")
|
85 |
+
python_configure(name = "local_config_python")
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
# Transitive dependencies of Audio DSP.
|
90 |
+
http_archive(
|
91 |
+
name = "eigen_archive",
|
92 |
+
build_file = "eigen.BUILD",
|
93 |
+
sha256 = "f3d69ac773ecaf3602cb940040390d4e71a501bb145ca9e01ce5464cf6d4eb68",
|
94 |
+
strip_prefix = "eigen-eigen-049af2f56331",
|
95 |
+
urls = [
|
96 |
+
"http://mirror.tensorflow.org/bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
|
97 |
+
"https://bitbucket.org/eigen/eigen/get/049af2f56331.tar.gz",
|
98 |
+
],
|
99 |
+
)
|
100 |
+
|
101 |
+
http_archive(
|
102 |
+
name = "fft2d",
|
103 |
+
build_file = "fft2d.BUILD",
|
104 |
+
sha256 = "ada7e99087c4ed477bfdf11413f2ba8db8a840ba9bbf8ac94f4f3972e2a7cec9",
|
105 |
+
urls = [
|
106 |
+
"http://www.kurims.kyoto-u.ac.jp/~ooura/fft2d.tgz",
|
107 |
+
],
|
108 |
+
)
|
109 |
+
|
110 |
+
# Google logging
|
111 |
+
git_repository(
|
112 |
+
name = "com_google_glog",
|
113 |
+
remote = "https://github.com/google/glog.git",
|
114 |
+
branch = "master"
|
115 |
+
)
|
116 |
+
# Dependency for glog
|
117 |
+
git_repository(
|
118 |
+
name = "com_github_gflags_gflags",
|
119 |
+
remote = "https://github.com/mchinen/gflags.git",
|
120 |
+
branch = "android_linking_fix"
|
121 |
+
)
|
122 |
+
|
123 |
+
# Bazel/build rules
|
124 |
+
|
125 |
+
http_archive(
|
126 |
+
name = "bazel_skylib",
|
127 |
+
urls = [
|
128 |
+
"https://mirror.bazel.build/github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
|
129 |
+
"https://github.com/bazelbuild/bazel-skylib/releases/download/1.0.2/bazel-skylib-1.0.2.tar.gz",
|
130 |
+
],
|
131 |
+
sha256 = "97e70364e9249702246c0e9444bccdc4b847bed1eb03c5a3ece4f83dfe6abc44",
|
132 |
+
)
|
133 |
+
load("@bazel_skylib//:workspace.bzl", "bazel_skylib_workspace")
|
134 |
+
bazel_skylib_workspace()
|
135 |
+
|
136 |
+
http_archive(
|
137 |
+
name = "rules_android",
|
138 |
+
sha256 = "cd06d15dd8bb59926e4d65f9003bfc20f9da4b2519985c27e190cddc8b7a7806",
|
139 |
+
strip_prefix = "rules_android-0.1.1",
|
140 |
+
urls = ["https://github.com/bazelbuild/rules_android/archive/v0.1.1.zip"],
|
141 |
+
)
|
142 |
+
|
143 |
+
# Google Maven Repository
|
144 |
+
GMAVEN_TAG = "20180625-1"
|
145 |
+
|
146 |
+
http_archive(
|
147 |
+
name = "gmaven_rules",
|
148 |
+
strip_prefix = "gmaven_rules-%s" % GMAVEN_TAG,
|
149 |
+
url = "https://github.com/bazelbuild/gmaven_rules/archive/%s.tar.gz" % GMAVEN_TAG,
|
150 |
+
)
|
151 |
+
|
152 |
+
load("@gmaven_rules//:gmaven.bzl", "gmaven_rules")
|
153 |
+
|
154 |
+
gmaven_rules()
|
app.py
CHANGED
@@ -1,6 +1,14 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
from inference import load_tacotron_model, load_wavegru_net, text_to_mel, mel_to_wav
|
|
|
|
|
4 |
|
5 |
alphabet, tacotron_net, tacotron_config = load_tacotron_model(
|
6 |
"./alphabet.txt", "./tacotron.toml", "./pretrained_model_ljs_500k.ckpt"
|
@@ -11,10 +19,13 @@ wavegru_config, wavegru_net = load_wavegru_net(
|
|
11 |
"./wavegru.yaml", "./wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt"
|
12 |
)
|
13 |
|
|
|
|
|
|
|
14 |
|
15 |
def speak(text):
|
16 |
mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
|
17 |
-
y = mel_to_wav(wavegru_net, mel, wavegru_config)
|
18 |
return 24_000, y
|
19 |
|
20 |
|
|
|
1 |
import gradio as gr
|
2 |
+
import os
|
3 |
+
|
4 |
+
|
5 |
+
## build wavegru-cpp
|
6 |
+
os.system("go get github.com/bazelbuild/bazelisk")
|
7 |
+
os.system("bazelisk build wavegru_mod -c opt --copt=-march=native")
|
8 |
|
9 |
from inference import load_tacotron_model, load_wavegru_net, text_to_mel, mel_to_wav
|
10 |
+
from wavegru_cpp import load_wavegru_cpp, extract_weight_mask
|
11 |
+
|
12 |
|
13 |
alphabet, tacotron_net, tacotron_config = load_tacotron_model(
|
14 |
"./alphabet.txt", "./tacotron.toml", "./pretrained_model_ljs_500k.ckpt"
|
|
|
19 |
"./wavegru.yaml", "./wavegru_vocoder_tpu_gta_preemphasis_pruning_v7_0040000.ckpt"
|
20 |
)
|
21 |
|
22 |
+
wave_cpp_weight_mask = extract_weight_mask(wavegru_net)
|
23 |
+
wavecpp = load_wavegru_cpp(wave_cpp_weight_mask)
|
24 |
+
|
25 |
|
26 |
def speak(text):
|
27 |
mel = text_to_mel(tacotron_net, text, alphabet, tacotron_config)
|
28 |
+
y = mel_to_wav(wavegru_net, wavecpp, mel, wavegru_config)
|
29 |
return 24_000, y
|
30 |
|
31 |
|
inference.py
CHANGED
@@ -56,10 +56,10 @@ def load_wavegru_net(config_file, model_file):
|
|
56 |
return config, net
|
57 |
|
58 |
|
59 |
-
wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=
|
60 |
|
61 |
|
62 |
-
def mel_to_wav(net, mel, config):
|
63 |
"""convert mel to wav"""
|
64 |
if len(mel.shape) == 2:
|
65 |
mel = mel[None]
|
@@ -69,10 +69,11 @@ def mel_to_wav(net, mel, config):
|
|
69 |
[(0, 0), (pad, pad), (0, 0)],
|
70 |
constant_values=np.log(config["mel_min"]),
|
71 |
)
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
wav =
|
|
|
76 |
wav = librosa.effects.deemphasis(wav, coef=0.86)
|
77 |
wav = wav * 2.0
|
78 |
wav = wav / max(1.0, np.max(np.abs(wav)))
|
|
|
56 |
return config, net
|
57 |
|
58 |
|
59 |
+
wavegru_inference = pax.pure(lambda net, mel: net.inference(mel, no_gru=True))
|
60 |
|
61 |
|
62 |
+
def mel_to_wav(net, netcpp, mel, config):
|
63 |
"""convert mel to wav"""
|
64 |
if len(mel.shape) == 2:
|
65 |
mel = mel[None]
|
|
|
69 |
[(0, 0), (pad, pad), (0, 0)],
|
70 |
constant_values=np.log(config["mel_min"]),
|
71 |
)
|
72 |
+
ft = wavegru_inference(net, mel)
|
73 |
+
ft = jax.device_get(ft[0])
|
74 |
+
wav = netcpp.inference(ft, 1.0)
|
75 |
+
wav = np.array(wav)
|
76 |
+
wav = librosa.mu_expand(wav - 127, mu=255)
|
77 |
wav = librosa.effects.deemphasis(wav, coef=0.86)
|
78 |
wav = wav * 2.0
|
79 |
wav = wav / max(1.0, np.max(np.abs(wav)))
|
packages.txt
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
libsndfile1-dev
|
|
|
|
1 |
+
libsndfile1-dev
|
2 |
+
golang-go
|
sparse_matmul/BUILD
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# [internal] load placeholder
|
2 |
+
|
3 |
+
licenses(["notice"])
|
4 |
+
|
5 |
+
cc_library(
|
6 |
+
name = "sparse_matmul",
|
7 |
+
hdrs = [
|
8 |
+
"sparse_matmul.h",
|
9 |
+
],
|
10 |
+
visibility = ["//visibility:public"],
|
11 |
+
deps = [
|
12 |
+
"//sparse_matmul/compute:gru_gates",
|
13 |
+
"//sparse_matmul/layers:layer",
|
14 |
+
"//sparse_matmul/layers:matrix",
|
15 |
+
"//sparse_matmul/layers:utils",
|
16 |
+
"//sparse_matmul/numerics:fast_transcendentals",
|
17 |
+
"//sparse_matmul/numerics:types",
|
18 |
+
"//sparse_matmul/os:coop_threads",
|
19 |
+
"//sparse_matmul/vector:cache_aligned_vector",
|
20 |
+
], # internal :sparse_matmul deps placeholder
|
21 |
+
)
|
22 |
+
|
sparse_matmul/compute/BUILD
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Low-level computation code, including generic and architecture-specific
|
2 |
+
# variants.
|
3 |
+
|
4 |
+
licenses(["notice"])
|
5 |
+
|
6 |
+
cc_library(
|
7 |
+
name = "gru_gates",
|
8 |
+
srcs = [
|
9 |
+
"ar_inputs.h",
|
10 |
+
"gru_gates_arm.h",
|
11 |
+
"gru_gates_avx_fixed.h",
|
12 |
+
"gru_gates_generic.h",
|
13 |
+
],
|
14 |
+
hdrs = ["gru_gates.h"],
|
15 |
+
visibility = [
|
16 |
+
"//visibility:public",
|
17 |
+
],
|
18 |
+
deps = [
|
19 |
+
":matmul",
|
20 |
+
"//sparse_matmul/numerics:fast_transcendentals",
|
21 |
+
"//sparse_matmul/numerics:types",
|
22 |
+
"//sparse_matmul/vector:cache_aligned_vector",
|
23 |
+
],
|
24 |
+
)
|
25 |
+
|
26 |
+
cc_library(
|
27 |
+
name = "kernels",
|
28 |
+
srcs = [
|
29 |
+
"kernels_arm.h",
|
30 |
+
"kernels_avx.h",
|
31 |
+
],
|
32 |
+
hdrs = [
|
33 |
+
"kernels_generic.h",
|
34 |
+
],
|
35 |
+
visibility = [
|
36 |
+
"//sparse_matmul:__subpackages__",
|
37 |
+
],
|
38 |
+
deps = [
|
39 |
+
"//sparse_matmul/numerics:fast_transcendentals",
|
40 |
+
"//sparse_matmul/numerics:types",
|
41 |
+
],
|
42 |
+
)
|
43 |
+
|
44 |
+
cc_library(
|
45 |
+
name = "matmul",
|
46 |
+
srcs = [
|
47 |
+
"matmul_fixed_avx2.cc",
|
48 |
+
"matmul_fixed_avx2.h",
|
49 |
+
"matmul_generic.cc",
|
50 |
+
"matmul_generic.h",
|
51 |
+
],
|
52 |
+
hdrs = [
|
53 |
+
"matmul.h",
|
54 |
+
],
|
55 |
+
visibility = [
|
56 |
+
"//sparse_matmul:__subpackages__",
|
57 |
+
],
|
58 |
+
deps = [
|
59 |
+
"//sparse_matmul/numerics:types",
|
60 |
+
"@com_google_absl//absl/time",
|
61 |
+
],
|
62 |
+
)
|
63 |
+
|
64 |
+
cc_library(
|
65 |
+
name = "thread_bounds",
|
66 |
+
srcs = ["thread_bounds.cc"],
|
67 |
+
hdrs = ["thread_bounds.h"],
|
68 |
+
visibility = [
|
69 |
+
"//sparse_matmul:__subpackages__",
|
70 |
+
],
|
71 |
+
deps = [
|
72 |
+
"@com_google_glog//:glog",
|
73 |
+
],
|
74 |
+
)
|
75 |
+
|
76 |
+
cc_test(
|
77 |
+
name = "gru_gates_test",
|
78 |
+
size = "small",
|
79 |
+
srcs = [
|
80 |
+
"gru_gates_test.cc",
|
81 |
+
],
|
82 |
+
deps = [
|
83 |
+
":gru_gates",
|
84 |
+
"@com_google_absl//absl/memory",
|
85 |
+
"@com_google_absl//absl/types:span",
|
86 |
+
"@com_google_googletest//:gtest_main",
|
87 |
+
],
|
88 |
+
)
|
sparse_matmul/compute/ar_inputs.h
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_
|
19 |
+
|
20 |
+
namespace csrblocksparse {
|
21 |
+
|
22 |
+
// Possible numbers of Autoregressive inputs.
|
23 |
+
// TODO(b/188702959): Generalize to any non-negative integer value?
|
24 |
+
enum class ARInputsMode {
|
25 |
+
// There are no autoregressive inputs. Inputs to the GRU gates are strictly
|
26 |
+
// from the gate-recurrent matmul and other unrelated inputs.
|
27 |
+
k0ARInputs,
|
28 |
+
// Two autoregressive inputs, such as coarse and fine for WaveRNN.
|
29 |
+
k2ARInputs,
|
30 |
+
// Three autoregressive inputs, such as prev coarse and fine plus current
|
31 |
+
// coarse for WaveRNN.
|
32 |
+
k3ARInputs,
|
33 |
+
};
|
34 |
+
|
35 |
+
} // namespace csrblocksparse
|
36 |
+
|
37 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_AR_INPUTS_H_
|
sparse_matmul/compute/gru_gates.h
ADDED
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_
|
19 |
+
|
20 |
+
#include <cstdint>
|
21 |
+
#include <vector>
|
22 |
+
|
23 |
+
// IWYU pragma: begin_exports
|
24 |
+
#include "sparse_matmul/compute/ar_inputs.h"
|
25 |
+
#include "sparse_matmul/compute/gru_gates_arm.h"
|
26 |
+
#include "sparse_matmul/compute/gru_gates_avx_fixed.h"
|
27 |
+
#include "sparse_matmul/compute/gru_gates_generic.h"
|
28 |
+
#include "sparse_matmul/compute/matmul.h"
|
29 |
+
#include "sparse_matmul/numerics/fixed_types.h"
|
30 |
+
#include "sparse_matmul/numerics/type_utils.h"
|
31 |
+
#include "sparse_matmul/vector/cache_aligned_vector.h"
|
32 |
+
// IWYU pragma: end_exports
|
33 |
+
|
34 |
+
namespace csrblocksparse {
|
35 |
+
|
36 |
+
// The master template is really a catch-all for the unimplemented cases to
|
37 |
+
// run the generics.
|
38 |
+
template <typename GRUStateType, typename InputType, typename SampleType = void>
|
39 |
+
class GruGates : public MatmulBase {
|
40 |
+
public:
|
41 |
+
using SampleWeightType = float;
|
42 |
+
static constexpr int kSIMDWidth = kGenericSIMDWidth;
|
43 |
+
|
44 |
+
// Generic GRU function covers all uses for WaveRNN-like architectures and
|
45 |
+
// conditioning.
|
46 |
+
// Controlled by template parameters thus:
|
47 |
+
// - |kInputsMode| == |k0ARInputs|: There are no autoregressive inputs so
|
48 |
+
// |ar_sample0|, |ar_sample1|, |ar_sample2|, |ar_01_weights|,
|
49 |
+
// |ar_2_weights| are ignored.
|
50 |
+
// - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied
|
51 |
+
// by |ar_01_weights| and added to the (conditioning) input.
|
52 |
+
// - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by
|
53 |
+
// |ar_2_weights| and added to the other two |ar_inputs| (and added to the
|
54 |
+
// conditioning input).
|
55 |
+
// - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary
|
56 |
+
// recurrent input that must be added to |*gru_recurrent_ptr|.
|
57 |
+
// - |num_replicas| determines the number of duplicates of the output to be
|
58 |
+
// written, separated by |replica_stride|.
|
59 |
+
// - |start|, |end| are |rows| in [0, |state_size|] to be processed by this
|
60 |
+
// thread.
|
61 |
+
//
|
62 |
+
// Previous state is read from |*gru_state_ptr| and the new state is written
|
63 |
+
// to *(|gru_state_ptr| + i * |replica_stride| for i in [0, |num_replicas|)).
|
64 |
+
template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
|
65 |
+
bool kSplitGates = false>
|
66 |
+
void GruWithARInput(int start, int end, int state_size,
|
67 |
+
const InputType* gru_recurrent_ptr,
|
68 |
+
const InputType* input_ptr, GRUStateType* gru_state_ptr,
|
69 |
+
const SampleType* ar_sample0 = nullptr,
|
70 |
+
const SampleType* ar_sample1 = nullptr,
|
71 |
+
const SampleWeightType* ar_01_weights = nullptr,
|
72 |
+
int num_replicas = 1, int replica_stride = 0,
|
73 |
+
const SampleType* ar_sample2 = nullptr,
|
74 |
+
const SampleWeightType* ar_2_weights = nullptr,
|
75 |
+
const InputType* gru_recurrent_other_ptr = nullptr) {
|
76 |
+
CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica";
|
77 |
+
GoThroughGates<GRUStateType, InputType, SampleWeightType, SampleType,
|
78 |
+
kInputsMode, kSplitGates>(
|
79 |
+
start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr,
|
80 |
+
input_ptr, gru_state_ptr, ar_2_weights, state_size, ar_sample0,
|
81 |
+
ar_sample1, ar_sample2);
|
82 |
+
}
|
83 |
+
|
84 |
+
// No AR inputs, no split gates, no batching, no replicated outputs.
|
85 |
+
// TODO(b/188702959): Redirect conditioning GRU here, removing code from
|
86 |
+
// gru_layer.h.
|
87 |
+
// Copy to specializations.
|
88 |
+
void PlainGru(int start, int end, int state_size,
|
89 |
+
const InputType* gru_recurrent_ptr, const InputType* input_ptr,
|
90 |
+
GRUStateType* gru_state_ptr) {
|
91 |
+
GruWithARInput<ARInputsMode::k0ARInputs>(
|
92 |
+
start, end, state_size, gru_recurrent_ptr, input_ptr, gru_state_ptr);
|
93 |
+
}
|
94 |
+
};
|
95 |
+
|
96 |
+
#if defined __ARM_NEON || defined __aarch64__
|
97 |
+
// Partial specialization for float.
|
98 |
+
template <>
|
99 |
+
class GruGates<float, float, float> : public MatmulBase {
|
100 |
+
public:
|
101 |
+
static constexpr int kSIMDWidth = kNeonSIMDWidth;
|
102 |
+
|
103 |
+
// Generic GRU function covers all uses for WaveRNN-like architectures and
|
104 |
+
// conditioning.
|
105 |
+
template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
|
106 |
+
bool kSplitGates = false>
|
107 |
+
void GruWithARInput(int start, int end, int state_size,
|
108 |
+
const float* gru_recurrent_data, const float* input_data,
|
109 |
+
float* gru_state_data, const float* ar_sample0 = nullptr,
|
110 |
+
const float* ar_sample1 = nullptr,
|
111 |
+
const float* ar_01_weights = nullptr,
|
112 |
+
int num_replicas = 1, int replica_stride = 0,
|
113 |
+
const float* ar_sample2 = nullptr,
|
114 |
+
const float* ar_2_weights = nullptr,
|
115 |
+
const float* gru_recurrent_other_data = nullptr) {
|
116 |
+
DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica";
|
117 |
+
GoThroughGatesFloat<kInputsMode, kSplitGates>(
|
118 |
+
start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data,
|
119 |
+
input_data, gru_state_data, ar_2_weights, state_size, ar_sample0,
|
120 |
+
ar_sample1, ar_sample2);
|
121 |
+
}
|
122 |
+
};
|
123 |
+
#endif // defined __ARM_NEON || defined __aarch64__
|
124 |
+
|
125 |
+
// Partial specialization for fixed types. The sample weights are always float
|
126 |
+
// whatever the fixed type of the other weights.
|
127 |
+
template <int kGRUStateBits, int kInputBits, int kSampleBits>
|
128 |
+
class GruGates<fixed16<kGRUStateBits>, fixed32<kInputBits>,
|
129 |
+
fixed16<kSampleBits>> : public MatmulBase {
|
130 |
+
public:
|
131 |
+
#if defined __ARM_NEON || defined __aarch64__
|
132 |
+
static constexpr int kSIMDWidth = kNeonSIMDWidth;
|
133 |
+
#elif defined __AVX2__
|
134 |
+
static constexpr int kSIMDWidth = kAVX2SIMDWidth * 2;
|
135 |
+
#else // Generic case.
|
136 |
+
static constexpr int kSIMDWidth = kGenericSIMDWidth;
|
137 |
+
#endif // __ARM_NEON || defined __aarch64__ / __AVX2__
|
138 |
+
|
139 |
+
using GRUStateType = fixed16<kGRUStateBits>;
|
140 |
+
using InputType = fixed32<kInputBits>;
|
141 |
+
using SampleType = fixed16<kSampleBits>;
|
142 |
+
using SampleWeightType = float;
|
143 |
+
static constexpr int kInputMantissaBits = InputType::kMantissaBits;
|
144 |
+
static constexpr int kSampleMantissaBits = SampleType::kMantissaBits;
|
145 |
+
static constexpr int kStateMantissaBits = GRUStateType::kMantissaBits;
|
146 |
+
// Generic GRU function covers all uses for WaveRNN-like architectures and
|
147 |
+
// conditioning.
|
148 |
+
template <ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
|
149 |
+
bool kSplitGates = false>
|
150 |
+
void GruWithARInput(int start, int end, int state_size,
|
151 |
+
const InputType* gru_recurrent_data,
|
152 |
+
const InputType* input_data, GRUStateType* gru_state_data,
|
153 |
+
const SampleType* ar_sample0 = nullptr,
|
154 |
+
const SampleType* ar_sample1 = nullptr,
|
155 |
+
const SampleWeightType* ar_01_weights = nullptr,
|
156 |
+
int num_replicas = 1, int replica_stride = 0,
|
157 |
+
const SampleType* ar_sample2 = nullptr,
|
158 |
+
const SampleWeightType* ar_2_weights = nullptr,
|
159 |
+
const InputType* gru_recurrent_other_data = nullptr) {
|
160 |
+
#if defined __ARM_NEON || defined __aarch64__ || defined __AVX2__
|
161 |
+
const int32_t* gru_recurrent_ptr =
|
162 |
+
reinterpret_cast<const int32_t*>(gru_recurrent_data);
|
163 |
+
const int32_t* gru_recurrent_other_ptr =
|
164 |
+
reinterpret_cast<const int32_t*>(gru_recurrent_other_data);
|
165 |
+
const int32_t* input_ptr = reinterpret_cast<const int32_t*>(input_data);
|
166 |
+
int16_t* gru_state_ptr = reinterpret_cast<int16_t*>(gru_state_data);
|
167 |
+
#if defined __AVX2__
|
168 |
+
// The samples are fixed16, but we scale them up here and convert to float
|
169 |
+
// so that the product with the QR weights is always on the same scale as
|
170 |
+
// InputType, so we don't have to do any more scaling inside.
|
171 |
+
const float sample_factor = static_cast<float>(1 << kInputMantissaBits);
|
172 |
+
#else
|
173 |
+
const float sample_factor = 1.0f;
|
174 |
+
#endif
|
175 |
+
// AR sample 0 and 1 are packed into a pair because the QR weights are
|
176 |
+
// formatted with the weights interleaved for sample 0 and 1.
|
177 |
+
std::pair<float, float> ar_sample01;
|
178 |
+
float ar_sample2_float = 0.0f;
|
179 |
+
if (kInputsMode == ARInputsMode::k2ARInputs ||
|
180 |
+
kInputsMode == ARInputsMode::k3ARInputs) {
|
181 |
+
ar_sample01 = {static_cast<float>(*ar_sample0) * sample_factor,
|
182 |
+
static_cast<float>(*ar_sample1) * sample_factor};
|
183 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
184 |
+
ar_sample2_float = static_cast<float>(*ar_sample2) * sample_factor;
|
185 |
+
}
|
186 |
+
}
|
187 |
+
#if defined __AVX2__
|
188 |
+
CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
|
189 |
+
GruGatesAVXFixed<kInputMantissaBits, kStateMantissaBits, kInputsMode,
|
190 |
+
kSplitGates>(
|
191 |
+
start, end, state_size, gru_recurrent_ptr, input_ptr, &ar_sample01,
|
192 |
+
ar_01_weights, num_replicas, replica_stride, &ar_sample2_float,
|
193 |
+
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
|
194 |
+
#else // ARM.
|
195 |
+
DCHECK_EQ(num_replicas, 1) << "ARM code should always have 1 replica";
|
196 |
+
GoThroughGatesFixed<GRUStateType, InputType, kInputsMode, kSplitGates>(
|
197 |
+
start, end, ar_01_weights, gru_recurrent_ptr, gru_recurrent_other_ptr,
|
198 |
+
input_ptr, gru_state_ptr, ar_2_weights, state_size, &ar_sample01,
|
199 |
+
&ar_sample2_float);
|
200 |
+
#endif // __AVX2__ / ARM.
|
201 |
+
#else // Generic case.
|
202 |
+
CHECK_EQ(num_replicas, 1) << "Generic code should always have 1 replica";
|
203 |
+
GoThroughGates<GRUStateType, InputType, SampleWeightType, SampleType,
|
204 |
+
kInputsMode, kSplitGates>(
|
205 |
+
start, end, ar_01_weights, gru_recurrent_data, gru_recurrent_other_data,
|
206 |
+
input_data, gru_state_data, ar_2_weights, state_size, ar_sample0,
|
207 |
+
ar_sample1, ar_sample2);
|
208 |
+
#endif // __ARM_NEON || defined __aarch64__ / __AVX2__
|
209 |
+
}
|
210 |
+
};
|
211 |
+
|
212 |
+
} // namespace csrblocksparse
|
213 |
+
|
214 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_H_
|
sparse_matmul/compute/gru_gates_arm.h
ADDED
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_
|
19 |
+
|
20 |
+
#if defined __ARM_NEON || defined __aarch64__
|
21 |
+
#include <arm_neon.h>
|
22 |
+
#endif
|
23 |
+
#include <cstdint>
|
24 |
+
|
25 |
+
#include "sparse_matmul/compute/ar_inputs.h"
|
26 |
+
#include "sparse_matmul/numerics/fast_transcendentals.h"
|
27 |
+
|
28 |
+
namespace csrblocksparse {
|
29 |
+
|
30 |
+
static constexpr int kNeonSIMDWidth = 4;
|
31 |
+
|
32 |
+
// ------ Scalar calculation --------
|
33 |
+
// See "Efficient Neural Audio Synthesis" for a description of the calculation.
|
34 |
+
// https://arxiv.org/abs/1802.08435
|
35 |
+
//
|
36 |
+
// NOTE:
|
37 |
+
// |sample| = (|coarse_at_sminus1|, |fine_at_sminus1|,
|
38 |
+
// |coarse_at_sminus1|, |fine_at_sminus1|)
|
39 |
+
// |w_sample| = (|coarse_at_s|, |coarse_at_s|, |coarse_at_s|, |coarse_at_s|)
|
40 |
+
//
|
41 |
+
// CHEATSHEET:
|
42 |
+
// vld1q_f32 = load 4 32-bit floats
|
43 |
+
// vmulq_f32(a, b) : return a * b;
|
44 |
+
// vaddq_f32(a, b) : return a + b;
|
45 |
+
// vmlaq_f32(c, a, b) : return c + a * b;
|
46 |
+
// vpaddq_f32(a, b) : return (a0 + a1, a2 + a3, b0 + b1, b2 + b3)
|
47 |
+
// vsubq_f32(a, b) : return a - b;
|
48 |
+
// vst1q_f32 = store 4 32-bit floats
|
49 |
+
#if defined __ARM_NEON || defined __aarch64__
|
50 |
+
|
51 |
+
#if !defined __aarch64__
|
52 |
+
// Backport of vpaddq_f32 to ARM32.
|
53 |
+
inline float32x4_t vpaddq_f32(float32x4_t a, float32x4_t b) {
|
54 |
+
float32x2_t a10 = vget_low_f32(a);
|
55 |
+
float32x2_t a32 = vget_high_f32(a);
|
56 |
+
float32x2_t b10 = vget_low_f32(b);
|
57 |
+
float32x2_t b32 = vget_high_f32(b);
|
58 |
+
return vcombine_f32(vpadd_f32(a10, a32), vpadd_f32(b10, b32));
|
59 |
+
}
|
60 |
+
#endif
|
61 |
+
|
62 |
+
template <ARInputsMode kInputsMode, bool SplitGates>
|
63 |
+
void GoThroughGatesFloat(int start, int end, const float* qr_ptr,
|
64 |
+
const float* gru_gates_ptr,
|
65 |
+
const float* gru_gates_other_ptr,
|
66 |
+
const float* conditioning_ptr, float* gru_h_ptr,
|
67 |
+
const float* w_hat, int proj_size,
|
68 |
+
const float* coarse_at_sminus1,
|
69 |
+
const float* fine_at_sminus1,
|
70 |
+
const float* coarse_at_s) {
|
71 |
+
// Increment all the pointers to save on pointer arithmetic in the loop.
|
72 |
+
conditioning_ptr += start;
|
73 |
+
gru_h_ptr += start;
|
74 |
+
gru_gates_ptr += start;
|
75 |
+
if (SplitGates) {
|
76 |
+
DCHECK_NE(gru_gates_other_ptr, nullptr);
|
77 |
+
gru_gates_other_ptr += start;
|
78 |
+
}
|
79 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
80 |
+
DCHECK_NE(qr_ptr, nullptr);
|
81 |
+
qr_ptr += 2 * start;
|
82 |
+
DCHECK_NE(coarse_at_sminus1, nullptr);
|
83 |
+
DCHECK_NE(fine_at_sminus1, nullptr);
|
84 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
85 |
+
DCHECK_NE(w_hat, nullptr);
|
86 |
+
DCHECK_NE(coarse_at_s, nullptr);
|
87 |
+
w_hat += start;
|
88 |
+
}
|
89 |
+
}
|
90 |
+
for (int i = start; i < end; i += kNeonSIMDWidth) {
|
91 |
+
float32x4_t reset = vld1q_f32(gru_gates_ptr);
|
92 |
+
float32x4_t update = vld1q_f32(gru_gates_ptr + proj_size);
|
93 |
+
float32x4_t cell = vld1q_f32(gru_gates_ptr + 2 * proj_size);
|
94 |
+
float32x4_t qr_cell;
|
95 |
+
if (SplitGates) {
|
96 |
+
reset = vaddq_f32(reset, vld1q_f32(gru_gates_other_ptr));
|
97 |
+
update = vaddq_f32(update, vld1q_f32(gru_gates_other_ptr + proj_size));
|
98 |
+
cell = vaddq_f32(cell, vld1q_f32(gru_gates_other_ptr + 2 * proj_size));
|
99 |
+
}
|
100 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
101 |
+
// Setup the sample vector.
|
102 |
+
float32x4_t sample = vdupq_n_f32(*coarse_at_sminus1);
|
103 |
+
sample = vsetq_lane_f32(*fine_at_sminus1, sample, 1);
|
104 |
+
sample = vsetq_lane_f32(*fine_at_sminus1, sample, 3);
|
105 |
+
|
106 |
+
// All auto types are float32x4_t, auto used to fit statements on one line
|
107 |
+
// for readability. Do two rows of QR at once.
|
108 |
+
auto qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample);
|
109 |
+
auto qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample);
|
110 |
+
auto qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1);
|
111 |
+
|
112 |
+
auto qr_update_0 = vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample);
|
113 |
+
auto qr_update_1 =
|
114 |
+
vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample);
|
115 |
+
auto qr_update = vpaddq_f32(qr_update_0, qr_update_1);
|
116 |
+
|
117 |
+
auto qr_cell_0 = vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample);
|
118 |
+
auto qr_cell_1 = vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample);
|
119 |
+
qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1);
|
120 |
+
|
121 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
122 |
+
float32x4_t w_sample = vdupq_n_f32(*coarse_at_s);
|
123 |
+
qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample);
|
124 |
+
qr_update =
|
125 |
+
vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample);
|
126 |
+
qr_cell =
|
127 |
+
vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample);
|
128 |
+
}
|
129 |
+
reset = vaddq_f32(reset, qr_reset);
|
130 |
+
update = vaddq_f32(update, qr_update);
|
131 |
+
}
|
132 |
+
auto reset_conditioning = vld1q_f32(conditioning_ptr);
|
133 |
+
auto update_conditioning = vld1q_f32(conditioning_ptr + proj_size);
|
134 |
+
auto cell_conditioning = vld1q_f32(conditioning_ptr + 2 * proj_size);
|
135 |
+
|
136 |
+
reset = fast_sigmoid(vaddq_f32(reset, reset_conditioning));
|
137 |
+
update = fast_sigmoid(vaddq_f32(update, update_conditioning));
|
138 |
+
if (kInputsMode == ARInputsMode::k0ARInputs) {
|
139 |
+
cell = vmulq_f32(reset, cell);
|
140 |
+
} else {
|
141 |
+
cell = vmlaq_f32(qr_cell, reset, cell);
|
142 |
+
}
|
143 |
+
auto hbar = fast_tanh(vaddq_f32(cell, cell_conditioning));
|
144 |
+
|
145 |
+
auto prev_h = vld1q_f32(gru_h_ptr);
|
146 |
+
auto diff = vsubq_f32(prev_h, hbar);
|
147 |
+
auto new_h = vmlaq_f32(hbar, diff, update);
|
148 |
+
|
149 |
+
vst1q_f32(gru_h_ptr, new_h);
|
150 |
+
// Increment all the pointers.
|
151 |
+
conditioning_ptr += kNeonSIMDWidth;
|
152 |
+
gru_h_ptr += kNeonSIMDWidth;
|
153 |
+
gru_gates_ptr += kNeonSIMDWidth;
|
154 |
+
if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth;
|
155 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
156 |
+
qr_ptr += 2 * kNeonSIMDWidth;
|
157 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth;
|
158 |
+
}
|
159 |
+
}
|
160 |
+
}
|
161 |
+
|
162 |
+
// This version should only be used if all of the 32-bit fixed point
|
163 |
+
// representations have the same number of mantissa bits.
|
164 |
+
// |ar_at_sminus1| packs sample 0 and 1 into a pair because the QR weights are
|
165 |
+
// formatted with the weights interleaved for sample 0 and 1. The two samples
|
166 |
+
// represent coarse and fine for WaveRNN.
|
167 |
+
template <typename GRUStateType, typename GRUMatMulOutType,
|
168 |
+
ARInputsMode kInputsMode, bool SplitGates>
|
169 |
+
void GoThroughGatesFixed(int start, int end, const float* qr_ptr,
|
170 |
+
const int32_t* gru_gates_ptr,
|
171 |
+
const int32_t* gru_gates_other_ptr,
|
172 |
+
const int32_t* conditioning_ptr, int16_t* gru_h_ptr,
|
173 |
+
const float* w_hat, int proj_size,
|
174 |
+
const std::pair<float, float>* ar_at_sminus1,
|
175 |
+
const float* coarse_at_s) {
|
176 |
+
// Increment all the pointers to save on pointer arithmetic in the loop.
|
177 |
+
conditioning_ptr += start;
|
178 |
+
gru_h_ptr += start;
|
179 |
+
gru_gates_ptr += start;
|
180 |
+
if (SplitGates) {
|
181 |
+
DCHECK_NE(gru_gates_other_ptr, nullptr);
|
182 |
+
gru_gates_other_ptr += start;
|
183 |
+
}
|
184 |
+
float32x4_t sample01;
|
185 |
+
float32x4_t w_sample;
|
186 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
187 |
+
DCHECK_NE(qr_ptr, nullptr);
|
188 |
+
qr_ptr += 2 * start;
|
189 |
+
DCHECK_NE(ar_at_sminus1, nullptr);
|
190 |
+
sample01 = vdupq_n_f32(ar_at_sminus1->first);
|
191 |
+
sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 1);
|
192 |
+
sample01 = vsetq_lane_f32(ar_at_sminus1->second, sample01, 3);
|
193 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
194 |
+
DCHECK_NE(w_hat, nullptr);
|
195 |
+
DCHECK_NE(coarse_at_s, nullptr);
|
196 |
+
w_hat += start;
|
197 |
+
w_sample = vdupq_n_f32(*coarse_at_s);
|
198 |
+
}
|
199 |
+
}
|
200 |
+
for (int i = start; i < end; i += kNeonSIMDWidth) {
|
201 |
+
auto reset = vld1q_s32(gru_gates_ptr);
|
202 |
+
auto update = vld1q_s32(gru_gates_ptr + proj_size);
|
203 |
+
// vcvtq_n_f32_s32 = convert 32-bit fixed point to fp32
|
204 |
+
auto cell_int = vld1q_s32(gru_gates_ptr + 2 * proj_size);
|
205 |
+
if (SplitGates) {
|
206 |
+
reset = vaddq_s32(reset, vld1q_s32(gru_gates_other_ptr));
|
207 |
+
update = vaddq_s32(update, vld1q_s32(gru_gates_other_ptr + proj_size));
|
208 |
+
cell_int =
|
209 |
+
vaddq_s32(cell_int, vld1q_s32(gru_gates_other_ptr + 2 * proj_size));
|
210 |
+
}
|
211 |
+
float32x4_t cell =
|
212 |
+
vcvtq_n_f32_s32(cell_int, GRUMatMulOutType::kMantissaBits);
|
213 |
+
float32x4_t qr_cell;
|
214 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
215 |
+
// Do two rows of QR at once.
|
216 |
+
float32x4_t qr_reset_0 = vmulq_f32(vld1q_f32(qr_ptr), sample01);
|
217 |
+
float32x4_t qr_reset_1 = vmulq_f32(vld1q_f32(qr_ptr + 4), sample01);
|
218 |
+
float32x4_t qr_reset = vpaddq_f32(qr_reset_0, qr_reset_1);
|
219 |
+
|
220 |
+
float32x4_t qr_update_0 =
|
221 |
+
vmulq_f32(vld1q_f32(qr_ptr + 2 * proj_size), sample01);
|
222 |
+
float32x4_t qr_update_1 =
|
223 |
+
vmulq_f32(vld1q_f32(qr_ptr + 4 + 2 * proj_size), sample01);
|
224 |
+
float32x4_t qr_update = vpaddq_f32(qr_update_0, qr_update_1);
|
225 |
+
|
226 |
+
float32x4_t qr_cell_0 =
|
227 |
+
vmulq_f32(vld1q_f32(qr_ptr + 4 * proj_size), sample01);
|
228 |
+
float32x4_t qr_cell_1 =
|
229 |
+
vmulq_f32(vld1q_f32(qr_ptr + 4 + 4 * proj_size), sample01);
|
230 |
+
qr_cell = vpaddq_f32(qr_cell_0, qr_cell_1);
|
231 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
232 |
+
float32x4_t w_sample = vdupq_n_f32(*coarse_at_s);
|
233 |
+
qr_reset = vmlaq_f32(qr_reset, vld1q_f32(w_hat), w_sample);
|
234 |
+
qr_update =
|
235 |
+
vmlaq_f32(qr_update, vld1q_f32(w_hat + proj_size), w_sample);
|
236 |
+
qr_cell =
|
237 |
+
vmlaq_f32(qr_cell, vld1q_f32(w_hat + 2 * proj_size), w_sample);
|
238 |
+
}
|
239 |
+
reset = vaddq_s32(
|
240 |
+
reset, vcvtq_n_s32_f32(qr_reset, GRUMatMulOutType::kMantissaBits));
|
241 |
+
update = vaddq_s32(
|
242 |
+
update, vcvtq_n_s32_f32(qr_update, GRUMatMulOutType::kMantissaBits));
|
243 |
+
}
|
244 |
+
|
245 |
+
auto reset_conditioning = vld1q_s32(conditioning_ptr);
|
246 |
+
auto update_conditioning = vld1q_s32(conditioning_ptr + proj_size);
|
247 |
+
float32x4_t cell_conditioning =
|
248 |
+
vcvtq_n_f32_s32(vld1q_s32(conditioning_ptr + 2 * proj_size),
|
249 |
+
GRUMatMulOutType::kMantissaBits);
|
250 |
+
|
251 |
+
float32x4_t reset_f32 = fast_sigmoid<GRUMatMulOutType::kExponentBits>(
|
252 |
+
vaddq_s32(reset, reset_conditioning));
|
253 |
+
float32x4_t update_f32 = fast_sigmoid<GRUMatMulOutType::kExponentBits>(
|
254 |
+
vaddq_s32(update, update_conditioning));
|
255 |
+
if (kInputsMode == ARInputsMode::k0ARInputs) {
|
256 |
+
cell = vmulq_f32(reset_f32, cell);
|
257 |
+
} else {
|
258 |
+
cell = vmlaq_f32(qr_cell, reset_f32, cell);
|
259 |
+
}
|
260 |
+
float32x4_t hbar = fast_tanh(vaddq_f32(cell, cell_conditioning));
|
261 |
+
|
262 |
+
float32x4_t prev_h = vcvtq_n_f32_s32(vmovl_s16(vld1_s16(gru_h_ptr)),
|
263 |
+
GRUStateType::kMantissaBits);
|
264 |
+
float32x4_t diff = vsubq_f32(prev_h, hbar);
|
265 |
+
float32x4_t new_h = vmlaq_f32(hbar, diff, update_f32);
|
266 |
+
|
267 |
+
// vcvtq_n_s32_f32 = convert fp32 to signed 32-bit fixed point
|
268 |
+
// vqrshrn_n_s32 = saturating, rounding, narrowing right shift - used to
|
269 |
+
// convert a 32-bit fixed point value to a 16-bit fixed point value
|
270 |
+
vst1_s16(gru_h_ptr,
|
271 |
+
vqrshrn_n_s32(
|
272 |
+
vcvtq_n_s32_f32(new_h, GRUStateType::kMantissaBits + 16), 16));
|
273 |
+
// Increment all the pointers.
|
274 |
+
conditioning_ptr += kNeonSIMDWidth;
|
275 |
+
gru_h_ptr += kNeonSIMDWidth;
|
276 |
+
gru_gates_ptr += kNeonSIMDWidth;
|
277 |
+
if (SplitGates) gru_gates_other_ptr += kNeonSIMDWidth;
|
278 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
279 |
+
qr_ptr += 2 * kNeonSIMDWidth;
|
280 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) w_hat += kNeonSIMDWidth;
|
281 |
+
}
|
282 |
+
}
|
283 |
+
}
|
284 |
+
#endif // defined __ARM_NEON || defined __aarch64__
|
285 |
+
|
286 |
+
} // namespace csrblocksparse
|
287 |
+
|
288 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_ARM_H_
|
sparse_matmul/compute/gru_gates_avx_fixed.h
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_
|
19 |
+
|
20 |
+
#include <cstdint>
|
21 |
+
#if defined __AVX2__
|
22 |
+
#include <immintrin.h>
|
23 |
+
#endif
|
24 |
+
#include <vector>
|
25 |
+
|
26 |
+
#include "sparse_matmul/compute/ar_inputs.h"
|
27 |
+
#include "sparse_matmul/numerics/fast_transcendentals.h"
|
28 |
+
|
29 |
+
namespace csrblocksparse {
|
30 |
+
|
31 |
+
#if defined __AVX2__
|
32 |
+
|
33 |
+
constexpr int kAVX2SIMDWidth = 8;
|
34 |
+
|
35 |
+
// Loads 8x fixed32 from |ptr0| and adds to |input|.
|
36 |
+
// If |kTwoInputs|, also loads from |ptr1| and adds that as well.
|
37 |
+
// Returns the 2 or 3-way sum.
|
38 |
+
template <bool kTwoInputs>
|
39 |
+
inline __m256i LoadAndAddFixed32(const int32_t* ptr0, const int32_t* ptr1,
|
40 |
+
const __m256i& input) {
|
41 |
+
__m256i data0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr0));
|
42 |
+
if (kTwoInputs) {
|
43 |
+
__m256i data1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr1));
|
44 |
+
data0 = _mm256_add_epi32(data0, data1);
|
45 |
+
}
|
46 |
+
return _mm256_add_epi32(data0, input);
|
47 |
+
}
|
48 |
+
|
49 |
+
// Loads 8x fixed32 from ptr0.
|
50 |
+
// If |kTwoInputs|, also loads from |ptr1| and adds.
|
51 |
+
// Multiplies the loaded values by the factor and adds to |input|, which also
|
52 |
+
// is converted to float.
|
53 |
+
// Returns the sum.
|
54 |
+
template <bool kTwoInputs>
|
55 |
+
inline __m256 LoadMultiplyAddToFloat(const int32_t* ptr0, const int32_t* ptr1,
|
56 |
+
const __m256& float_factor,
|
57 |
+
const __m256& input) {
|
58 |
+
__m256i data0 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr0));
|
59 |
+
if (kTwoInputs) {
|
60 |
+
__m256i data1 = _mm256_load_si256(reinterpret_cast<const __m256i*>(ptr1));
|
61 |
+
data0 = _mm256_add_epi32(data0, data1);
|
62 |
+
}
|
63 |
+
__m256 float_result = _mm256_cvtepi32_ps(data0);
|
64 |
+
float_result = _mm256_mul_ps(float_result, float_factor);
|
65 |
+
return _mm256_add_ps(float_result, input);
|
66 |
+
}
|
67 |
+
|
68 |
+
// Loads 16x float in 2x 8x registers from |ptr0_1| and multiplies by
|
69 |
+
// |input_pairs|, likewise formatted as 8x floats, alternating between the two
|
70 |
+
// AR inputs and sums each pair of results, making 8x float results.
|
71 |
+
// If |kThreeInputs|, also loads 8x float from |ptr2| and multiplies by
|
72 |
+
// |third_input|, which must be formatted as 8x float. The second product is
|
73 |
+
// added to the previous result.
|
74 |
+
// Returns the sum added to |accumulator|.
|
75 |
+
template <bool kThreeInputs>
|
76 |
+
inline __m256 MultiplyAddFloat(const __m256& input_pairs,
|
77 |
+
const __m256& third_input, const float* ptr0_1,
|
78 |
+
const float* ptr2, const __m256& accumulator) {
|
79 |
+
__m256 data_pair0 = _mm256_load_ps(ptr0_1);
|
80 |
+
__m256 data_pair1 = _mm256_load_ps(ptr0_1 + 8);
|
81 |
+
data_pair0 = _mm256_mul_ps(data_pair0, input_pairs);
|
82 |
+
data_pair1 = _mm256_mul_ps(data_pair1, input_pairs);
|
83 |
+
data_pair0 = _mm256_hadd_ps(data_pair0, data_pair1);
|
84 |
+
// Swap the middle 2 64 bit pairs to correct the hadd result.
|
85 |
+
data_pair0 = _mm256_permute4x64_pd((__m256d)data_pair0, 0xd8);
|
86 |
+
if (kThreeInputs) {
|
87 |
+
// Load 256 bits (8 x float) of data, then multiply-accumulate.
|
88 |
+
data_pair1 = _mm256_load_ps(ptr2);
|
89 |
+
data_pair1 = _mm256_mul_ps(data_pair1, third_input);
|
90 |
+
data_pair0 = _mm256_add_ps(data_pair0, data_pair1);
|
91 |
+
}
|
92 |
+
// Add conditioning.
|
93 |
+
return _mm256_add_ps(data_pair0, accumulator);
|
94 |
+
}
|
95 |
+
|
96 |
+
// Processes the tanh and the final combination, returns the new GRU state.
|
97 |
+
template <int kInputMantissaBits, int kStateMantissaBits, bool kSplitGates>
|
98 |
+
inline __m256i GRUComputeState(const __m256& cell0, const __m256& cell1,
|
99 |
+
const __m256& reset0, const __m256& reset1,
|
100 |
+
const __m256& update0, const __m256& update1,
|
101 |
+
const int32_t* gate_ptr,
|
102 |
+
const int32_t* gate_other_ptr,
|
103 |
+
const void* gru_h_ptr) {
|
104 |
+
// Multiply the cell gru output and the reset.
|
105 |
+
__m256 float_gru0 = LoadMultiplyAddToFloat<kSplitGates>(
|
106 |
+
gate_ptr, gate_other_ptr, reset0, cell0);
|
107 |
+
__m256 float_gru1 = LoadMultiplyAddToFloat<kSplitGates>(
|
108 |
+
gate_ptr + kAVX2SIMDWidth, gate_other_ptr + kAVX2SIMDWidth, reset1,
|
109 |
+
cell1);
|
110 |
+
// Compute tanh on the result.
|
111 |
+
__m256 hbar0, hbar1;
|
112 |
+
float_tanh_float<kInputMantissaBits, TM_ORDER4_FLOAT>(float_gru0, float_gru1,
|
113 |
+
hbar0, hbar1);
|
114 |
+
// Load the 16-bit previous gru state and update.
|
115 |
+
__m256i gru = _mm256_load_si256(reinterpret_cast<__m256i const*>(gru_h_ptr));
|
116 |
+
__m256 state_factor =
|
117 |
+
_mm256_set1_ps(1.0f / (static_cast<float>(1 << kStateMantissaBits)));
|
118 |
+
float_gru0 =
|
119 |
+
_mm256_cvtepi32_ps(_mm256_cvtepi16_epi32(_mm256_castsi256_si128(gru)));
|
120 |
+
float_gru1 = _mm256_cvtepi32_ps(
|
121 |
+
_mm256_cvtepi16_epi32(_mm256_extractf128_si256(gru, 1)));
|
122 |
+
float_gru0 = _mm256_mul_ps(float_gru0, state_factor);
|
123 |
+
float_gru1 = _mm256_mul_ps(float_gru1, state_factor);
|
124 |
+
float_gru0 = _mm256_sub_ps(float_gru0, hbar0);
|
125 |
+
float_gru1 = _mm256_sub_ps(float_gru1, hbar1);
|
126 |
+
float_gru0 = _mm256_mul_ps(float_gru0, update0);
|
127 |
+
float_gru1 = _mm256_mul_ps(float_gru1, update1);
|
128 |
+
state_factor = _mm256_set1_ps(static_cast<float>(1 << kStateMantissaBits));
|
129 |
+
float_gru0 = _mm256_add_ps(float_gru0, hbar0);
|
130 |
+
float_gru1 = _mm256_add_ps(float_gru1, hbar1);
|
131 |
+
float_gru0 = _mm256_mul_ps(float_gru0, state_factor);
|
132 |
+
float_gru1 = _mm256_mul_ps(float_gru1, state_factor);
|
133 |
+
return PackFloatsToFixed16(float_gru0, float_gru1);
|
134 |
+
}
|
135 |
+
|
136 |
+
// According to |kInputsMode|, processes 0, 2 or 3 autoregressive inputs and
|
137 |
+
// combines with |input| and |gates*|.
|
138 |
+
// With 2 AR inputs, loads 8x pairs of float from |pair_weights| and multiplies
|
139 |
+
// by |paired_ar|, likewise formatted as 8x float, but scaled such that the
|
140 |
+
// product with pair_weights is on the same scale as |*input| and |*gates0|,
|
141 |
+
// and sums each pair result, making 8x float results.
|
142 |
+
// If 3 AR inputs, also loads 8x float from |third_weights| and multiplies by
|
143 |
+
// |third_ar|, which must be formatted as 8x scaled floats. The second product
|
144 |
+
// is added to the previous result.
|
145 |
+
// Inputs, 8x fixed32 are loaded from |input|, and added to the total.
|
146 |
+
// Finally 8x fixed32 from |gates0| (and |gates1| if |kTwoGates|) are added as
|
147 |
+
// well.
|
148 |
+
// Returns the total sum as a float, but on the scale of |*input|.
|
149 |
+
template <bool kTwoGates, ARInputsMode kInputsMode>
|
150 |
+
inline __m256 GruInput32ToFloat(const __m256& paired_ar,
|
151 |
+
const __m256& third_ar,
|
152 |
+
const float* pair_weights,
|
153 |
+
const float* third_weights,
|
154 |
+
const int32_t* gates0, const int32_t* gates1,
|
155 |
+
const int32_t* input) {
|
156 |
+
__m256i data32 = _mm256_load_si256(reinterpret_cast<__m256i const*>(input));
|
157 |
+
data32 = LoadAndAddFixed32<kTwoGates>(gates0, gates1, data32);
|
158 |
+
__m256 float_data = _mm256_cvtepi32_ps(data32);
|
159 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
160 |
+
float_data = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>(
|
161 |
+
paired_ar, third_ar, pair_weights, third_weights, float_data);
|
162 |
+
}
|
163 |
+
return float_data;
|
164 |
+
}
|
165 |
+
|
166 |
+
// Generic GRU gates function controlled by template parameters thus:
|
167 |
+
// - |kInputBits|: the mantissa bits in |*input_ptr|, |*gru_recurrent_ptr|.
|
168 |
+
// - |kStateBits|: the mantissa_bits in |*gru_state_ptr|.
|
169 |
+
// - |kInputsMode == |k0ARInputs|: There are no autoregressive inputs so
|
170 |
+
// |ar_sample, |ar_sample1|, |ar_sample2|, |ar_01_weights|, |ar_2_weights| are
|
171 |
+
// ignored.
|
172 |
+
// - |kInputsMode| == |k2ARInputs|: |ar_sample0|, |ar_sample1| are multiplied by
|
173 |
+
// |ar_01_weights| and added to the (conditioning) input.
|
174 |
+
// - |kInputsMode| == |k3ARInputs|: |ar_sample2| is multiplied by |ar_2_weights|
|
175 |
+
// and added to the other two AR inputs (and added to the conditioning input).
|
176 |
+
// - |kReplicas| determines the number of duplicates of the output to be
|
177 |
+
// written, separated by |replica_stride|. If zero, then the number of
|
178 |
+
// replicas is variable and taken from the |replicas| argument.
|
179 |
+
// - If |kSplitGates| is true: The |*gru_recurrent_other_ptr| is secondary
|
180 |
+
// recurrent input that must be added to |*gru_recurrent_ptr|.
|
181 |
+
// - |start|, |end| are |rows| in [0, |state_size|] to be processed by this
|
182 |
+
// thread.
|
183 |
+
//
|
184 |
+
// Previous state is read from |*gru_state_ptr| and the new state is written to
|
185 |
+
// *(|gru_state_ptr| + i * |replica_stride| for i in [0, |kReplicas|]).
|
186 |
+
template <int kInputBits, int kStateBits,
|
187 |
+
ARInputsMode kInputsMode = ARInputsMode::k0ARInputs,
|
188 |
+
int kReplicas = 1, bool kSplitGates = false>
|
189 |
+
inline void GruGatesTemplate(
|
190 |
+
int start, int end, int state_size, int replicas, int replica_stride,
|
191 |
+
const int32_t* gru_recurrent_ptr, const int32_t* input_ptr,
|
192 |
+
const std::pair<float, float>* ar_sample01, const float* ar_01_weights,
|
193 |
+
const float* ar_sample2, const float* ar_2_weights,
|
194 |
+
const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) {
|
195 |
+
constexpr int kQRIncrement = kAVX2SIMDWidth;
|
196 |
+
// Increment all the pointers to save on pointer arithmetic in the loop.
|
197 |
+
input_ptr += start;
|
198 |
+
gru_state_ptr += start;
|
199 |
+
gru_recurrent_ptr += start;
|
200 |
+
if (kSplitGates) gru_recurrent_other_ptr += start;
|
201 |
+
__m256 ar_2_inputs, ar_3rd_input;
|
202 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
203 |
+
ar_01_weights += 2 * start;
|
204 |
+
ar_2_inputs = _mm256_castsi256_ps(
|
205 |
+
_mm256_set1_epi64x(*reinterpret_cast<const int64_t*>(ar_sample01)));
|
206 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
207 |
+
ar_2_weights += start;
|
208 |
+
ar_3rd_input = _mm256_set1_ps(*ar_sample2);
|
209 |
+
} else {
|
210 |
+
ar_3rd_input = {};
|
211 |
+
}
|
212 |
+
} else {
|
213 |
+
ar_2_inputs = {};
|
214 |
+
ar_3rd_input = {};
|
215 |
+
}
|
216 |
+
// The transcendentals handle 2x registers of data at once, so we have to do
|
217 |
+
// everything in duplicate.
|
218 |
+
for (int i = start; i < end; i += kQRIncrement * 2) {
|
219 |
+
// Load 8 pairs of fixed16s for each of reset, update and cell.
|
220 |
+
__m256 reset0 = GruInput32ToFloat<kSplitGates, kInputsMode>(
|
221 |
+
ar_2_inputs, ar_3rd_input, ar_01_weights, ar_2_weights,
|
222 |
+
gru_recurrent_ptr, gru_recurrent_other_ptr, input_ptr);
|
223 |
+
__m256 reset1 = GruInput32ToFloat<kSplitGates, kInputsMode>(
|
224 |
+
ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * kQRIncrement,
|
225 |
+
ar_2_weights + kQRIncrement, gru_recurrent_ptr + kAVX2SIMDWidth,
|
226 |
+
gru_recurrent_other_ptr + kAVX2SIMDWidth, input_ptr + kAVX2SIMDWidth);
|
227 |
+
float_sigmoid_float<kInputBits>(reset0, reset1);
|
228 |
+
__m256 update0 = GruInput32ToFloat<kSplitGates, kInputsMode>(
|
229 |
+
ar_2_inputs, ar_3rd_input, ar_01_weights + 2 * state_size,
|
230 |
+
ar_2_weights + state_size, gru_recurrent_ptr + state_size,
|
231 |
+
gru_recurrent_other_ptr + state_size, input_ptr + state_size);
|
232 |
+
__m256 update1 = GruInput32ToFloat<kSplitGates, kInputsMode>(
|
233 |
+
ar_2_inputs, ar_3rd_input,
|
234 |
+
ar_01_weights + 2 * state_size + 2 * kQRIncrement,
|
235 |
+
ar_2_weights + state_size + kQRIncrement,
|
236 |
+
gru_recurrent_ptr + state_size + kAVX2SIMDWidth,
|
237 |
+
gru_recurrent_other_ptr + state_size + kAVX2SIMDWidth,
|
238 |
+
input_ptr + state_size + kAVX2SIMDWidth);
|
239 |
+
float_sigmoid_float<kInputBits>(update0, update1);
|
240 |
+
__m256 cell0 = _mm256_cvtepi32_ps(_mm256_load_si256(
|
241 |
+
reinterpret_cast<__m256i const*>(input_ptr + 2 * state_size)));
|
242 |
+
__m256 cell1 =
|
243 |
+
_mm256_cvtepi32_ps(_mm256_load_si256(reinterpret_cast<__m256i const*>(
|
244 |
+
input_ptr + 2 * state_size + kAVX2SIMDWidth)));
|
245 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
246 |
+
cell0 = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>(
|
247 |
+
ar_2_inputs, ar_3rd_input, ar_01_weights + 4 * state_size,
|
248 |
+
ar_2_weights + 2 * state_size, cell0);
|
249 |
+
cell1 = MultiplyAddFloat<kInputsMode == ARInputsMode::k3ARInputs>(
|
250 |
+
ar_2_inputs, ar_3rd_input,
|
251 |
+
ar_01_weights + 4 * state_size + 2 * kQRIncrement,
|
252 |
+
ar_2_weights + 2 * state_size + kQRIncrement, cell1);
|
253 |
+
}
|
254 |
+
__m256i gru_state = GRUComputeState<kInputBits, kStateBits, kSplitGates>(
|
255 |
+
cell0, cell1, reset0, reset1, update0, update1,
|
256 |
+
gru_recurrent_ptr + 2 * state_size,
|
257 |
+
gru_recurrent_other_ptr + 2 * state_size, gru_state_ptr);
|
258 |
+
if (kReplicas > 0) {
|
259 |
+
// With |kReplicas| a template parameter, the compiler will unroll the
|
260 |
+
// loop.
|
261 |
+
for (int j = 0; j < kReplicas; ++j) {
|
262 |
+
_mm256_store_si256(
|
263 |
+
reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride),
|
264 |
+
gru_state);
|
265 |
+
}
|
266 |
+
} else {
|
267 |
+
// This loop will not unroll as replicas is variable.
|
268 |
+
for (int j = 0; j < replicas; ++j) {
|
269 |
+
_mm256_store_si256(
|
270 |
+
reinterpret_cast<__m256i*>(gru_state_ptr + j * replica_stride),
|
271 |
+
gru_state);
|
272 |
+
}
|
273 |
+
}
|
274 |
+
// Increment all the pointers.
|
275 |
+
input_ptr += 2 * kAVX2SIMDWidth;
|
276 |
+
gru_state_ptr += 2 * kAVX2SIMDWidth;
|
277 |
+
gru_recurrent_ptr += 2 * kAVX2SIMDWidth;
|
278 |
+
if (kSplitGates) gru_recurrent_other_ptr += 2 * kAVX2SIMDWidth;
|
279 |
+
if (kInputsMode != ARInputsMode::k0ARInputs) {
|
280 |
+
ar_01_weights += 4 * kQRIncrement;
|
281 |
+
if (kInputsMode == ARInputsMode::k3ARInputs)
|
282 |
+
ar_2_weights += 2 * kQRIncrement;
|
283 |
+
}
|
284 |
+
}
|
285 |
+
}
|
286 |
+
|
287 |
+
// Dispatches calls to the GruGatesTemplate function above converting the
|
288 |
+
// replicas variable argument to a template parameter to allow the compiler to
|
289 |
+
// unroll the write loop.
|
290 |
+
// |ar_sample01| packs sample 0 and 1 into a pair because the QR weights are
|
291 |
+
// formatted with the weights interleaved for sample 0 and 1. The two samples
|
292 |
+
// represent coarse and fine for WaveRNN.
|
293 |
+
template <int kInputBits, int kStateBits,
|
294 |
+
ARInputsMode kInputsMode = ARInputsMode::k2ARInputs,
|
295 |
+
bool kSplitGates = false>
|
296 |
+
inline void GruGatesAVXFixed(
|
297 |
+
int start, int end, int state_size, const int32_t* gru_recurrent_ptr,
|
298 |
+
const int32_t* input_ptr, const std::pair<float, float>* ar_sample01,
|
299 |
+
const float* ar_01_weights, int num_replicas, int replica_stride,
|
300 |
+
const float* ar_sample2, const float* ar_2_weights,
|
301 |
+
const int32_t* gru_recurrent_other_ptr, int16_t* gru_state_ptr) {
|
302 |
+
// Convert the number of replicas from a variable to a template parameter
|
303 |
+
// with a switch. This enables the compiler to unroll the loop for
|
304 |
+
// the write, making it faster for common numbers of threads.
|
305 |
+
switch (num_replicas) {
|
306 |
+
case 1:
|
307 |
+
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/1,
|
308 |
+
kSplitGates>(
|
309 |
+
start, end, state_size, num_replicas, replica_stride,
|
310 |
+
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
|
311 |
+
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
|
312 |
+
break;
|
313 |
+
case 2:
|
314 |
+
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/2,
|
315 |
+
kSplitGates>(
|
316 |
+
start, end, state_size, num_replicas, replica_stride,
|
317 |
+
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
|
318 |
+
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
|
319 |
+
break;
|
320 |
+
case 4:
|
321 |
+
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/4,
|
322 |
+
kSplitGates>(
|
323 |
+
start, end, state_size, num_replicas, replica_stride,
|
324 |
+
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
|
325 |
+
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
|
326 |
+
break;
|
327 |
+
case 6:
|
328 |
+
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/6,
|
329 |
+
kSplitGates>(
|
330 |
+
start, end, state_size, num_replicas, replica_stride,
|
331 |
+
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
|
332 |
+
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
|
333 |
+
break;
|
334 |
+
default:
|
335 |
+
// Zero |kReplicas| tells the function to use the |num_replicas| variable.
|
336 |
+
GruGatesTemplate<kInputBits, kStateBits, kInputsMode, /*kReplicas=*/0,
|
337 |
+
kSplitGates>(
|
338 |
+
start, end, state_size, num_replicas, replica_stride,
|
339 |
+
gru_recurrent_ptr, input_ptr, ar_sample01, ar_01_weights, ar_sample2,
|
340 |
+
ar_2_weights, gru_recurrent_other_ptr, gru_state_ptr);
|
341 |
+
}
|
342 |
+
}
|
343 |
+
|
344 |
+
#endif // __AVX2__
|
345 |
+
|
346 |
+
} // namespace csrblocksparse
|
347 |
+
|
348 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_AVX_FIXED_H_
|
sparse_matmul/compute/gru_gates_generic.h
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_
|
19 |
+
|
20 |
+
#include "sparse_matmul/compute/ar_inputs.h"
|
21 |
+
#include "sparse_matmul/numerics/fast_transcendentals.h"
|
22 |
+
|
23 |
+
namespace csrblocksparse {
|
24 |
+
|
25 |
+
constexpr int kGenericSIMDWidth = 4;
|
26 |
+
|
27 |
+
// TODO(b/188702959): Rename arguments to match gru_gates.h.
|
28 |
+
template <typename GRUStateType, typename GRUMatMulOutType, typename QR_W_Type,
|
29 |
+
typename SampleType, ARInputsMode kInputsMode,
|
30 |
+
bool SplitGates = false>
|
31 |
+
void GoThroughGates(int start, int end, const QR_W_Type* qr_ptr,
|
32 |
+
const GRUMatMulOutType* gru_gates_ptr,
|
33 |
+
const GRUMatMulOutType* gru_gates_other_ptr,
|
34 |
+
const GRUMatMulOutType* conditioning_ptr,
|
35 |
+
GRUStateType* gru_h_ptr, const QR_W_Type* w_hat,
|
36 |
+
int proj_size, const SampleType* coarse_at_sminus1,
|
37 |
+
const SampleType* fine_at_sminus1,
|
38 |
+
const SampleType* coarse_at_s = nullptr) {
|
39 |
+
float qr_cell = 0.0f, reset, update, cell;
|
40 |
+
for (int i = start; i < end; ++i) {
|
41 |
+
if (kInputsMode == ARInputsMode::k0ARInputs) {
|
42 |
+
reset = static_cast<float>(gru_gates_ptr[i]);
|
43 |
+
update = static_cast<float>(gru_gates_ptr[proj_size + i]);
|
44 |
+
} else {
|
45 |
+
float qr_c_reset = static_cast<float>(qr_ptr[2 * i + 0]);
|
46 |
+
float qr_f_reset = static_cast<float>(qr_ptr[2 * i + 1]);
|
47 |
+
float qr_c_update = static_cast<float>(qr_ptr[2 * proj_size + 2 * i + 0]);
|
48 |
+
float qr_f_update = static_cast<float>(qr_ptr[2 * proj_size + 2 * i + 1]);
|
49 |
+
float qr_c_cell = static_cast<float>(qr_ptr[4 * proj_size + 2 * i + 0]);
|
50 |
+
float qr_f_cell = static_cast<float>(qr_ptr[4 * proj_size + 2 * i + 1]);
|
51 |
+
float w_hat_i_reset = 0.0f;
|
52 |
+
float w_hat_i_update = 0.0f;
|
53 |
+
float w_hat_i_cell = 0.0f;
|
54 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
55 |
+
w_hat_i_reset = static_cast<float>(w_hat[i]);
|
56 |
+
w_hat_i_update = static_cast<float>(w_hat[proj_size + i]);
|
57 |
+
w_hat_i_cell = static_cast<float>(w_hat[2 * proj_size + i]);
|
58 |
+
}
|
59 |
+
float coarse = static_cast<float>(coarse_at_sminus1[0]);
|
60 |
+
float fine = static_cast<float>(fine_at_sminus1[0]);
|
61 |
+
reset = qr_c_reset * coarse + qr_f_reset * fine;
|
62 |
+
update = qr_c_update * coarse + qr_f_update * fine;
|
63 |
+
qr_cell = qr_c_cell * coarse + qr_f_cell * fine;
|
64 |
+
if (kInputsMode == ARInputsMode::k3ARInputs) {
|
65 |
+
float coarse = static_cast<float>(coarse_at_s[0]);
|
66 |
+
reset += w_hat_i_reset * coarse;
|
67 |
+
update += w_hat_i_update * coarse;
|
68 |
+
qr_cell += w_hat_i_cell * coarse;
|
69 |
+
}
|
70 |
+
reset += static_cast<float>(gru_gates_ptr[i]);
|
71 |
+
update += static_cast<float>(gru_gates_ptr[proj_size + i]);
|
72 |
+
}
|
73 |
+
cell = static_cast<float>(gru_gates_ptr[2 * proj_size + i]);
|
74 |
+
if (SplitGates) {
|
75 |
+
reset += static_cast<float>(gru_gates_other_ptr[i]);
|
76 |
+
update += static_cast<float>(gru_gates_other_ptr[proj_size + i]);
|
77 |
+
cell += static_cast<float>(gru_gates_other_ptr[2 * proj_size + i]);
|
78 |
+
}
|
79 |
+
float reset_conditioning = static_cast<float>(conditioning_ptr[i]);
|
80 |
+
float update_conditioning =
|
81 |
+
static_cast<float>(conditioning_ptr[proj_size + i]);
|
82 |
+
float cell_conditioning =
|
83 |
+
static_cast<float>(conditioning_ptr[2 * proj_size + i]);
|
84 |
+
reset = fast_sigmoid(reset + reset_conditioning);
|
85 |
+
update = fast_sigmoid(update + update_conditioning);
|
86 |
+
float hbar = fast_tanh(qr_cell + reset * cell + cell_conditioning);
|
87 |
+
int h_index = i;
|
88 |
+
float prev_h = static_cast<float>(gru_h_ptr[h_index]);
|
89 |
+
float diff = prev_h - hbar;
|
90 |
+
float new_h = hbar + diff * update;
|
91 |
+
gru_h_ptr[h_index] = static_cast<GRUStateType>(new_h);
|
92 |
+
}
|
93 |
+
}
|
94 |
+
|
95 |
+
} // namespace csrblocksparse
|
96 |
+
|
97 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_GRU_GATES_GENERIC_H_
|
sparse_matmul/compute/gru_gates_test.cc
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2021 Google LLC
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.
|
14 |
+
|
15 |
+
#include "sparse_matmul/compute/gru_gates.h"
|
16 |
+
|
17 |
+
#include <cstdint>
|
18 |
+
#include <cstring>
|
19 |
+
#include <numeric>
|
20 |
+
|
21 |
+
#include "absl/memory/memory.h"
|
22 |
+
#include "absl/types/span.h"
|
23 |
+
#include "gmock/gmock.h"
|
24 |
+
#include "gtest/gtest.h"
|
25 |
+
|
26 |
+
namespace {
|
27 |
+
|
28 |
+
using csrblocksparse::ARInputsMode;
|
29 |
+
|
30 |
+
template <typename GRUStateType, typename InputType, typename SampleType = void,
|
31 |
+
csrblocksparse::ARInputsMode kInputsMode, bool kSplitGates>
|
32 |
+
csrblocksparse::CacheAlignedVector<GRUStateType> TestGruGates() {
|
33 |
+
using SampleWeightType = float;
|
34 |
+
constexpr int kStateSize = 16;
|
35 |
+
csrblocksparse::CacheAlignedVector<SampleWeightType> qr(6 * kStateSize);
|
36 |
+
csrblocksparse::CacheAlignedVector<SampleWeightType> w(3 * kStateSize);
|
37 |
+
csrblocksparse::CacheAlignedVector<InputType> gru_gates(3 * kStateSize);
|
38 |
+
csrblocksparse::CacheAlignedVector<InputType> gru_other_gates(3 * kStateSize);
|
39 |
+
csrblocksparse::CacheAlignedVector<InputType> conditioning(3 * kStateSize);
|
40 |
+
csrblocksparse::CacheAlignedVector<GRUStateType> gru_h(kStateSize);
|
41 |
+
csrblocksparse::GruGates<GRUStateType, InputType, SampleType> gru_gates_impl;
|
42 |
+
const SampleType kCoarseAtSMinus1(0.03f);
|
43 |
+
const SampleType kFineAtSMinus1(0.07f);
|
44 |
+
const SampleType kCoarseAtS(-0.02f);
|
45 |
+
|
46 |
+
qr.FillOnes();
|
47 |
+
w.FillOnes();
|
48 |
+
gru_gates.FillRandom();
|
49 |
+
gru_other_gates.FillRandom();
|
50 |
+
conditioning.FillRandom();
|
51 |
+
gru_h.FillZero();
|
52 |
+
|
53 |
+
gru_gates_impl.template GruWithARInput<kInputsMode, kSplitGates>(
|
54 |
+
/*start=*/0, /*end=*/kStateSize, kStateSize, gru_gates.data(),
|
55 |
+
conditioning.data(), gru_h.data(), &kCoarseAtSMinus1, &kFineAtSMinus1,
|
56 |
+
qr.data(),
|
57 |
+
/*num_replicas=*/1, /*replica_stride=*/0, &kCoarseAtS, w.data(),
|
58 |
+
gru_other_gates.data());
|
59 |
+
return gru_h;
|
60 |
+
}
|
61 |
+
|
62 |
+
TEST(GruGates, FloatWaveRNNCoarseMatchesGolden) {
|
63 |
+
// If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers
|
64 |
+
// will also need to change.
|
65 |
+
const std::vector<float> kGoldenValues = {
|
66 |
+
0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.746f, 0.0f, 0.0f,
|
67 |
+
0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.993f};
|
68 |
+
csrblocksparse::CacheAlignedVector<float> gru_h =
|
69 |
+
TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
|
70 |
+
/*kSplitGates=*/true>();
|
71 |
+
|
72 |
+
ASSERT_EQ(kGoldenValues.size(), gru_h.size());
|
73 |
+
for (int i = 0; i < gru_h.size(); ++i) {
|
74 |
+
EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i;
|
75 |
+
}
|
76 |
+
}
|
77 |
+
|
78 |
+
TEST(GruGates, FloatWaveRNNFineMatchesGolden) {
|
79 |
+
// If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers
|
80 |
+
// will also need to change.
|
81 |
+
const std::vector<float> kGoldenValues = {
|
82 |
+
0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.737f, 0.0f, 0.0f,
|
83 |
+
0.0f, 0.0f, 0.969f, 0.0f, 0.0f, 1.0f, 0.0f, -0.994f};
|
84 |
+
csrblocksparse::CacheAlignedVector<float> gru_h =
|
85 |
+
TestGruGates<float, float, float, ARInputsMode::k3ARInputs,
|
86 |
+
/*kSplitGates=*/true>();
|
87 |
+
|
88 |
+
ASSERT_EQ(kGoldenValues.size(), gru_h.size());
|
89 |
+
for (int i = 0; i < gru_h.size(); ++i) {
|
90 |
+
EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i;
|
91 |
+
}
|
92 |
+
}
|
93 |
+
|
94 |
+
TEST(GruGates, FloatTwoArInputsNonSplitGateMatchesGolden) {
|
95 |
+
// If the RNG in csrblocksparse::CacheAlignedVector changes, these numbers
|
96 |
+
// will also need to change.
|
97 |
+
const std::vector<float> kGoldenValues = {
|
98 |
+
0.0f, 0.0f, 0.0f, 0.0f, 1.0f, 0.714f, 0.0f, -0.002f,
|
99 |
+
0.0f, 0.0f, 0.970f, 0.0f, 0.0f, 1.0f, 0.0f, -0.965f};
|
100 |
+
csrblocksparse::CacheAlignedVector<float> gru_h =
|
101 |
+
TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
|
102 |
+
/*kSplitGates=*/false>();
|
103 |
+
|
104 |
+
ASSERT_EQ(kGoldenValues.size(), gru_h.size());
|
105 |
+
for (int i = 0; i < gru_h.size(); ++i) {
|
106 |
+
EXPECT_NEAR(kGoldenValues[i], gru_h[i], 1e-3) << "i=" << i;
|
107 |
+
}
|
108 |
+
}
|
109 |
+
|
110 |
+
TEST(GruGates, FixedWaveRNNCoarseMatchesFloat) {
|
111 |
+
using GRUMatMulOutType = csrblocksparse::fixed32<11>;
|
112 |
+
using GRUStateType = csrblocksparse::fixed16<2>;
|
113 |
+
using SampleType = csrblocksparse::fixed16<0>;
|
114 |
+
csrblocksparse::CacheAlignedVector<float> float_gru_h =
|
115 |
+
TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
|
116 |
+
/*kSplitGates=*/true>();
|
117 |
+
csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h =
|
118 |
+
TestGruGates<GRUStateType, GRUMatMulOutType, SampleType,
|
119 |
+
ARInputsMode::k2ARInputs, /*kSplitGates=*/true>();
|
120 |
+
|
121 |
+
ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size());
|
122 |
+
for (int i = 0; i < fixed_gru_h.size(); ++i) {
|
123 |
+
EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3)
|
124 |
+
<< "i=" << i;
|
125 |
+
}
|
126 |
+
}
|
127 |
+
|
128 |
+
TEST(GruGates, FixedWaveRNNFineMatchesFloat) {
|
129 |
+
using GRUMatMulOutType = csrblocksparse::fixed32<11>;
|
130 |
+
using GRUStateType = csrblocksparse::fixed16<2>;
|
131 |
+
using SampleType = csrblocksparse::fixed16<0>;
|
132 |
+
csrblocksparse::CacheAlignedVector<float> float_gru_h =
|
133 |
+
TestGruGates<float, float, float, ARInputsMode::k3ARInputs,
|
134 |
+
/*kSplitGates=*/true>();
|
135 |
+
csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h =
|
136 |
+
TestGruGates<GRUStateType, GRUMatMulOutType, SampleType,
|
137 |
+
ARInputsMode::k3ARInputs, /*kSplitGates=*/true>();
|
138 |
+
|
139 |
+
ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size());
|
140 |
+
for (int i = 0; i < fixed_gru_h.size(); ++i) {
|
141 |
+
EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3)
|
142 |
+
<< "i=" << i;
|
143 |
+
}
|
144 |
+
}
|
145 |
+
|
146 |
+
TEST(GruGates, FixedTwoArInputsNonSplitGateMatchesFloat) {
|
147 |
+
using GRUMatMulOutType = csrblocksparse::fixed32<11>;
|
148 |
+
using GRUStateType = csrblocksparse::fixed16<2>;
|
149 |
+
using SampleType = csrblocksparse::fixed16<0>;
|
150 |
+
csrblocksparse::CacheAlignedVector<float> float_gru_h =
|
151 |
+
TestGruGates<float, float, float, ARInputsMode::k2ARInputs,
|
152 |
+
/*kSplitGates=*/false>();
|
153 |
+
csrblocksparse::CacheAlignedVector<GRUStateType> fixed_gru_h =
|
154 |
+
TestGruGates<GRUStateType, GRUMatMulOutType, SampleType,
|
155 |
+
ARInputsMode::k2ARInputs, /*kSplitGates=*/false>();
|
156 |
+
|
157 |
+
ASSERT_EQ(float_gru_h.size(), fixed_gru_h.size());
|
158 |
+
for (int i = 0; i < fixed_gru_h.size(); ++i) {
|
159 |
+
EXPECT_NEAR(float_gru_h[i], static_cast<float>(fixed_gru_h[i]), 1e-3)
|
160 |
+
<< "i=" << i;
|
161 |
+
}
|
162 |
+
}
|
163 |
+
|
164 |
+
} // namespace
|
sparse_matmul/compute/kernels_arm.h
ADDED
The diff for this file is too large to render.
See raw diff
|
|
sparse_matmul/compute/kernels_avx.h
ADDED
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_
|
19 |
+
|
20 |
+
#if defined __AVX__
|
21 |
+
#include <immintrin.h>
|
22 |
+
|
23 |
+
#include <algorithm>
|
24 |
+
#include <type_traits>
|
25 |
+
// TODO(b/188702959): Remove fast_transcendentals with GRU refactor.
|
26 |
+
#include "sparse_matmul/numerics/fast_transcendentals.h"
|
27 |
+
#include "sparse_matmul/numerics/fixed_types.h"
|
28 |
+
#include "sparse_matmul/numerics/float16_types.h"
|
29 |
+
#include "sparse_matmul/numerics/type_utils.h"
|
30 |
+
|
31 |
+
namespace csrblocksparse {
|
32 |
+
namespace detail {
|
33 |
+
|
34 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
35 |
+
struct IsAllowableFloatTypes
|
36 |
+
: std::integral_constant<bool, std::is_same<WeightType, float>::value &&
|
37 |
+
std::is_same<RhsType, float>::value &&
|
38 |
+
std::is_same<OutType, float>::value> {};
|
39 |
+
|
40 |
+
#if defined __AVX2__
|
41 |
+
// 16-bit inputs, 32-bit output exponent matches sum of input exponents
|
42 |
+
// OR
|
43 |
+
// 16-bit inputs, 16-bit output - will shift to match exponent
|
44 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
45 |
+
struct IsAllowableFixedTypes
|
46 |
+
: std::integral_constant<bool, (IsFixed16Type<WeightType>::value &&
|
47 |
+
IsFixed16Type<RhsType>::value) &&
|
48 |
+
(IsFixed32Type<OutType>::value ||
|
49 |
+
IsFixed16Type<OutType>::value)> {};
|
50 |
+
|
51 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
52 |
+
struct ShouldEnableGenericKernel
|
53 |
+
: std::integral_constant<
|
54 |
+
bool,
|
55 |
+
!IsAllowableFloatTypes<WeightType, RhsType, OutType>::value &&
|
56 |
+
!IsAllowableFixedTypes<WeightType, RhsType, OutType>::value> {};
|
57 |
+
|
58 |
+
template <typename Type>
|
59 |
+
struct IsAddableFixedTypes
|
60 |
+
: std::integral_constant<bool, IsFixed32Type<Type>::value ||
|
61 |
+
IsFixed16Type<Type>::value> {};
|
62 |
+
template <typename Type>
|
63 |
+
struct ShouldEnableGenericAdd
|
64 |
+
: std::integral_constant<bool, !IsAddableFixedTypes<Type>::value> {};
|
65 |
+
|
66 |
+
#else // No AVX2.
|
67 |
+
|
68 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
69 |
+
struct ShouldEnableGenericKernel
|
70 |
+
: std::integral_constant<
|
71 |
+
bool, !IsAllowableFloatTypes<WeightType, RhsType, OutType>::value> {};
|
72 |
+
|
73 |
+
template <typename Type>
|
74 |
+
struct ShouldEnableGenericAdd : std::true_type {};
|
75 |
+
#endif // __AVX2__
|
76 |
+
|
77 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
78 |
+
struct ShouldEnableGenericSpMV_4x4
|
79 |
+
: ShouldEnableGenericKernel<WeightType, RhsType, OutType> {};
|
80 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
81 |
+
struct ShouldEnableGenericSpMM5_4x4
|
82 |
+
: ShouldEnableGenericKernel<WeightType, RhsType, OutType> {};
|
83 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
84 |
+
struct ShouldEnableGenericSpMV_1x1 : std::true_type {};
|
85 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
86 |
+
struct ShouldEnableGenericSpMM5_1x1 : std::true_type {};
|
87 |
+
|
88 |
+
// The computational routines do NO error checking for speed. It is assumed
|
89 |
+
// that this has been handled by CSRBlockSparseMatrix.
|
90 |
+
|
91 |
+
// In-line function to extract results from a pair of registers and store in
|
92 |
+
// memory. Note that the non-const references are registers, and are modified
|
93 |
+
// by this function!
|
94 |
+
inline void Extract4Results(bool relu, __m256& sum1, __m256& sum2,
|
95 |
+
float** out_ptr) {
|
96 |
+
// Horizontally add the results. We have 2 registers, |sum1| and |sum2| that
|
97 |
+
// each contain 2 sets of 4 values that need to be added.
|
98 |
+
sum1 = _mm256_hadd_ps(sum1, sum2);
|
99 |
+
sum1 = _mm256_hadd_ps(sum1, sum1);
|
100 |
+
// Now |sum1| contains [|res0|, |res2|, |res0|, |res2|, |res1|, |res3|,
|
101 |
+
// |res1|, |res3|]
|
102 |
+
if (relu) {
|
103 |
+
sum1 = _mm256_max_ps(sum1, _mm256_setzero_ps());
|
104 |
+
}
|
105 |
+
// It is really hard in AVX to cross the 128 bit 'lanes' and this is the
|
106 |
+
// *only* way to do it.
|
107 |
+
// Get the top half of |sum1| in to bottom of |sum2|.
|
108 |
+
sum2 = _mm256_permute2f128_ps(sum1, sum1, 1);
|
109 |
+
// Interleave the values between the two registers.
|
110 |
+
sum1 = _mm256_unpacklo_ps(sum1, sum2);
|
111 |
+
// Save the lower 128 bits (4 floats).
|
112 |
+
__m128 result = _mm256_extractf128_ps(sum1, 0);
|
113 |
+
_mm_store_ps(*out_ptr, result);
|
114 |
+
*out_ptr += 4;
|
115 |
+
}
|
116 |
+
|
117 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
|
118 |
+
// blocked pattern, x is a vector and b is vector. Weights are stored for this
|
119 |
+
// routine by making each 4x4 block contiguous. Blocks are ordered in standard
|
120 |
+
// row-major format. column indices are converted to deltas and then multiplied
|
121 |
+
// by 2 to convert to bytes, so that the value can be used directly to offset
|
122 |
+
// the pointer into the rhs vector.
|
123 |
+
//
|
124 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
125 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
126 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
127 |
+
// speedup by reducing latencies at the end of the loop.
|
128 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
129 |
+
typename std::enable_if<std::is_same<WeightType, float>::value &&
|
130 |
+
std::is_same<RhsType, float>::value &&
|
131 |
+
std::is_same<OutType, float>::value>::type
|
132 |
+
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
133 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
134 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
135 |
+
OutType* out_ptr, int64_t assigned_rows,
|
136 |
+
int64_t rows /* only used in SpMM variants */,
|
137 |
+
int64_t cols /* only used in SpMM variants */, int relu) {
|
138 |
+
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
|
139 |
+
// Broadcast the biases by 4 to undo the division by 4 in the input biases.
|
140 |
+
__m256 sum1 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
|
141 |
+
_mm_broadcast_ss(bias_ptr));
|
142 |
+
bias_ptr += 2;
|
143 |
+
__m256 sum2 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
|
144 |
+
_mm_broadcast_ss(bias_ptr));
|
145 |
+
bias_ptr += 2;
|
146 |
+
|
147 |
+
int reduced_col_count = *nnz_per_row++;
|
148 |
+
for (int c = 0; c < reduced_col_count; ++c) {
|
149 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
150 |
+
rhs_ptr += col_delta;
|
151 |
+
// Multiply this 4x4 block.
|
152 |
+
__m256 rhs =
|
153 |
+
_mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptr));
|
154 |
+
__m256 weights1 = _mm256_load_ps(weights_ptr);
|
155 |
+
weights_ptr += 8;
|
156 |
+
sum1 = _mm256_add_ps(sum1, _mm256_mul_ps(weights1, rhs));
|
157 |
+
__m256 weights2 = _mm256_load_ps(weights_ptr);
|
158 |
+
weights_ptr += 8;
|
159 |
+
sum2 = _mm256_add_ps(sum2, _mm256_mul_ps(weights2, rhs));
|
160 |
+
}
|
161 |
+
Extract4Results(relu, sum1, sum2, &out_ptr);
|
162 |
+
}
|
163 |
+
}
|
164 |
+
|
165 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
|
166 |
+
// blocked pattern, x is a fat vector with 5 columns and b is vector. b is
|
167 |
+
// broadcast. Weights are stored for this routine by making each 4x4 block
|
168 |
+
// contiguous. Blocks are ordered in standard row-major format. column indices
|
169 |
+
// are converted to deltas and then multiplied by 2 to convert to bytes, so
|
170 |
+
// that the value can be used directly to offset the pointer into the rhs
|
171 |
+
// vector.
|
172 |
+
//
|
173 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
174 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
175 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
176 |
+
// speedup by reducing latencies at the end of the loop.
|
177 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
178 |
+
typename std::enable_if<std::is_same<WeightType, float>::value &&
|
179 |
+
std::is_same<RhsType, float>::value &&
|
180 |
+
std::is_same<OutType, float>::value>::type
|
181 |
+
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
182 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
183 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
184 |
+
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
|
185 |
+
int relu) {
|
186 |
+
const RhsType* rhs_ptrs[5];
|
187 |
+
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
|
188 |
+
|
189 |
+
OutType* out_ptrs[5];
|
190 |
+
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
|
191 |
+
|
192 |
+
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
|
193 |
+
// We will acumulate the results in 10 registers, |sum1_0| to |sum2_4|.
|
194 |
+
// Broadcast the biases by 4 to undo the division by 4 in the input biases.
|
195 |
+
__m256 sum1_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
|
196 |
+
_mm_broadcast_ss(bias_ptr));
|
197 |
+
bias_ptr += 2;
|
198 |
+
__m256 sum2_0 = _mm256_set_m128(_mm_broadcast_ss(bias_ptr + 1),
|
199 |
+
_mm_broadcast_ss(bias_ptr));
|
200 |
+
bias_ptr += 2;
|
201 |
+
__m256 sum1_1 = sum1_0;
|
202 |
+
__m256 sum2_1 = sum2_0;
|
203 |
+
__m256 sum1_2 = sum1_0;
|
204 |
+
__m256 sum2_2 = sum2_0;
|
205 |
+
__m256 sum1_3 = sum1_0;
|
206 |
+
__m256 sum2_3 = sum2_0;
|
207 |
+
__m256 sum1_4 = sum1_0;
|
208 |
+
__m256 sum2_4 = sum2_0;
|
209 |
+
|
210 |
+
int reduced_col_count = *nnz_per_row++;
|
211 |
+
for (int c = 0; c < reduced_col_count; ++c) {
|
212 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
213 |
+
for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;
|
214 |
+
|
215 |
+
// Multiply this 4x4 block.
|
216 |
+
__m256 rhs =
|
217 |
+
_mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[0]));
|
218 |
+
__m256 weights1 = _mm256_load_ps(weights_ptr);
|
219 |
+
weights_ptr += 8;
|
220 |
+
sum1_0 = _mm256_add_ps(sum1_0, _mm256_mul_ps(weights1, rhs));
|
221 |
+
__m256 weights2 = _mm256_load_ps(weights_ptr);
|
222 |
+
weights_ptr += 8;
|
223 |
+
sum2_0 = _mm256_add_ps(sum2_0, _mm256_mul_ps(weights2, rhs));
|
224 |
+
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[1]));
|
225 |
+
sum1_1 = _mm256_add_ps(sum1_1, _mm256_mul_ps(weights1, rhs));
|
226 |
+
sum2_1 = _mm256_add_ps(sum2_1, _mm256_mul_ps(weights2, rhs));
|
227 |
+
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[2]));
|
228 |
+
sum1_2 = _mm256_add_ps(sum1_2, _mm256_mul_ps(weights1, rhs));
|
229 |
+
sum2_2 = _mm256_add_ps(sum2_2, _mm256_mul_ps(weights2, rhs));
|
230 |
+
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[3]));
|
231 |
+
sum1_3 = _mm256_add_ps(sum1_3, _mm256_mul_ps(weights1, rhs));
|
232 |
+
sum2_3 = _mm256_add_ps(sum2_3, _mm256_mul_ps(weights2, rhs));
|
233 |
+
rhs = _mm256_broadcast_ps(reinterpret_cast<const __m128*>(rhs_ptrs[4]));
|
234 |
+
sum1_4 = _mm256_add_ps(sum1_4, _mm256_mul_ps(weights1, rhs));
|
235 |
+
sum2_4 = _mm256_add_ps(sum2_4, _mm256_mul_ps(weights2, rhs));
|
236 |
+
}
|
237 |
+
|
238 |
+
Extract4Results(relu, sum1_0, sum2_0, &out_ptrs[0]);
|
239 |
+
Extract4Results(relu, sum1_1, sum2_1, &out_ptrs[1]);
|
240 |
+
Extract4Results(relu, sum1_2, sum2_2, &out_ptrs[2]);
|
241 |
+
Extract4Results(relu, sum1_3, sum2_3, &out_ptrs[3]);
|
242 |
+
Extract4Results(relu, sum1_4, sum2_4, &out_ptrs[4]);
|
243 |
+
}
|
244 |
+
}
|
245 |
+
|
246 |
+
#ifdef __AVX2__
|
247 |
+
|
248 |
+
// In-line function to finish the computation of the result as 4x int32 in
|
249 |
+
// |sum|.
|
250 |
+
inline void Compute4Results(bool relu, int kShiftAmount, __m256i& sum) {
|
251 |
+
// Horizontally add the results. We have 1 register that contains results
|
252 |
+
// [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not
|
253 |
+
// cross lanes, so we end up with [0 1 0 1 2 3 2 3]
|
254 |
+
sum = _mm256_hadd_epi32(sum, sum);
|
255 |
+
// Permutes the middle two pairs to get the answers together.
|
256 |
+
sum = _mm256_permute4x64_epi64(sum, 0xd8);
|
257 |
+
if (kShiftAmount > 0) {
|
258 |
+
// Shift right with rounding to get the right number of mantissa bits.
|
259 |
+
__m256i rounding = _mm256_set1_epi32(1 << (kShiftAmount - 1));
|
260 |
+
sum = _mm256_add_epi32(sum, rounding);
|
261 |
+
sum = _mm256_srai_epi32(sum, kShiftAmount);
|
262 |
+
}
|
263 |
+
// Now |sum| contains [|res0|, |res1|, |res2|, |res3|, |res0|, |res1|,
|
264 |
+
// |res2|, |res3|]
|
265 |
+
if (relu) {
|
266 |
+
sum = _mm256_max_epi32(sum, _mm256_setzero_si256());
|
267 |
+
}
|
268 |
+
}
|
269 |
+
|
270 |
+
// In-line function to extract the 4x int32 results from |sum| to memory.
|
271 |
+
// Non-const reference for |sum| as it is a register.
|
272 |
+
inline void Extract4xint32(bool relu, int kShiftAmount, __m256i& sum,
|
273 |
+
int32_t** out_ptr) {
|
274 |
+
Compute4Results(relu, kShiftAmount, sum);
|
275 |
+
// Save the lower 128 bits (4x int32).
|
276 |
+
__m128i result = _mm256_extractf128_si256(sum, 0);
|
277 |
+
_mm_store_si128(reinterpret_cast<__m128i*>(*out_ptr), result);
|
278 |
+
*out_ptr += 4;
|
279 |
+
}
|
280 |
+
|
281 |
+
// In-line function to extract the 4x int32 results from sum to 4x int16 in
|
282 |
+
// memory.
|
283 |
+
// Non-const reference for |sum| as it is a register.
|
284 |
+
inline void Extract4xint16(bool relu, int kShiftAmount, __m256i& sum,
|
285 |
+
int16_t** out_ptr) {
|
286 |
+
Compute4Results(relu, kShiftAmount, sum);
|
287 |
+
// Clip to 16 bit range (with saturation) and pack in the bottom 64 bits.
|
288 |
+
// Converts the lower 4x int32 in bottom 128 bits to 4x int16 in bottom 64
|
289 |
+
// bits, replicated in the next 64 bits.
|
290 |
+
sum = _mm256_packs_epi32(sum, sum);
|
291 |
+
// Save 4x int 16 from the bottom 64 bits.
|
292 |
+
*reinterpret_cast<int64_t*>(*out_ptr) = _mm256_extract_epi64(sum, 0);
|
293 |
+
*out_ptr += 4;
|
294 |
+
}
|
295 |
+
|
296 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
|
297 |
+
// blocked pattern, x is a vector and b is vector. Weights are stored for this
|
298 |
+
// routine by making each 4x4 block contiguous. Blocks are ordered in standard
|
299 |
+
// row-major format. column indices are converted to deltas and then multiplied
|
300 |
+
// by 2 to convert to bytes, so that the value can be used directly to offset
|
301 |
+
// the pointer into the rhs vector.
|
302 |
+
//
|
303 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
304 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
305 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
306 |
+
// speedup by reducing latencies at the end of the loop.
|
307 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
308 |
+
typename std::enable_if<
|
309 |
+
IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value &&
|
310 |
+
(IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type
|
311 |
+
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
312 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
313 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
314 |
+
OutType* out_ptr, int64_t assigned_rows,
|
315 |
+
int64_t rows /* only used in SpMM variants */,
|
316 |
+
int64_t cols /* only used in SpMM variants */, int relu) {
|
317 |
+
constexpr int kShiftAmount =
|
318 |
+
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
|
319 |
+
OutType::kMantissaBits;
|
320 |
+
static_assert(kShiftAmount >= 0,
|
321 |
+
"Result must have fewer mantissa bits than product");
|
322 |
+
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
|
323 |
+
// Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3].
|
324 |
+
__m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr));
|
325 |
+
__m256i biases = _mm256_set_m128i(bias, bias);
|
326 |
+
bias_ptr += 4;
|
327 |
+
// Swap the top two pairs: [0 1 2 3 2 3 0 1]
|
328 |
+
// TODO(b/188702959): consider |_mm256_permutevar8x32|, and set the index
|
329 |
+
// register outside the row loop.
|
330 |
+
biases = _mm256_permute4x64_epi64(biases, 0xb4);
|
331 |
+
// Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3].
|
332 |
+
biases = _mm256_unpacklo_epi32(biases, biases);
|
333 |
+
// Double the results to make up for the division by 4.
|
334 |
+
// TODO(b/188702959): consider moving this to where the biases are computed.
|
335 |
+
__m256i sum = _mm256_add_epi32(biases, biases);
|
336 |
+
|
337 |
+
// TODO(b/188702959): People don't like the old-fashioned, close-to-the-
|
338 |
+
// metal notation of *|nnz_per_row|++, so measure the effect of putting the
|
339 |
+
// increment in the for loop.
|
340 |
+
int reduced_col_count = *nnz_per_row;
|
341 |
+
++nnz_per_row;
|
342 |
+
for (int c = 0; c < reduced_col_count; ++c) {
|
343 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
344 |
+
rhs_ptr += col_delta;
|
345 |
+
// Multiply this 4x4 block.
|
346 |
+
// Get the 4x int16 into the bottom of rhs_64.
|
347 |
+
__m128i rhs_64 =
|
348 |
+
_mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptr));
|
349 |
+
// Load all 16 weights.
|
350 |
+
__m256i weights =
|
351 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
|
352 |
+
// Broadcast the rhs, pretending that each is a 64-bit unit:
|
353 |
+
// [0123 0123 0123 0123].
|
354 |
+
__m256i rhs = _mm256_broadcastq_epi64(rhs_64);
|
355 |
+
weights_ptr += 16;
|
356 |
+
// |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally
|
357 |
+
// adds adjacent pairs to make 8x32 bit results. Add these to the sum.
|
358 |
+
sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs));
|
359 |
+
}
|
360 |
+
static_assert(
|
361 |
+
IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value,
|
362 |
+
"AVX2 kernel only supports fixed16 and fixed32 types");
|
363 |
+
// The only significant difference between fixed16 and fixed32 is the size
|
364 |
+
// of the storage unit. The registers have to be repacked accordingly.
|
365 |
+
if (IsFixed32Type<OutType>::value) {
|
366 |
+
Extract4xint32(relu, kShiftAmount, sum,
|
367 |
+
reinterpret_cast<int32_t**>(&out_ptr));
|
368 |
+
} else {
|
369 |
+
Extract4xint16(relu, kShiftAmount, sum,
|
370 |
+
reinterpret_cast<int16_t**>(&out_ptr));
|
371 |
+
}
|
372 |
+
}
|
373 |
+
}
|
374 |
+
|
375 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
|
376 |
+
// blocked pattern, x is a fat vector with 5 columns and b is vector. b is
|
377 |
+
// broadcast. Weights are stored for this routine by making each 4x4 block
|
378 |
+
// contiguous. Blocks are ordered in standard row-major format. column indices
|
379 |
+
// are converted to deltas and then multiplied by 2 to convert to bytes, so
|
380 |
+
// that the value can be used directly to offset the pointer into the rhs
|
381 |
+
// vector.
|
382 |
+
//
|
383 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
384 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
385 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
386 |
+
// speedup by reducing latencies at the end of the loop.
|
387 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
388 |
+
typename std::enable_if<
|
389 |
+
IsFixed16Type<WeightType>::value && IsFixed16Type<RhsType>::value &&
|
390 |
+
(IsFixed32Type<OutType>::value || IsFixed16Type<OutType>::value)>::type
|
391 |
+
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
392 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
393 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
394 |
+
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
|
395 |
+
int relu) {
|
396 |
+
constexpr int kShiftAmount =
|
397 |
+
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
|
398 |
+
OutType::kMantissaBits;
|
399 |
+
static_assert(kShiftAmount >= 0,
|
400 |
+
"Result must have fewer mantissa bits than product");
|
401 |
+
const RhsType* rhs_ptrs[5];
|
402 |
+
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
|
403 |
+
|
404 |
+
OutType* out_ptrs[5];
|
405 |
+
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
|
406 |
+
|
407 |
+
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
|
408 |
+
// We will acumulate the results in 5 registers, sum_0 to sum_4.
|
409 |
+
// Load the biases duplicated into a 256 bit register [0 1 2 3 0 1 2 3].
|
410 |
+
__m128i bias = _mm_load_si128(reinterpret_cast<__m128i const*>(bias_ptr));
|
411 |
+
__m256i biases = _mm256_set_m128i(bias, bias);
|
412 |
+
bias_ptr += 4;
|
413 |
+
// Swap the top two pairs: [0 1 2 3 2 3 0 1]
|
414 |
+
biases = _mm256_permute4x64_epi64(biases, 0xb4);
|
415 |
+
// Duplicate the low pairs in each lane: [0 0 1 1 2 2 3 3].
|
416 |
+
biases = _mm256_unpacklo_epi32(biases, biases);
|
417 |
+
// Double the results to make up for the division by 4.
|
418 |
+
__m256i sum_0 = _mm256_add_epi32(biases, biases);
|
419 |
+
__m256i sum_1 = sum_0;
|
420 |
+
__m256i sum_2 = sum_0;
|
421 |
+
__m256i sum_3 = sum_0;
|
422 |
+
__m256i sum_4 = sum_0;
|
423 |
+
|
424 |
+
int reduced_col_count = *nnz_per_row;
|
425 |
+
++nnz_per_row;
|
426 |
+
for (int c = 0; c < reduced_col_count; ++c) {
|
427 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
428 |
+
for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;
|
429 |
+
// Multiply this 4x4 block.
|
430 |
+
// Get the 4x int16 into the bottom of |rhs_64|.
|
431 |
+
__m128i rhs_64 =
|
432 |
+
_mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[0]));
|
433 |
+
// Load all 16 weights.
|
434 |
+
__m256i weights =
|
435 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
|
436 |
+
// Broadcast the rhs, pretending that each is a 64-bit unit:
|
437 |
+
// [0123 0123 0123 0123].
|
438 |
+
__m256i rhs = _mm256_broadcastq_epi64(rhs_64);
|
439 |
+
weights_ptr += 16;
|
440 |
+
// |_mm256_madd_epi16| does 16x16x16=16x32 bit multiply and horizontally
|
441 |
+
// adds adjacent pairs to make 8x32 bit results. Add these to the sum.
|
442 |
+
sum_0 = _mm256_add_epi32(sum_0, _mm256_madd_epi16(weights, rhs));
|
443 |
+
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[1]));
|
444 |
+
rhs = _mm256_broadcastq_epi64(rhs_64);
|
445 |
+
sum_1 = _mm256_add_epi32(sum_1, _mm256_madd_epi16(weights, rhs));
|
446 |
+
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[2]));
|
447 |
+
rhs = _mm256_broadcastq_epi64(rhs_64);
|
448 |
+
sum_2 = _mm256_add_epi32(sum_2, _mm256_madd_epi16(weights, rhs));
|
449 |
+
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[3]));
|
450 |
+
rhs = _mm256_broadcastq_epi64(rhs_64);
|
451 |
+
sum_3 = _mm256_add_epi32(sum_3, _mm256_madd_epi16(weights, rhs));
|
452 |
+
rhs_64 = _mm_loadl_epi64(reinterpret_cast<__m128i const*>(rhs_ptrs[4]));
|
453 |
+
rhs = _mm256_broadcastq_epi64(rhs_64);
|
454 |
+
sum_4 = _mm256_add_epi32(sum_4, _mm256_madd_epi16(weights, rhs));
|
455 |
+
}
|
456 |
+
static_assert(
|
457 |
+
IsFixed16Type<OutType>::value || IsFixed32Type<OutType>::value,
|
458 |
+
"AVX2 kernel only supports fixed16 and fixed32 types");
|
459 |
+
// The only significant difference between fixed16 and fixed32 is the size
|
460 |
+
// of the storage unit. The registers have to be repacked accordingly.
|
461 |
+
if (IsFixed32Type<OutType>::value) {
|
462 |
+
Extract4xint32(relu, kShiftAmount, sum_0,
|
463 |
+
reinterpret_cast<int32_t**>(&out_ptrs[0]));
|
464 |
+
Extract4xint32(relu, kShiftAmount, sum_1,
|
465 |
+
reinterpret_cast<int32_t**>(&out_ptrs[1]));
|
466 |
+
Extract4xint32(relu, kShiftAmount, sum_2,
|
467 |
+
reinterpret_cast<int32_t**>(&out_ptrs[2]));
|
468 |
+
Extract4xint32(relu, kShiftAmount, sum_3,
|
469 |
+
reinterpret_cast<int32_t**>(&out_ptrs[3]));
|
470 |
+
Extract4xint32(relu, kShiftAmount, sum_4,
|
471 |
+
reinterpret_cast<int32_t**>(&out_ptrs[4]));
|
472 |
+
} else {
|
473 |
+
Extract4xint16(relu, kShiftAmount, sum_0,
|
474 |
+
reinterpret_cast<int16_t**>(&out_ptrs[0]));
|
475 |
+
Extract4xint16(relu, kShiftAmount, sum_1,
|
476 |
+
reinterpret_cast<int16_t**>(&out_ptrs[1]));
|
477 |
+
Extract4xint16(relu, kShiftAmount, sum_2,
|
478 |
+
reinterpret_cast<int16_t**>(&out_ptrs[2]));
|
479 |
+
Extract4xint16(relu, kShiftAmount, sum_3,
|
480 |
+
reinterpret_cast<int16_t**>(&out_ptrs[3]));
|
481 |
+
Extract4xint16(relu, kShiftAmount, sum_4,
|
482 |
+
reinterpret_cast<int16_t**>(&out_ptrs[4]));
|
483 |
+
}
|
484 |
+
}
|
485 |
+
}
|
486 |
+
|
487 |
+
// Processes one GRU gate input with sigmoid.
|
488 |
+
template <int InputMantissaBits, int StateMantissaBits, bool SplitGates>
|
489 |
+
inline __m256i GRUGateSigmoid(const void* gate_ptr, const void* gate_other_ptr,
|
490 |
+
const __m256i& input,
|
491 |
+
const int32_t* sigmoid_table) {
|
492 |
+
__m256i gate = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_ptr));
|
493 |
+
if (SplitGates) {
|
494 |
+
__m256i other =
|
495 |
+
_mm256_loadu_si256(reinterpret_cast<const __m256i*>(gate_other_ptr));
|
496 |
+
gate = _mm256_add_epi32(gate, other);
|
497 |
+
}
|
498 |
+
gate = _mm256_add_epi32(gate, input);
|
499 |
+
// Compute sigmoids on reset and update.
|
500 |
+
return csrblocksparse::fixed32_sigmoid_fixed16<InputMantissaBits,
|
501 |
+
StateMantissaBits>(
|
502 |
+
sigmoid_table, gate);
|
503 |
+
}
|
504 |
+
|
505 |
+
// Processes the tanh and the final combination, returning the new GRU state.
|
506 |
+
template <int InputMantissaBits, int StateMantissaBits, bool SplitGates = false>
|
507 |
+
inline __m256i GRUGateState(const __m256i& cell, const __m256i& reset,
|
508 |
+
const __m256i& update,
|
509 |
+
const __m256i& rounding_offset,
|
510 |
+
const void* gate_ptr, const void* gate_other_ptr,
|
511 |
+
const void* gru_h_ptr, const int32_t* tanh_table) {
|
512 |
+
// Multiply the cell GRU output and the reset. There is a slight danger of
|
513 |
+
// loss of precision here, so use 32x32=64 bit and shift back after.
|
514 |
+
__m256i gru = _mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_ptr));
|
515 |
+
if (SplitGates) {
|
516 |
+
__m256i other_gru =
|
517 |
+
_mm256_loadu_si256(reinterpret_cast<__m256i const*>(gate_other_ptr));
|
518 |
+
gru = _mm256_add_epi32(gru, other_gru);
|
519 |
+
}
|
520 |
+
// This only computes the products of the low-order 32 bits of each pair.
|
521 |
+
__m256i gru_lo = _mm256_mul_epi32(gru, reset);
|
522 |
+
// Swap odd and even 32-bit units and do it again to get the high products.
|
523 |
+
gru = _mm256_shuffle_epi32(gru, 0xb1);
|
524 |
+
__m256i gru_hi = _mm256_mul_epi32(gru, _mm256_shuffle_epi32(reset, 0xb1));
|
525 |
+
// Now shift right to compensate for the multiply and re-interleave the
|
526 |
+
// 32-bit results.
|
527 |
+
// NOTE: There is no shift right arithmetic for 64 bit values until AVX512!
|
528 |
+
// Fortunately it doesn't matter, as the results are being truncated to 32
|
529 |
+
// bits and we aren't shifting right by more than 32 bits here.
|
530 |
+
gru_lo = _mm256_srli_epi64(gru_lo, StateMantissaBits);
|
531 |
+
// The upper results are shifted LEFT, so we can use blend to recombine in
|
532 |
+
// a single instruction.
|
533 |
+
gru_hi = _mm256_slli_epi64(gru_hi, 32 - StateMantissaBits);
|
534 |
+
// Recombine the 32 bit results from lo and hi, alternating.
|
535 |
+
gru = _mm256_blend_epi32(gru_lo, gru_hi, 0xaa);
|
536 |
+
gru = _mm256_add_epi32(cell, gru);
|
537 |
+
// Compute tanh on the result. Although this instantly discards a bunch of
|
538 |
+
// bits, there were only 7 surplus bits for the multiply, which isn't enough
|
539 |
+
// to do it as 16x16=32.
|
540 |
+
__m256i hbar =
|
541 |
+
csrblocksparse::fixed32_tanh_fixed16<InputMantissaBits,
|
542 |
+
StateMantissaBits>(tanh_table, gru);
|
543 |
+
// Load the 16-bit previous GRU state and sign-extend to 32 bits.
|
544 |
+
gru = _mm256_cvtepi16_epi32(
|
545 |
+
_mm_load_si128(reinterpret_cast<__m128i const*>(gru_h_ptr)));
|
546 |
+
gru = _mm256_sub_epi32(gru, hbar);
|
547 |
+
// Since |gru| is 16 bit sign-extended to 32, and |update| is the output of
|
548 |
+
// sigmoid, it is always contained within 16 bits and never negative, we can
|
549 |
+
// use |madd_epi16| to do 16x16=32 multiply with horizontal adding as the
|
550 |
+
// addend will always be zero, and this is twice as fast as full blown
|
551 |
+
// 32x32=32. The only possible problem is if the subtract above caused
|
552 |
+
// overflow.
|
553 |
+
gru = _mm256_madd_epi16(gru, update);
|
554 |
+
// Renormalize to fixed16. This time rounding is critical, as this is the
|
555 |
+
// output GRU state.
|
556 |
+
gru = _mm256_add_epi32(gru, rounding_offset);
|
557 |
+
gru = _mm256_srai_epi32(gru, StateMantissaBits);
|
558 |
+
return _mm256_add_epi32(gru, hbar);
|
559 |
+
}
|
560 |
+
|
561 |
+
template <typename Type>
|
562 |
+
typename std::enable_if<IsFixed32Type<Type>::value>::type SumVectors(
|
563 |
+
int start, int end, const Type* add1, const Type* add2, Type* result) {
|
564 |
+
constexpr int kSIMDWidth = 8;
|
565 |
+
for (int i = start; i < end; i += kSIMDWidth) {
|
566 |
+
__m256i data1 =
|
567 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i));
|
568 |
+
__m256i data2 =
|
569 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i));
|
570 |
+
data1 = _mm256_add_epi32(data1, data2);
|
571 |
+
_mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1);
|
572 |
+
}
|
573 |
+
}
|
574 |
+
|
575 |
+
template <typename Type>
|
576 |
+
typename std::enable_if<IsFixed16Type<Type>::value>::type SumVectors(
|
577 |
+
int start, int end, const Type* add1, const Type* add2, Type* result) {
|
578 |
+
constexpr int kSIMDWidth = 16;
|
579 |
+
for (int i = start; i < end; i += kSIMDWidth) {
|
580 |
+
__m256i data1 =
|
581 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(add1 + i));
|
582 |
+
__m256i data2 =
|
583 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(add2 + i));
|
584 |
+
data1 = _mm256_add_epi16(data1, data2);
|
585 |
+
_mm256_store_si256(reinterpret_cast<__m256i*>(result + i), data1);
|
586 |
+
}
|
587 |
+
}
|
588 |
+
|
589 |
+
#endif // __AVX2__
|
590 |
+
|
591 |
+
} // namespace detail
|
592 |
+
} // namespace csrblocksparse
|
593 |
+
|
594 |
+
#undef LABEL_COL_LOOP
|
595 |
+
#undef LABEL_ROW_LOOP
|
596 |
+
#undef LABEL_SKIP_COL_LOOP
|
597 |
+
#undef LABEL_TOP_LOOP
|
598 |
+
|
599 |
+
#endif // __AVX__
|
600 |
+
|
601 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_AVX_H_
|
sparse_matmul/compute/kernels_generic.h
ADDED
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_
|
19 |
+
|
20 |
+
#include <algorithm>
|
21 |
+
#include <type_traits>
|
22 |
+
|
23 |
+
#include "sparse_matmul/numerics/fixed_types.h"
|
24 |
+
#include "sparse_matmul/numerics/float16_types.h"
|
25 |
+
#include "sparse_matmul/numerics/type_utils.h"
|
26 |
+
|
27 |
+
// Separate out the assembly kernels for readability. Eventually this will
|
28 |
+
// become an ifdef switch on the architecture type.
|
29 |
+
#if defined __aarch64__
|
30 |
+
#include "sparse_matmul/compute/kernels_arm.h"
|
31 |
+
#elif defined __AVX__
|
32 |
+
#include "sparse_matmul/compute/kernels_avx.h"
|
33 |
+
#else // defined __AVX__
|
34 |
+
// If there is no architecture-specific implementation, then always use generic.
|
35 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
36 |
+
struct ShouldEnableGenericSpMV_4x4 : std::true_type {};
|
37 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
38 |
+
struct ShouldEnableGenericSpMM5_4x4 : std::true_type {};
|
39 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
40 |
+
struct ShouldEnableGenericSpMV_1x1 : std::true_type {};
|
41 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
42 |
+
struct ShouldEnableGenericSpMM5_1x1 : std::true_type {};
|
43 |
+
template <typename Type>
|
44 |
+
struct ShouldEnableGenericAdd : std::true_type {};
|
45 |
+
#endif // defined __arch64__
|
46 |
+
|
47 |
+
namespace csrblocksparse {
|
48 |
+
namespace detail {
|
49 |
+
|
50 |
+
// The computational routines do NO error checking for speed. It is assumed
|
51 |
+
// that this has been handled by CSRBlockSparseMatrix.
|
52 |
+
|
53 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
|
54 |
+
// blocked pattern, x is a vector and b is vector. Weights are stored for this
|
55 |
+
// routine by making each 4x4 block contiguous. Blocks are ordered in standard
|
56 |
+
// row-major format. column indices are converted to deltas and then multiplied
|
57 |
+
// by 2 to convert to bytes, so that the value can be used directly to offset
|
58 |
+
// the pointer into the rhs vector.
|
59 |
+
//
|
60 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
61 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
62 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
63 |
+
// speedup by reducing latencies at the end of the loop.
|
64 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
65 |
+
typename std::enable_if<
|
66 |
+
ShouldEnableGenericSpMV_4x4<WeightType, RhsType, OutType>::value>::type
|
67 |
+
SpMV_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
68 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
69 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
70 |
+
OutType* out_ptr, int64_t assigned_rows,
|
71 |
+
int64_t rows /* only used in SpMM variants */,
|
72 |
+
int64_t cols /* only used in SpMM variants */, int relu) {
|
73 |
+
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
|
74 |
+
float accumulators[4];
|
75 |
+
// Undo the divion by the happens for the assembly version.
|
76 |
+
for (int i = 0; i < 4; ++i)
|
77 |
+
accumulators[i] = 4.f * static_cast<float>(*bias_ptr++);
|
78 |
+
|
79 |
+
int reduced_col_count = *nnz_per_row++;
|
80 |
+
for (int c = 0; c < reduced_col_count; ++c) {
|
81 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
82 |
+
rhs_ptr += col_delta;
|
83 |
+
|
84 |
+
// Multiply this 4x4 block.
|
85 |
+
for (int i = 0; i < 4; ++i) {
|
86 |
+
for (int j = 0; j < 4; ++j) {
|
87 |
+
accumulators[i] += static_cast<float>(*weights_ptr++) *
|
88 |
+
static_cast<float>(rhs_ptr[j]);
|
89 |
+
}
|
90 |
+
}
|
91 |
+
}
|
92 |
+
|
93 |
+
for (int i = 0; i < 4; ++i)
|
94 |
+
*out_ptr++ = static_cast<OutType>(relu ? std::max(accumulators[i], 0.f)
|
95 |
+
: accumulators[i]);
|
96 |
+
}
|
97 |
+
}
|
98 |
+
|
99 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with a 4x4
|
100 |
+
// blocked pattern, x is a fat vector with 5 columns and b is vector. b is
|
101 |
+
// broadcast. Weights are stored for this routine by making each 4x4 block
|
102 |
+
// contiguous. Blocks are ordered in standard row-major format. column indices
|
103 |
+
// are converted to deltas and then multiplied by 2 to convert to bytes, so
|
104 |
+
// that the value can be used directly to offset the pointer into the rhs
|
105 |
+
// vector.
|
106 |
+
//
|
107 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
108 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
109 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
110 |
+
// speedup by reducing latencies at the end of the loop.
|
111 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
112 |
+
typename std::enable_if<
|
113 |
+
ShouldEnableGenericSpMM5_4x4<WeightType, RhsType, OutType>::value>::type
|
114 |
+
SpMM5_4x4(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
115 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
116 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
117 |
+
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
|
118 |
+
int relu) {
|
119 |
+
const RhsType* rhs_ptrs[5];
|
120 |
+
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
|
121 |
+
|
122 |
+
OutType* out_ptrs[5];
|
123 |
+
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
|
124 |
+
|
125 |
+
for (int reduced_row = 0; reduced_row < assigned_rows; ++reduced_row) {
|
126 |
+
float accumulators[4][5];
|
127 |
+
// Undo the divion by the happens for the assembly version.
|
128 |
+
for (int i = 0; i < 4; ++i) {
|
129 |
+
for (int k = 0; k < 5; ++k) {
|
130 |
+
accumulators[i][k] = 4.f * static_cast<float>(*bias_ptr);
|
131 |
+
}
|
132 |
+
++bias_ptr;
|
133 |
+
}
|
134 |
+
|
135 |
+
int reduced_col_count = *nnz_per_row++;
|
136 |
+
for (int c = 0; c < reduced_col_count; ++c) {
|
137 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
138 |
+
for (int k = 0; k < 5; ++k) rhs_ptrs[k] += col_delta;
|
139 |
+
|
140 |
+
// multiply this 4x4 block
|
141 |
+
for (int i = 0; i < 4; ++i) {
|
142 |
+
for (int j = 0; j < 4; ++j) {
|
143 |
+
for (int k = 0; k < 5; ++k) {
|
144 |
+
accumulators[i][k] += static_cast<float>(*weights_ptr) *
|
145 |
+
static_cast<float>(rhs_ptrs[k][j]);
|
146 |
+
}
|
147 |
+
weights_ptr++;
|
148 |
+
}
|
149 |
+
}
|
150 |
+
}
|
151 |
+
|
152 |
+
for (int k = 0; k < 5; ++k) {
|
153 |
+
for (int i = 0; i < 4; ++i) {
|
154 |
+
out_ptrs[k][0] = static_cast<OutType>(
|
155 |
+
relu ? std::max(accumulators[i][k], 0.f) : accumulators[i][k]);
|
156 |
+
out_ptrs[k]++;
|
157 |
+
}
|
158 |
+
}
|
159 |
+
}
|
160 |
+
}
|
161 |
+
|
162 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with
|
163 |
+
// a 1x1 blocked pattern (ie unstructured), x is a
|
164 |
+
// vector and b is vector.
|
165 |
+
// Weights are stored for this routine in standard CSR format. Each row must
|
166 |
+
// have a multiple of 8 columns.
|
167 |
+
// column indices are converted to deltas and then multiplied by 2 to convert
|
168 |
+
// to bytes, so that the value can be used directly to offset the pointer
|
169 |
+
// into the rhs vector.
|
170 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
171 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
172 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
173 |
+
// speedup by reducing latencies at the end of the loop.
|
174 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
175 |
+
typename std::enable_if<
|
176 |
+
ShouldEnableGenericSpMV_1x1<WeightType, RhsType, OutType>::value>::type
|
177 |
+
SpMV_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
178 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
179 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
180 |
+
OutType* out_ptr, int64_t assigned_rows,
|
181 |
+
int64_t rows /* only used in SpMM variants */,
|
182 |
+
int64_t cols /* only used in SpMM variants */, int relu) {
|
183 |
+
for (int row = 0; row < assigned_rows; ++row) {
|
184 |
+
// Undo the divion by the happens for the assembly version.
|
185 |
+
float accumulator = 4.f * static_cast<float>(*bias_ptr++);
|
186 |
+
|
187 |
+
int col_count = *nnz_per_row++;
|
188 |
+
for (int c = 0; c < col_count; ++c) {
|
189 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
190 |
+
rhs_ptr += col_delta;
|
191 |
+
|
192 |
+
accumulator +=
|
193 |
+
static_cast<float>(*weights_ptr++) * static_cast<float>(*rhs_ptr);
|
194 |
+
}
|
195 |
+
|
196 |
+
*out_ptr++ =
|
197 |
+
static_cast<OutType>(relu ? std::max(accumulator, 0.f) : accumulator);
|
198 |
+
}
|
199 |
+
}
|
200 |
+
|
201 |
+
// Performs the calculation y = A * x + b where A is a sparse matrix with
|
202 |
+
// a 1x1 blocked pattern (ie unstructured), x is a
|
203 |
+
// vector and b is vector.
|
204 |
+
// Weights are stored for this routine in standard CSR format. Each row must
|
205 |
+
// have a multiple of 8 columns.
|
206 |
+
// column indices are converted to deltas and then multiplied by 2 to convert
|
207 |
+
// to bytes, so that the value can be used directly to offset the pointer
|
208 |
+
// into the rhs vector.
|
209 |
+
// NOTE: The bias is expected to have be multiplied by .25f prior to calling
|
210 |
+
// this function. This is automatically taken care of in SparseLinearLayer.
|
211 |
+
// The bias is reconstructed through horizontal additions, leads to a small
|
212 |
+
// speedup by reducing latencies at the end of the loop.
|
213 |
+
template <typename WeightType, typename RhsType, typename OutType>
|
214 |
+
typename std::enable_if<
|
215 |
+
ShouldEnableGenericSpMM5_1x1<WeightType, RhsType, OutType>::value>::type
|
216 |
+
SpMM5_1x1(const WeightType* weights_ptr, const int16_t* col_deltas_bytes,
|
217 |
+
const int32_t* nnz_per_row, const RhsType* rhs_ptr,
|
218 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias_ptr,
|
219 |
+
OutType* out_ptr, int64_t assigned_rows, int64_t rows, int64_t cols,
|
220 |
+
int relu) {
|
221 |
+
const RhsType* rhs_ptrs[5];
|
222 |
+
for (int i = 0; i < 5; ++i) rhs_ptrs[i] = rhs_ptr + i * cols;
|
223 |
+
|
224 |
+
OutType* out_ptrs[5];
|
225 |
+
for (int i = 0; i < 5; ++i) out_ptrs[i] = out_ptr + i * rows;
|
226 |
+
|
227 |
+
for (int row = 0; row < assigned_rows; ++row) {
|
228 |
+
// Undo the divion by the happens for the assembly version.
|
229 |
+
float accumulator[5];
|
230 |
+
for (int i = 0; i < 5; ++i)
|
231 |
+
accumulator[i] = 4.f * static_cast<float>(*bias_ptr);
|
232 |
+
|
233 |
+
++bias_ptr;
|
234 |
+
|
235 |
+
int col_count = *nnz_per_row++;
|
236 |
+
for (int c = 0; c < col_count; ++c) {
|
237 |
+
int col_delta = *col_deltas_bytes++ / sizeof(RhsType);
|
238 |
+
for (int i = 0; i < 5; ++i) {
|
239 |
+
rhs_ptrs[i] += col_delta;
|
240 |
+
accumulator[i] += static_cast<float>(*weights_ptr) *
|
241 |
+
static_cast<float>(rhs_ptrs[i][0]);
|
242 |
+
}
|
243 |
+
weights_ptr++;
|
244 |
+
}
|
245 |
+
|
246 |
+
for (int i = 0; i < 5; ++i) {
|
247 |
+
out_ptrs[i][0] = static_cast<OutType>(relu ? std::max(accumulator[i], 0.f)
|
248 |
+
: accumulator[i]);
|
249 |
+
out_ptrs[i]++;
|
250 |
+
}
|
251 |
+
}
|
252 |
+
}
|
253 |
+
|
254 |
+
template <typename Type>
|
255 |
+
typename std::enable_if<ShouldEnableGenericAdd<Type>::value>::type SumVectors(
|
256 |
+
int start, int end, const Type* add1, const Type* add2, Type* result) {
|
257 |
+
LOG_FIRST_N(WARNING, 1) << "SumVectors: using generic kernel!";
|
258 |
+
for (int i = start; i < end; ++i) {
|
259 |
+
Type sum = static_cast<Type>(static_cast<float>(add1[i]) +
|
260 |
+
static_cast<float>(add2[i]));
|
261 |
+
result[i] = sum;
|
262 |
+
}
|
263 |
+
}
|
264 |
+
|
265 |
+
} // namespace detail
|
266 |
+
} // namespace csrblocksparse
|
267 |
+
|
268 |
+
#undef LABEL_COL_LOOP
|
269 |
+
#undef LABEL_ROW_LOOP
|
270 |
+
#undef LABEL_SKIP_COL_LOOP
|
271 |
+
#undef LABEL_TOP_LOOP
|
272 |
+
|
273 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_KERNELS_GENERIC_H_
|
sparse_matmul/compute/matmul.h
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
|
19 |
+
|
20 |
+
#include <cstdint>
|
21 |
+
#include <vector>
|
22 |
+
|
23 |
+
#include "absl/time/time.h"
|
24 |
+
#include "sparse_matmul/compute/matmul_fixed_avx2.h"
|
25 |
+
#include "sparse_matmul/compute/matmul_generic.h"
|
26 |
+
#include "sparse_matmul/numerics/fixed_types.h"
|
27 |
+
#include "sparse_matmul/numerics/type_utils.h"
|
28 |
+
#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32)
|
29 |
+
#include <cpuid.h>
|
30 |
+
#endif
|
31 |
+
|
32 |
+
namespace csrblocksparse {
|
33 |
+
|
34 |
+
// The number of elements in a block.
|
35 |
+
constexpr int kBlockSize = 4;
|
36 |
+
|
37 |
+
// Base class for Matmul containing the members that are non type-specicfic.
|
38 |
+
class MatmulBase {
|
39 |
+
public:
|
40 |
+
// Constructor initializes the flags that determine which implementation to
|
41 |
+
// use at run-time, constrained by both compiler flags and cpuid.
|
42 |
+
MatmulBase() {
|
43 |
+
#if defined(__x86_64__) || defined(__i386__) || defined(_WIN32)
|
44 |
+
// Code tested to work on Linux systems and multiple Android emulators.
|
45 |
+
unsigned int eax, ebx, ecx, edx;
|
46 |
+
if (__get_cpuid(1, &eax, &ebx, &ecx, &edx) != 0) {
|
47 |
+
using_avx_ = (ecx & bit_AVX) != 0;
|
48 |
+
if (using_avx_) {
|
49 |
+
__get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx);
|
50 |
+
using_avx2_ = (ebx & bit_AVX2) != 0;
|
51 |
+
using_avx512_ = (ebx & bit_AVX512F) != 0 && (ebx & bit_AVX512DQ) &&
|
52 |
+
(ebx & bit_AVX512BW) != 0;
|
53 |
+
VLOG(2) << "avx2 flag=" << using_avx2_ << " 512=" << using_avx512_;
|
54 |
+
} else {
|
55 |
+
LOG(ERROR) << "AVX not found at all!";
|
56 |
+
}
|
57 |
+
}
|
58 |
+
#else
|
59 |
+
using_aarch64_ = true;
|
60 |
+
#endif
|
61 |
+
}
|
62 |
+
|
63 |
+
protected:
|
64 |
+
// Flags that define what (runtime) architectures are available. Flags that
|
65 |
+
// are set are limited by both the compiler flags and runtime environment.
|
66 |
+
bool using_avx512_ = false;
|
67 |
+
bool using_avx2_ = false;
|
68 |
+
bool using_avx_ = false;
|
69 |
+
bool using_aarch64_ = false;
|
70 |
+
};
|
71 |
+
|
72 |
+
// The master template is really a catch-all for the unimplmented cases to
|
73 |
+
// report an error.
|
74 |
+
template <typename WeightType, typename RhsType>
|
75 |
+
class Matmul : public MatmulBase {
|
76 |
+
public:
|
77 |
+
// Sparse inputs, outputs replicated strided for each thread.
|
78 |
+
template <typename OutType>
|
79 |
+
void MatVec4x4(const WeightType* weights, const RhsType* rhs,
|
80 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias,
|
81 |
+
const int32_t* nnz_per_row, const int16_t* rhs_indices,
|
82 |
+
int start_row, int end_row, bool relu, int replicas,
|
83 |
+
int stride, OutType* output) {
|
84 |
+
// The specializations should take care of every real case.
|
85 |
+
CHECK(false) << "Unsupported combination of types used!";
|
86 |
+
}
|
87 |
+
template <typename OutType>
|
88 |
+
void MatVec8x4(const WeightType* weights, const RhsType* rhs,
|
89 |
+
const typename TypeOfProduct<WeightType, RhsType>::type* bias,
|
90 |
+
const int32_t* nnz_per_row, const int16_t* rhs_indices,
|
91 |
+
int start_row, int end_row, bool relu, int replicas,
|
92 |
+
int stride, OutType* output) {
|
93 |
+
// The specializations should take care of every real case.
|
94 |
+
CHECK(false) << "Unsupported combination of types used!";
|
95 |
+
}
|
96 |
+
};
|
97 |
+
|
98 |
+
// Full specialization for float.
|
99 |
+
template <>
|
100 |
+
class Matmul<float, float> : public MatmulBase {
|
101 |
+
public:
|
102 |
+
void MatVec4x4(const float* weights, const float* rhs, const float* bias,
|
103 |
+
const int32_t* nnz_per_row, const int16_t* rhs_indices,
|
104 |
+
int start_row, int end_row, bool relu, int replicas,
|
105 |
+
int stride, float* output) {
|
106 |
+
detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
|
107 |
+
start_row, end_row, /*block_height=*/4,
|
108 |
+
/*block_width=*/4, relu, replicas, stride,
|
109 |
+
output);
|
110 |
+
}
|
111 |
+
void MatVec8x4(const float* weights, const float* rhs, const float* bias,
|
112 |
+
const int32_t* nnz_per_row, const int16_t* rhs_indices,
|
113 |
+
int start_row, int end_row, bool relu, int replicas,
|
114 |
+
int stride, float* output) {
|
115 |
+
detail::MatVecFloatGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
|
116 |
+
start_row, end_row, /*block_height=*/8,
|
117 |
+
/*block_width=*/4, relu, replicas, stride,
|
118 |
+
output);
|
119 |
+
}
|
120 |
+
};
|
121 |
+
|
122 |
+
// Partial specialization for fixed types. Covers fixed16xfixed16 = OutType,
|
123 |
+
// where OutType should be fixed16 or fixed32. The mantissa bits don't have
|
124 |
+
// to match.
|
125 |
+
template <int WeightBits, int RhsBits>
|
126 |
+
class Matmul<fixed16<WeightBits>, fixed16<RhsBits>> : public MatmulBase {
|
127 |
+
public:
|
128 |
+
using WeightType = fixed16<WeightBits>;
|
129 |
+
using RhsType = fixed16<RhsBits>;
|
130 |
+
|
131 |
+
template <typename OutType>
|
132 |
+
void MatVec4x4(const int16_t* weights, const int16_t* rhs,
|
133 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
134 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
135 |
+
bool relu, int replicas, int stride, OutType* output) {
|
136 |
+
constexpr int kShiftAmount =
|
137 |
+
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
|
138 |
+
OutType::kMantissaBits;
|
139 |
+
static_assert(kShiftAmount >= 0,
|
140 |
+
"OutType must not have more mantissa bits than inputs");
|
141 |
+
#if defined __AVX2__
|
142 |
+
CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
|
143 |
+
if (sizeof(*output) == 4) {
|
144 |
+
int32_t* out32 = reinterpret_cast<int32_t*>(output);
|
145 |
+
detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
|
146 |
+
start_row, end_row, relu, kShiftAmount,
|
147 |
+
replicas, stride, out32);
|
148 |
+
} else {
|
149 |
+
int16_t* out16 = reinterpret_cast<int16_t*>(output);
|
150 |
+
detail::MatVec4x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
|
151 |
+
start_row, end_row, relu, kShiftAmount,
|
152 |
+
replicas, stride, out16);
|
153 |
+
}
|
154 |
+
#elif defined __aarch64__
|
155 |
+
if (using_aarch64_) {
|
156 |
+
LOG(FATAL) << "Fixed16 MatVec4x4 not yet implemented!";
|
157 |
+
}
|
158 |
+
|
159 |
+
#else
|
160 |
+
detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
|
161 |
+
start_row, end_row, /*block_height=*/4,
|
162 |
+
/*block_width=*/4, relu, sizeof(*output),
|
163 |
+
kShiftAmount, replicas, stride, output);
|
164 |
+
#endif // __AVX2__
|
165 |
+
}
|
166 |
+
|
167 |
+
template <typename OutType>
|
168 |
+
void MatVec8x4(const int16_t* weights, const int16_t* rhs,
|
169 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
170 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
171 |
+
bool relu, int replicas, int stride, OutType* output) {
|
172 |
+
constexpr int kShiftAmount =
|
173 |
+
TypeOfProduct<WeightType, RhsType>::type::kMantissaBits -
|
174 |
+
OutType::kMantissaBits;
|
175 |
+
static_assert(kShiftAmount >= 0,
|
176 |
+
"OutType must not have more mantissa bits than inputs");
|
177 |
+
#if defined __AVX2__
|
178 |
+
CHECK(replicas == 1 && sizeof(*output) == 4)
|
179 |
+
<< "Only replicas == 1 and fixed32 output are implemented for AVX2!";
|
180 |
+
CHECK(using_avx2_) << "Compiled for AVX2, but cpu flag not set!";
|
181 |
+
int32_t* out32 = reinterpret_cast<int32_t*>(output);
|
182 |
+
detail::MatVec8x4FixedAVX2(weights, rhs, bias, nnz_per_row, rhs_indices,
|
183 |
+
start_row, end_row, relu, kShiftAmount, out32);
|
184 |
+
#elif defined __aarch64__
|
185 |
+
if (using_aarch64_) {
|
186 |
+
LOG(FATAL) << "Fixed16 MatVec8x4 not yet implemented!";
|
187 |
+
}
|
188 |
+
#else
|
189 |
+
detail::MatVecFixedGeneric(weights, rhs, bias, nnz_per_row, rhs_indices,
|
190 |
+
start_row, end_row, /*block_height=*/8,
|
191 |
+
/*block_width=*/4, relu, sizeof(*output),
|
192 |
+
kShiftAmount, replicas, stride, output);
|
193 |
+
#endif // __AVX2__
|
194 |
+
}
|
195 |
+
};
|
196 |
+
|
197 |
+
} // namespace csrblocksparse
|
198 |
+
|
199 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_H_
|
sparse_matmul/compute/matmul_fixed_avx2.cc
ADDED
@@ -0,0 +1,235 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2021 Google LLC
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.
|
14 |
+
|
15 |
+
#include "sparse_matmul/compute/matmul_fixed_avx2.h"
|
16 |
+
|
17 |
+
#include <cstdint>
|
18 |
+
|
19 |
+
#if defined __AVX__
|
20 |
+
#include <immintrin.h>
|
21 |
+
#endif
|
22 |
+
|
23 |
+
#include "sparse_matmul/compute/matmul.h"
|
24 |
+
|
25 |
+
namespace csrblocksparse {
|
26 |
+
namespace detail {
|
27 |
+
|
28 |
+
static const int32_t kint32min = static_cast<int32_t>(~0x7FFFFFFF);
|
29 |
+
static const int32_t kint32max = static_cast<int32_t>(0x7FFFFFFF);
|
30 |
+
|
31 |
+
#if defined __AVX2__
|
32 |
+
// In-line function computes and returns the result of one row (of blocks) as
|
33 |
+
// 4x int32_t. |weights_ptr| is a non-const reference so it can easily be
|
34 |
+
// interpreted as belonging to the caller.
|
35 |
+
inline __m256i ComputeRowResults(const __m128i& bias128, const int16_t* rhs,
|
36 |
+
const int16_t* rhs_indices, int nnz,
|
37 |
+
int16_t const*& weights_ptr) {
|
38 |
+
// Expand bias to 64 bits in a 256 bit register [0 z 1 z 2 z 3 z], where z is
|
39 |
+
// Zero and 0-3 are the 4x32 bit bias values.
|
40 |
+
__m256i sum = _mm256_cvtepu32_epi64(bias128);
|
41 |
+
|
42 |
+
for (int c = 0; c < nnz; ++c) {
|
43 |
+
int rhs_index = rhs_indices[c];
|
44 |
+
// Load all 16 weights.
|
45 |
+
__m256i weights =
|
46 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
|
47 |
+
// Get the 4x int16_t into the bottom of |rhs_64|.
|
48 |
+
__m128i rhs_64 = _mm_loadl_epi64(
|
49 |
+
reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize));
|
50 |
+
// Broadcast the rhs, pretending that each is a 64-bit unit:
|
51 |
+
// [0123 0123 0123 0123].
|
52 |
+
__m256i rhs_value = _mm256_broadcastq_epi64(rhs_64);
|
53 |
+
weights_ptr += 16;
|
54 |
+
sum = _mm256_add_epi32(sum, _mm256_madd_epi16(weights, rhs_value));
|
55 |
+
}
|
56 |
+
// Horizontally add the results. We have 1 register that contains results
|
57 |
+
// [0 0 1 1 2 2 3 3], but hadd (and almost no other AVX instruction) will not
|
58 |
+
// cross lanes, so we end up with [0 1 0 1 2 3 2 3]
|
59 |
+
sum = _mm256_hadd_epi32(sum, sum);
|
60 |
+
// Permutes the middle two pairs to get the answers together.
|
61 |
+
return _mm256_permute4x64_epi64(sum, 0xd8);
|
62 |
+
}
|
63 |
+
|
64 |
+
// Template that allows any fixed combination of OutType and replicas, plus
|
65 |
+
// variable |relu|, |shift_out|. Note that |kReplicas| is a template arg as
|
66 |
+
// well as a function arg so we can hard-code a limited amount of unrolling.
|
67 |
+
template <typename OutType, int kReplicas>
|
68 |
+
void MatVec4x4FixedAVX2Template(const int16_t* weights_ptr, const int16_t* rhs,
|
69 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
70 |
+
const int16_t* rhs_indices, int start_row,
|
71 |
+
int end_row, bool relu, int shift_out,
|
72 |
+
int replicas, int stride, OutType* output) {
|
73 |
+
int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0;
|
74 |
+
__m256i rounding = _mm256_set1_epi32(rounding_addon);
|
75 |
+
__m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min);
|
76 |
+
for (int row_block = start_row; row_block < end_row; ++row_block) {
|
77 |
+
// Load 4 biases [0 1 2 3].
|
78 |
+
__m128i bias128 = _mm_load_si128(reinterpret_cast<__m128i const*>(bias));
|
79 |
+
bias += kBlockSize;
|
80 |
+
int nnz = nnz_per_row[row_block];
|
81 |
+
__m256i sum =
|
82 |
+
ComputeRowResults(bias128, rhs, rhs_indices, nnz, weights_ptr);
|
83 |
+
rhs_indices += nnz;
|
84 |
+
// Shift right with rounding to get the right number of mantissa bits.
|
85 |
+
sum = _mm256_add_epi32(sum, rounding);
|
86 |
+
sum = _mm256_srai_epi32(sum, shift_out);
|
87 |
+
// Now sum contains [res0, res1, res2, res3, res0, res1, res2, res3]
|
88 |
+
sum = _mm256_max_epi32(sum, zero);
|
89 |
+
if (sizeof(OutType) == 2) {
|
90 |
+
// Clip to 16 bit range (with saturation) and pack in the bottom 64
|
91 |
+
// bits. The 64 bit result is replicated across the whole 256 bit
|
92 |
+
// register. [0123 0123 0123 0123]
|
93 |
+
sum = _mm256_packs_epi32(sum, sum);
|
94 |
+
int64_t result = _mm256_extract_epi64(sum, 0);
|
95 |
+
*reinterpret_cast<int64_t*>(output) = result;
|
96 |
+
if (kReplicas > 1) {
|
97 |
+
*reinterpret_cast<int64_t*>(output + stride) = result;
|
98 |
+
if (kReplicas > 2) {
|
99 |
+
for (int r = 2; r < replicas; ++r) {
|
100 |
+
*reinterpret_cast<int64_t*>(output + r * stride) = result;
|
101 |
+
}
|
102 |
+
}
|
103 |
+
}
|
104 |
+
} else {
|
105 |
+
// Save the lower 128 bits (4x int32_t).
|
106 |
+
__m128i result = _mm256_extractf128_si256(sum, 0);
|
107 |
+
_mm_store_si128(reinterpret_cast<__m128i*>(output), result);
|
108 |
+
if (kReplicas > 1) {
|
109 |
+
_mm_store_si128(reinterpret_cast<__m128i*>(output + stride), result);
|
110 |
+
if (kReplicas > 2) {
|
111 |
+
for (int r = 2; r < replicas; ++r) {
|
112 |
+
_mm_store_si128(reinterpret_cast<__m128i*>(output + r * stride),
|
113 |
+
result);
|
114 |
+
}
|
115 |
+
}
|
116 |
+
}
|
117 |
+
}
|
118 |
+
output += kBlockSize;
|
119 |
+
}
|
120 |
+
}
|
121 |
+
|
122 |
+
// Version that covers all possible combinations of the variable conditions:
|
123 |
+
// |relu|, |shift_out|, |replicas|, with int16_t |output|.
|
124 |
+
void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
|
125 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
126 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
127 |
+
bool relu, int shift_out, int replicas, int stride,
|
128 |
+
int16_t* output) {
|
129 |
+
if (replicas <= 1) {
|
130 |
+
MatVec4x4FixedAVX2Template<int16_t, 1>(weights_ptr, rhs, bias, nnz_per_row,
|
131 |
+
rhs_indices, start_row, end_row,
|
132 |
+
relu, shift_out, 1, stride, output);
|
133 |
+
} else if (replicas == 2) {
|
134 |
+
MatVec4x4FixedAVX2Template<int16_t, 2>(weights_ptr, rhs, bias, nnz_per_row,
|
135 |
+
rhs_indices, start_row, end_row,
|
136 |
+
relu, shift_out, 2, stride, output);
|
137 |
+
} else {
|
138 |
+
MatVec4x4FixedAVX2Template<int16_t, 3>(
|
139 |
+
weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row,
|
140 |
+
relu, shift_out, replicas, stride, output);
|
141 |
+
}
|
142 |
+
}
|
143 |
+
|
144 |
+
// Version that covers all possible combinations of the variable conditions:
|
145 |
+
// |relu|, |shift_out|, |replicas|, with int32_t |output|.
|
146 |
+
void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
|
147 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
148 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
149 |
+
bool relu, int shift_out, int replicas, int stride,
|
150 |
+
int32_t* output) {
|
151 |
+
if (replicas <= 1) {
|
152 |
+
MatVec4x4FixedAVX2Template<int32_t, 1>(weights_ptr, rhs, bias, nnz_per_row,
|
153 |
+
rhs_indices, start_row, end_row,
|
154 |
+
relu, shift_out, 1, stride, output);
|
155 |
+
} else if (replicas == 2) {
|
156 |
+
MatVec4x4FixedAVX2Template<int32_t, 2>(weights_ptr, rhs, bias, nnz_per_row,
|
157 |
+
rhs_indices, start_row, end_row,
|
158 |
+
relu, shift_out, 2, stride, output);
|
159 |
+
} else {
|
160 |
+
MatVec4x4FixedAVX2Template<int32_t, 3>(
|
161 |
+
weights_ptr, rhs, bias, nnz_per_row, rhs_indices, start_row, end_row,
|
162 |
+
relu, shift_out, replicas, stride, output);
|
163 |
+
}
|
164 |
+
}
|
165 |
+
|
166 |
+
// In-line function computes and returns the result of one row (of blocks) as
|
167 |
+
// 8x int32_t. weights_ptr is a non-const reference so it can easily be
|
168 |
+
// interpreted as belonging to the caller.
|
169 |
+
inline __m256i Compute8RowResults(const __m256i& bias256, const int16_t* rhs,
|
170 |
+
const int16_t* rhs_indices, int nnz,
|
171 |
+
int16_t const*& weights_ptr) {
|
172 |
+
// Expand bias to 64 bits in a 256 bit register [0 z 1 z 2 z 3 z], where z is
|
173 |
+
// Zero and 0-3 are the 4x32 bit bias values from 128 bit half of the input.
|
174 |
+
__m256i sum1 = _mm256_cvtepu32_epi64(_mm256_castsi256_si128(bias256));
|
175 |
+
// Plus 4 more in another sum register from the upper 128 bit half.
|
176 |
+
__m256i sum2 = _mm256_cvtepu32_epi64(_mm256_extractf128_si256(bias256, 1));
|
177 |
+
|
178 |
+
for (int c = 0; c < nnz; ++c) {
|
179 |
+
int rhs_index = rhs_indices[c];
|
180 |
+
// Load all 16 weights.
|
181 |
+
__m256i weights =
|
182 |
+
_mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
|
183 |
+
// Get the 4x int16_t into the bottom of |rhs_64|.
|
184 |
+
__m128i rhs_64 = _mm_loadl_epi64(
|
185 |
+
reinterpret_cast<__m128i const*>(rhs + rhs_index * kBlockSize));
|
186 |
+
// Broadcast the rhs, pretending that each is a 64-bit unit:
|
187 |
+
// [0123 0123 0123 0123].
|
188 |
+
__m256i rhs_value = _mm256_broadcastq_epi64(rhs_64);
|
189 |
+
weights_ptr += 16;
|
190 |
+
sum1 = _mm256_add_epi32(sum1, _mm256_madd_epi16(weights, rhs_value));
|
191 |
+
// Same again for the other 4 results, re-using the same rhs value.
|
192 |
+
weights = _mm256_load_si256(reinterpret_cast<__m256i const*>(weights_ptr));
|
193 |
+
weights_ptr += 16;
|
194 |
+
sum2 = _mm256_add_epi32(sum2, _mm256_madd_epi16(weights, rhs_value));
|
195 |
+
}
|
196 |
+
// Horizontally add the results. We have 2 registers that contain results
|
197 |
+
// [0 0 1 1 2 2 3 3], and [4 4 5 5 6 6 7 7] but hadd (and almost no other AVX
|
198 |
+
// instruction) will not cross lanes, so we end up with [0 1 4 5 2 3 6 7]
|
199 |
+
sum1 = _mm256_hadd_epi32(sum1, sum2);
|
200 |
+
// Permutes the middle two pairs to get the answers in the right order.
|
201 |
+
return _mm256_permute4x64_epi64(sum1, 0xd8);
|
202 |
+
}
|
203 |
+
|
204 |
+
// Version that covers the main conditions used with 8x4:
|
205 |
+
// |relu|, |shift_out|, with int32_t |output|.
|
206 |
+
void MatVec8x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
|
207 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
208 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
209 |
+
bool relu, int shift_out, int32_t* output) {
|
210 |
+
int rounding_addon = shift_out > 0 ? (1 << (shift_out - 1)) : 0;
|
211 |
+
__m256i rounding = _mm256_set1_epi32(rounding_addon);
|
212 |
+
__m256i zero = relu ? _mm256_setzero_si256() : _mm256_set1_epi32(kint32min);
|
213 |
+
for (int row_block = start_row; row_block < end_row; ++row_block) {
|
214 |
+
// Load 4 biases [0 1 2 3 4 5 6 7].
|
215 |
+
__m256i bias256 = _mm256_load_si256(reinterpret_cast<__m256i const*>(bias));
|
216 |
+
bias += kBlockSize * 2;
|
217 |
+
int nnz = nnz_per_row[row_block];
|
218 |
+
__m256i sum =
|
219 |
+
Compute8RowResults(bias256, rhs, rhs_indices, nnz, weights_ptr);
|
220 |
+
rhs_indices += nnz;
|
221 |
+
// Shift right with rounding to get the right number of mantissa bits.
|
222 |
+
sum = _mm256_add_epi32(sum, rounding);
|
223 |
+
sum = _mm256_srai_epi32(sum, shift_out);
|
224 |
+
// Now sum contains [res0, res1, res2, res3, res0, res1, res2, res3]
|
225 |
+
sum = _mm256_max_epi32(sum, zero);
|
226 |
+
// Save the all 256 bits (8x int32_t).
|
227 |
+
_mm256_store_si256(reinterpret_cast<__m256i*>(output), sum);
|
228 |
+
output += kBlockSize * 2;
|
229 |
+
}
|
230 |
+
}
|
231 |
+
|
232 |
+
#endif
|
233 |
+
|
234 |
+
} // namespace detail
|
235 |
+
} // namespace csrblocksparse
|
sparse_matmul/compute/matmul_fixed_avx2.h
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_
|
19 |
+
|
20 |
+
#include <cstdint>
|
21 |
+
|
22 |
+
namespace csrblocksparse {
|
23 |
+
namespace detail {
|
24 |
+
|
25 |
+
// Version that covers all possible combinations of the variable conditions:
|
26 |
+
// |relu|, |shift_out|, |replicas|, with int16 output.
|
27 |
+
void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
|
28 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
29 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
30 |
+
bool relu, int shift_out, int replicas, int stride,
|
31 |
+
int16_t* output);
|
32 |
+
// Version that covers all possible combinations of the variable conditions:
|
33 |
+
// |relu|, |shift_out|, |replicas|, with int32 output.
|
34 |
+
void MatVec4x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
|
35 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
36 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
37 |
+
bool relu, int shift_out, int replicas, int stride,
|
38 |
+
int32_t* output);
|
39 |
+
// Version that covers the main conditions used with 8x4:
|
40 |
+
// |relu|, |shift_out|, with int32 output.
|
41 |
+
void MatVec8x4FixedAVX2(const int16_t* weights_ptr, const int16_t* rhs,
|
42 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
43 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
44 |
+
bool relu, int shift_out, int32_t* output);
|
45 |
+
|
46 |
+
} // namespace detail
|
47 |
+
} // namespace csrblocksparse
|
48 |
+
|
49 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_FIXED_AVX2_H_
|
sparse_matmul/compute/matmul_generic.cc
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2021 Google LLC
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.
|
14 |
+
|
15 |
+
#include "sparse_matmul/compute/matmul_generic.h"
|
16 |
+
|
17 |
+
#include <cstdint>
|
18 |
+
#include <vector>
|
19 |
+
|
20 |
+
#include "sparse_matmul/compute/matmul.h"
|
21 |
+
|
22 |
+
namespace csrblocksparse {
|
23 |
+
namespace detail {
|
24 |
+
|
25 |
+
void MatVecFloatGeneric(const float* weights, const float* rhs,
|
26 |
+
const float* bias, const int32_t* nnz_per_row,
|
27 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
28 |
+
int block_height, int block_width, bool relu,
|
29 |
+
int replicas, int stride, float* output) {
|
30 |
+
int weight_index = 0;
|
31 |
+
int bias_index = 0;
|
32 |
+
std::vector<float> accumulators(block_height);
|
33 |
+
for (int row_block = start_row; row_block < end_row;
|
34 |
+
++row_block, output += block_height) {
|
35 |
+
int nnz = nnz_per_row[row_block];
|
36 |
+
// Biases are now stored and used directly without pre-division.
|
37 |
+
for (int i = 0; i < block_height; ++i) accumulators[i] = bias[bias_index++];
|
38 |
+
|
39 |
+
for (int c = 0; c < nnz; ++c) {
|
40 |
+
int rhs_index = rhs_indices[c];
|
41 |
+
const float* block_rhs = rhs + rhs_index * block_width;
|
42 |
+
// Multiply this |block_height| x |block_width| block.
|
43 |
+
for (int i = 0; i < block_height; ++i) {
|
44 |
+
for (int j = 0; j < block_width; ++j) {
|
45 |
+
accumulators[i] += weights[weight_index++] * block_rhs[j];
|
46 |
+
}
|
47 |
+
}
|
48 |
+
}
|
49 |
+
rhs_indices += nnz;
|
50 |
+
// Apply relu if desired.
|
51 |
+
if (relu) {
|
52 |
+
for (int i = 0; i < block_height; ++i) {
|
53 |
+
if (accumulators[i] < 0) accumulators[i] = 0;
|
54 |
+
}
|
55 |
+
}
|
56 |
+
for (int r = 0; r < replicas; ++r) {
|
57 |
+
for (int i = 0; i < block_height; ++i) {
|
58 |
+
output[i + r * stride] = accumulators[i];
|
59 |
+
}
|
60 |
+
}
|
61 |
+
}
|
62 |
+
}
|
63 |
+
|
64 |
+
void MatVecFixedGeneric(const int16_t* weights, const int16_t* rhs,
|
65 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
66 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
67 |
+
int block_height, int block_width, bool relu,
|
68 |
+
int bytes_out, int shift_out, int replicas, int stride,
|
69 |
+
void* output) {
|
70 |
+
int weight_index = 0;
|
71 |
+
int bias_index = 0;
|
72 |
+
std::vector<int32_t> accumulators(block_height);
|
73 |
+
for (int row_block = start_row; row_block < end_row; ++row_block) {
|
74 |
+
int nnz = nnz_per_row[row_block];
|
75 |
+
// Biases are now stored and used directly without pre-division.
|
76 |
+
for (int i = 0; i < block_height; ++i) accumulators[i] = bias[bias_index++];
|
77 |
+
|
78 |
+
for (int c = 0; c < nnz; ++c) {
|
79 |
+
int rhs_index = rhs_indices[c];
|
80 |
+
const int16_t* block_rhs = rhs + rhs_index * block_width;
|
81 |
+
// Multiply this |block_height| x |block_width| block.
|
82 |
+
for (int i = 0; i < block_height; ++i) {
|
83 |
+
for (int j = 0; j < block_width; ++j) {
|
84 |
+
accumulators[i] += weights[weight_index++] * block_rhs[j];
|
85 |
+
}
|
86 |
+
}
|
87 |
+
}
|
88 |
+
rhs_indices += nnz;
|
89 |
+
// Apply relu if desired.
|
90 |
+
if (relu) {
|
91 |
+
for (int i = 0; i < block_height; ++i) {
|
92 |
+
if (accumulators[i] < 0) accumulators[i] = 0;
|
93 |
+
}
|
94 |
+
}
|
95 |
+
// Output shift.
|
96 |
+
if (shift_out > 0) {
|
97 |
+
for (int i = 0; i < block_height; ++i) {
|
98 |
+
accumulators[i] >>= shift_out;
|
99 |
+
}
|
100 |
+
}
|
101 |
+
if (bytes_out == 2) {
|
102 |
+
int16_t* out16 = reinterpret_cast<int16_t*>(output);
|
103 |
+
output = out16 + block_height;
|
104 |
+
for (int r = 0; r < replicas; ++r, out16 += stride) {
|
105 |
+
for (int i = 0; i < block_height; ++i) {
|
106 |
+
out16[i] = accumulators[i];
|
107 |
+
}
|
108 |
+
}
|
109 |
+
} else {
|
110 |
+
int32_t* out32 = reinterpret_cast<int32_t*>(output);
|
111 |
+
output = out32 + block_height;
|
112 |
+
for (int r = 0; r < replicas; ++r, out32 += stride) {
|
113 |
+
for (int i = 0; i < block_height; ++i) {
|
114 |
+
out32[i] = accumulators[i];
|
115 |
+
}
|
116 |
+
}
|
117 |
+
}
|
118 |
+
}
|
119 |
+
}
|
120 |
+
|
121 |
+
} // namespace detail
|
122 |
+
} // namespace csrblocksparse
|
sparse_matmul/compute/matmul_generic.h
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_
|
19 |
+
|
20 |
+
#include <cstdint>
|
21 |
+
|
22 |
+
namespace csrblocksparse {
|
23 |
+
namespace detail {
|
24 |
+
|
25 |
+
// Generic version uses plain C++ code.
|
26 |
+
void MatVecFloatGeneric(const float* weights, const float* rhs,
|
27 |
+
const float* bias, const int32_t* nnz_per_row,
|
28 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
29 |
+
int block_height, int block_width, bool relu,
|
30 |
+
int replicas, int stride, float* output);
|
31 |
+
void MatVecFixedGeneric(const int16_t* weights, const int16_t* rhs,
|
32 |
+
const int32_t* bias, const int32_t* nnz_per_row,
|
33 |
+
const int16_t* rhs_indices, int start_row, int end_row,
|
34 |
+
int block_height, int block_width, bool relu,
|
35 |
+
int bytes_out, int shift_out, int replicas, int stride,
|
36 |
+
void* output);
|
37 |
+
|
38 |
+
} // namespace detail
|
39 |
+
} // namespace csrblocksparse
|
40 |
+
|
41 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_MATMUL_GENERIC_H_
|
sparse_matmul/compute/thread_bounds.cc
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2021 Google LLC
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.
|
14 |
+
|
15 |
+
#include "sparse_matmul/compute/thread_bounds.h"
|
16 |
+
|
17 |
+
#include <vector>
|
18 |
+
|
19 |
+
#include "glog/logging.h"
|
20 |
+
|
21 |
+
namespace csrblocksparse {
|
22 |
+
|
23 |
+
void ThreadBounds::PrepareForThreads(int block_width, int block_height,
|
24 |
+
int num_threads,
|
25 |
+
int reduced_rows_per_cache_row,
|
26 |
+
int reduced_rows, const int* nnz_per_row) {
|
27 |
+
CHECK_GT(num_threads, 0);
|
28 |
+
block_width_ = block_width;
|
29 |
+
block_height_ = block_height;
|
30 |
+
ComputeThreadSplitPoints(num_threads, reduced_rows_per_cache_row,
|
31 |
+
reduced_rows, nnz_per_row);
|
32 |
+
weight_starts_.clear();
|
33 |
+
rhs_indices_starts_.clear();
|
34 |
+
bias_starts_.clear();
|
35 |
+
weight_starts_.reserve(row_starts_.size());
|
36 |
+
rhs_indices_starts_.reserve(row_starts_.size());
|
37 |
+
bias_starts_.reserve(row_starts_.size());
|
38 |
+
|
39 |
+
// Compute the start indices of each of the types, given what we know about
|
40 |
+
// padding, and number of |nnz_per_row|.
|
41 |
+
int weight_index = 0;
|
42 |
+
int rhs_indices_index = 0;
|
43 |
+
int bias_index = 0;
|
44 |
+
int row = 0;
|
45 |
+
for (int start : row_starts_) {
|
46 |
+
while (row < start) {
|
47 |
+
weight_index += nnz_per_row[row] * block_width_ * block_height_;
|
48 |
+
rhs_indices_index += nnz_per_row[row];
|
49 |
+
bias_index += block_height_;
|
50 |
+
++row;
|
51 |
+
}
|
52 |
+
weight_starts_.push_back(weight_index);
|
53 |
+
rhs_indices_starts_.push_back(rhs_indices_index);
|
54 |
+
bias_starts_.push_back(bias_index);
|
55 |
+
}
|
56 |
+
}
|
57 |
+
|
58 |
+
// Computes the block row (reduced) index of the start of each thread.
|
59 |
+
void ThreadBounds::ComputeThreadSplitPoints(int num_threads,
|
60 |
+
int reduced_rows_per_cache_row,
|
61 |
+
int reduced_rows,
|
62 |
+
const int* nnz_per_row) {
|
63 |
+
row_starts_.assign(/*n=*/1, /*val=*/0);
|
64 |
+
// Break the rule if the matrix is too small to allow one per thread, which
|
65 |
+
// occurs only during tests.
|
66 |
+
if (reduced_rows_per_cache_row * num_threads > reduced_rows)
|
67 |
+
reduced_rows_per_cache_row = std::max(reduced_rows / num_threads, 1);
|
68 |
+
int cache_rows = (reduced_rows + reduced_rows_per_cache_row - 1) /
|
69 |
+
reduced_rows_per_cache_row;
|
70 |
+
|
71 |
+
// Compute exclusive prefix sum of the amount of work per row.
|
72 |
+
std::vector<int> work_upto_row(cache_rows + 1, 0);
|
73 |
+
int extra_row_work = 2 * reduced_rows_per_cache_row;
|
74 |
+
for (int i = 0; i < cache_rows; ++i) {
|
75 |
+
int new_nnz = 0;
|
76 |
+
for (int j = 0; j < reduced_rows_per_cache_row; ++j) {
|
77 |
+
// if |reduced_rows_per_cache_row| isn't an exact multiple of the
|
78 |
+
// matrix size, then we need to be careful here.
|
79 |
+
int index = i * reduced_rows_per_cache_row + j;
|
80 |
+
if (index < reduced_rows) new_nnz += nnz_per_row[index];
|
81 |
+
}
|
82 |
+
work_upto_row[i + 1] = new_nnz + extra_row_work + work_upto_row[i];
|
83 |
+
}
|
84 |
+
int total_work = work_upto_row.back();
|
85 |
+
// Find the split point point based on assigned approximately equal amount
|
86 |
+
// of work for each thread.
|
87 |
+
int prev_split = 0;
|
88 |
+
for (int i = 1; i <= num_threads; ++i) {
|
89 |
+
int split = std::distance(
|
90 |
+
work_upto_row.begin(),
|
91 |
+
std::lower_bound(work_upto_row.begin(), work_upto_row.end(),
|
92 |
+
i * total_work / num_threads));
|
93 |
+
int split_row = split * reduced_rows_per_cache_row;
|
94 |
+
if (i == num_threads) {
|
95 |
+
split_row = reduced_rows;
|
96 |
+
}
|
97 |
+
|
98 |
+
VLOG(2) << "tid=" << i - 1 << " num rows=" << split_row - row_starts_.back()
|
99 |
+
<< " work=" << work_upto_row[split] - work_upto_row[prev_split];
|
100 |
+
row_starts_.push_back(split_row);
|
101 |
+
prev_split = split;
|
102 |
+
}
|
103 |
+
VLOG(2) << "total rows=" << reduced_rows << " total work=" << total_work;
|
104 |
+
}
|
105 |
+
|
106 |
+
} // namespace csrblocksparse
|
sparse_matmul/compute/thread_bounds.h
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_
|
19 |
+
|
20 |
+
#include <vector>
|
21 |
+
|
22 |
+
namespace csrblocksparse {
|
23 |
+
|
24 |
+
// Class to compute and store the bounds of each thread used in a computation,
|
25 |
+
// and to provide corresponding spans of vectors.
|
26 |
+
class ThreadBounds {
|
27 |
+
public:
|
28 |
+
ThreadBounds() : block_width_(0), block_height_(0) {}
|
29 |
+
|
30 |
+
void PrepareForThreads(int block_width, int block_height, int num_threads,
|
31 |
+
int reduced_rows_per_cache_row, int reduced_rows,
|
32 |
+
const int* nnz_per_row);
|
33 |
+
|
34 |
+
// Functions that offset the appropriate type to the start of the data
|
35 |
+
// needed by the given thread id (|tid|).
|
36 |
+
template <typename WeightType>
|
37 |
+
const WeightType* OffsetWeights(const WeightType* weights, int tid) const {
|
38 |
+
return weights + weight_starts_[tid];
|
39 |
+
}
|
40 |
+
template <typename RhsIndType>
|
41 |
+
const RhsIndType* OffsetRhsIndices(const RhsIndType* rhs_indices,
|
42 |
+
int tid) const {
|
43 |
+
return rhs_indices + rhs_indices_starts_[tid];
|
44 |
+
}
|
45 |
+
template <typename BiasType>
|
46 |
+
const BiasType* OffsetBias(const BiasType* bias, int tid) const {
|
47 |
+
return bias + bias_starts_[tid];
|
48 |
+
}
|
49 |
+
template <typename OutType>
|
50 |
+
OutType* OffsetOutput(OutType* output, int tid) const {
|
51 |
+
return output + block_height_ * row_starts_[tid];
|
52 |
+
}
|
53 |
+
int StartRow(int tid) const { return row_starts_[tid]; }
|
54 |
+
const std::vector<int>& row_starts() const { return row_starts_; }
|
55 |
+
|
56 |
+
private:
|
57 |
+
// Computes the block row (reduced) index of the start of each thread.
|
58 |
+
void ComputeThreadSplitPoints(int num_threads, int reduced_rows_per_cache_row,
|
59 |
+
int reduced_rows, const int* nnz_per_row);
|
60 |
+
|
61 |
+
// Sizes of a sparse block.
|
62 |
+
int block_width_;
|
63 |
+
int block_height_;
|
64 |
+
// Start indices of each data type by thread-id with an extra value at the
|
65 |
+
// end.
|
66 |
+
std::vector<int> row_starts_;
|
67 |
+
std::vector<int> weight_starts_;
|
68 |
+
std::vector<int> rhs_indices_starts_;
|
69 |
+
std::vector<int> bias_starts_;
|
70 |
+
};
|
71 |
+
|
72 |
+
} // namespace csrblocksparse
|
73 |
+
|
74 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_COMPUTE_THREAD_BOUNDS_H_
|
sparse_matmul/layers/BUILD
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Sparse/Masked Matrix and Layer.
|
2 |
+
|
3 |
+
# [internal] load android_library_selector
|
4 |
+
# [internal] load android_cc_test:def.bzl
|
5 |
+
|
6 |
+
licenses(["notice"])
|
7 |
+
|
8 |
+
cc_library(
|
9 |
+
name = "layer",
|
10 |
+
hdrs = [
|
11 |
+
"sparse_linear_layer.h",
|
12 |
+
],
|
13 |
+
visibility = [
|
14 |
+
"//sparse_matmul:__subpackages__",
|
15 |
+
],
|
16 |
+
deps = [
|
17 |
+
":matrix",
|
18 |
+
"//sparse_matmul/numerics:types",
|
19 |
+
"//sparse_matmul/os:coop_threads",
|
20 |
+
"//sparse_matmul/vector:cache_aligned_vector",
|
21 |
+
"@com_google_absl//absl/memory",
|
22 |
+
"@com_google_absl//absl/strings:str_format",
|
23 |
+
"@com_google_glog//:glog",
|
24 |
+
],
|
25 |
+
)
|
26 |
+
|
27 |
+
cc_library(
|
28 |
+
name = "matrix",
|
29 |
+
hdrs = [
|
30 |
+
"csr_blocksparse_matrix.h",
|
31 |
+
"masked_sparse_matrix.h",
|
32 |
+
],
|
33 |
+
visibility = [
|
34 |
+
"//sparse_matmul:__subpackages__",
|
35 |
+
],
|
36 |
+
deps = [
|
37 |
+
"//sparse_matmul/compute:kernels",
|
38 |
+
"//sparse_matmul/compute:matmul",
|
39 |
+
"//sparse_matmul/compute:thread_bounds",
|
40 |
+
"//sparse_matmul/numerics:types",
|
41 |
+
"//sparse_matmul/os:coop_threads",
|
42 |
+
"//sparse_matmul/vector:cache_aligned_vector",
|
43 |
+
"@com_google_absl//absl/memory",
|
44 |
+
"@com_google_absl//absl/strings:str_format",
|
45 |
+
"@com_google_glog//:glog",
|
46 |
+
],
|
47 |
+
)
|
48 |
+
|
49 |
+
cc_library(
|
50 |
+
name = "utils",
|
51 |
+
srcs = [
|
52 |
+
"utils.cc",
|
53 |
+
],
|
54 |
+
hdrs = [
|
55 |
+
"read_array_ifstream.h",
|
56 |
+
"utils.h",
|
57 |
+
],
|
58 |
+
visibility = [
|
59 |
+
"//sparse_matmul:__subpackages__",
|
60 |
+
],
|
61 |
+
deps = [
|
62 |
+
":layer",
|
63 |
+
":matrix",
|
64 |
+
":status",
|
65 |
+
"//sparse_matmul/numerics:types",
|
66 |
+
"//sparse_matmul/vector:cache_aligned_vector",
|
67 |
+
"//sparse_matmul/zlib_wrapper",
|
68 |
+
"@com_google_absl//absl/status",
|
69 |
+
"@com_google_absl//absl/strings",
|
70 |
+
"@com_google_absl//absl/strings:cord",
|
71 |
+
"@gulrak_filesystem//:filesystem",
|
72 |
+
],
|
73 |
+
)
|
74 |
+
|
75 |
+
cc_library(
|
76 |
+
name = "status",
|
77 |
+
srcs = [
|
78 |
+
"errno_mapping.cc",
|
79 |
+
],
|
80 |
+
hdrs = [
|
81 |
+
"errno_mapping.h",
|
82 |
+
"status_macros.h",
|
83 |
+
],
|
84 |
+
deps = [
|
85 |
+
"@com_google_absl//absl/status",
|
86 |
+
"@com_google_absl//absl/status:statusor",
|
87 |
+
"@com_google_absl//absl/strings",
|
88 |
+
"@com_google_absl//absl/strings:cord",
|
89 |
+
],
|
90 |
+
)
|
91 |
+
|
92 |
+
cc_test(
|
93 |
+
name = "csrblocksparse_test",
|
94 |
+
size = "small",
|
95 |
+
srcs = [
|
96 |
+
"csrblocksparse_test.cc",
|
97 |
+
],
|
98 |
+
data = glob(["testdata/*"]),
|
99 |
+
linkopts = select({
|
100 |
+
"@bazel_tools//platforms:android": ["-landroid"],
|
101 |
+
"//conditions:default": [],
|
102 |
+
}),
|
103 |
+
shard_count = 10,
|
104 |
+
deps = [
|
105 |
+
":status",
|
106 |
+
":utils",
|
107 |
+
"//sparse_matmul/compute:matmul",
|
108 |
+
"//sparse_matmul/numerics:test_utils",
|
109 |
+
"//sparse_matmul/os:coop_threads",
|
110 |
+
"@com_google_absl//absl/status",
|
111 |
+
"@com_google_absl//absl/strings",
|
112 |
+
"@com_google_absl//absl/types:span",
|
113 |
+
"@com_google_googletest//:gtest_main",
|
114 |
+
"@gulrak_filesystem//:filesystem",
|
115 |
+
],
|
116 |
+
)
|
117 |
+
|
118 |
+
cc_test(
|
119 |
+
name = "sparse_linear_layer_test",
|
120 |
+
srcs = [
|
121 |
+
"sparse_linear_layer_test.cc",
|
122 |
+
],
|
123 |
+
deps = [
|
124 |
+
":layer",
|
125 |
+
"//sparse_matmul/numerics:test_utils",
|
126 |
+
"@com_google_googletest//:gtest_main",
|
127 |
+
],
|
128 |
+
)
|
129 |
+
|
130 |
+
cc_test(
|
131 |
+
name = "utils_test",
|
132 |
+
srcs = ["utils_test.cc"],
|
133 |
+
deps = [
|
134 |
+
":layer",
|
135 |
+
":matrix",
|
136 |
+
":status",
|
137 |
+
":utils",
|
138 |
+
"//sparse_matmul/numerics:fast_transcendentals",
|
139 |
+
"//sparse_matmul/numerics:test_utils",
|
140 |
+
"//sparse_matmul/numerics:types",
|
141 |
+
"//sparse_matmul/vector:cache_aligned_vector",
|
142 |
+
"@com_google_absl//absl/flags:flag",
|
143 |
+
"@com_google_googletest//:gtest_main",
|
144 |
+
"@gulrak_filesystem//:filesystem",
|
145 |
+
],
|
146 |
+
)
|
sparse_matmul/layers/csr_blocksparse_matrix.h
ADDED
@@ -0,0 +1,835 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_
|
19 |
+
|
20 |
+
#include <algorithm>
|
21 |
+
#include <cstdint>
|
22 |
+
#include <iostream>
|
23 |
+
#include <memory>
|
24 |
+
#include <tuple>
|
25 |
+
#include <vector>
|
26 |
+
|
27 |
+
#include "glog/logging.h"
|
28 |
+
// IWYU pragma: begin_exports
|
29 |
+
#include "sparse_matmul/compute/kernels_generic.h"
|
30 |
+
#include "sparse_matmul/compute/matmul.h"
|
31 |
+
#include "sparse_matmul/compute/thread_bounds.h"
|
32 |
+
#include "sparse_matmul/layers/masked_sparse_matrix.h"
|
33 |
+
#include "sparse_matmul/numerics/fixed_types.h"
|
34 |
+
#include "sparse_matmul/numerics/float16_types.h"
|
35 |
+
#include "sparse_matmul/os/coop_threads.h"
|
36 |
+
#include "sparse_matmul/vector/cache_aligned_vector.h"
|
37 |
+
// IWYU pragma: end_exports
|
38 |
+
#include "absl/memory/memory.h"
|
39 |
+
|
40 |
+
namespace csrblocksparse {
|
41 |
+
// CsrBlockSparseMatrix stores a modified block compressed sparse row
|
42 |
+
// representation of a sparse matrix. The ordering of the weights is modified
|
43 |
+
// in the 16x1 and 1x1 cases so that a certain number (4 and 8 respectively)
|
44 |
+
// of columns of weights are stored contiguously before moving on to the next
|
45 |
+
// row. The 4x4 case stores each block contiguously.
|
46 |
+
//
|
47 |
+
// Currently it is constructed from a MaskedSparseMatrix which usees a dense
|
48 |
+
// binary mask representation. The construction generates the compressed
|
49 |
+
// representation. Further iterations will support a direct serialization
|
50 |
+
// of the compressed representation.
|
51 |
+
//
|
52 |
+
// MaskedSparseMatrix masked_matrix(rows, cols, existing_mask, existing_values)
|
53 |
+
// CsrBlockSparseMatrix matrix(masked_matrix)
|
54 |
+
//
|
55 |
+
// matrix.SpMV_bias(rhs, bias, &out);
|
56 |
+
//
|
57 |
+
// This class is thread compatible.
|
58 |
+
template <typename WeightType, typename RhsType, typename DeltaType = int16_t>
|
59 |
+
class CsrBlockSparseMatrix {
|
60 |
+
public:
|
61 |
+
CsrBlockSparseMatrix() {}
|
62 |
+
|
63 |
+
// Reference used to indicate that this is an input and not an output.
|
64 |
+
CsrBlockSparseMatrix(const uint8_t* const& buffer, const std::size_t& len) {
|
65 |
+
ReadFromFlatBuffer(buffer, len);
|
66 |
+
ComputeRHSIndices();
|
67 |
+
}
|
68 |
+
|
69 |
+
template <typename InputType>
|
70 |
+
CsrBlockSparseMatrix(const MaskedSparseMatrix<InputType>& masked_matrix) {
|
71 |
+
sparsity_ = masked_matrix.sparsity();
|
72 |
+
rows_ = masked_matrix.rows();
|
73 |
+
cols_ = masked_matrix.cols();
|
74 |
+
|
75 |
+
DetermineBlockSize(masked_matrix);
|
76 |
+
|
77 |
+
if (block_width_ == 1 && block_height_ == 1)
|
78 |
+
col_multiple_ = 8;
|
79 |
+
else
|
80 |
+
col_multiple_ = 1;
|
81 |
+
|
82 |
+
std::vector<InputType> weights(masked_matrix.values().begin(),
|
83 |
+
masked_matrix.values().end());
|
84 |
+
|
85 |
+
reduced_rows_ = (rows_ + block_height_ - 1) / block_height_;
|
86 |
+
rows_ = reduced_rows_ * block_height_;
|
87 |
+
reduced_cols_ = cols_ / block_width_;
|
88 |
+
|
89 |
+
// Calculate the reduced CSR representation of the matrix.
|
90 |
+
std::vector<int> reduced_mask(reduced_rows_ * reduced_cols_);
|
91 |
+
std::vector<int> row_offsets = {0};
|
92 |
+
int nnz = 0;
|
93 |
+
const auto& mask = masked_matrix.mask();
|
94 |
+
for (int r = 0; r < reduced_rows_; ++r) {
|
95 |
+
for (int c = 0; c < reduced_cols_; ++c) {
|
96 |
+
int mask_val = mask[r * block_height_ * cols_ + c * block_width_];
|
97 |
+
reduced_mask[r * reduced_cols_ + c] = mask_val;
|
98 |
+
nnz += mask_val;
|
99 |
+
}
|
100 |
+
row_offsets.push_back(nnz);
|
101 |
+
}
|
102 |
+
|
103 |
+
// Make sure the reduced representation has the correct number of columns.
|
104 |
+
MakeColumnsMultiple(row_offsets, &reduced_mask, &weights);
|
105 |
+
|
106 |
+
std::vector<int> col_indices;
|
107 |
+
std::vector<WeightType> weights_csr;
|
108 |
+
std::vector<int> nnz_per_row;
|
109 |
+
MaskAndWeightsToCsr(reduced_mask, weights, &nnz_per_row, &col_indices,
|
110 |
+
&weights_csr);
|
111 |
+
|
112 |
+
// Generate column deltas from |col_indices|.
|
113 |
+
std::vector<DeltaType> col_deltas;
|
114 |
+
for (int i = 0; i < col_indices.size(); ++i) {
|
115 |
+
// |col_indices| are used to index the RHS vector which is always float.
|
116 |
+
int64_t diff = sizeof(RhsType);
|
117 |
+
if (i == 0)
|
118 |
+
diff *= block_width_ * (col_indices[i]);
|
119 |
+
else
|
120 |
+
diff *= block_width_ * (col_indices[i] - col_indices[i - 1]);
|
121 |
+
|
122 |
+
CHECK(diff < std::numeric_limits<DeltaType>::max())
|
123 |
+
<< "delta between column indices in bytes " << diff
|
124 |
+
<< " exceeded the maximum size of the DeltaType "
|
125 |
+
<< std::numeric_limits<DeltaType>::max();
|
126 |
+
col_deltas.push_back(static_cast<DeltaType>(diff));
|
127 |
+
}
|
128 |
+
|
129 |
+
// Because of pre-fetching we need some extra values at the end.
|
130 |
+
col_deltas.insert(col_deltas.end(), std::max(2, col_multiple_ + 1), 0);
|
131 |
+
nnz_per_row.insert(nnz_per_row.end(), 2, nnz_per_row.back());
|
132 |
+
|
133 |
+
weights_ = CacheAlignedVector<WeightType>(weights_csr);
|
134 |
+
col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas);
|
135 |
+
nnz_per_row_ = CacheAlignedVector<int>(nnz_per_row);
|
136 |
+
ComputeRHSIndices();
|
137 |
+
|
138 |
+
num_threads_ = 0;
|
139 |
+
PrepareForThreads(1);
|
140 |
+
}
|
141 |
+
|
142 |
+
// Constructor makes a matrix from the given weights, deltas and nnz, taking
|
143 |
+
// the other parameters from |src_matrix|. |cols| is the number of raw columns
|
144 |
+
// (NOT blocks) of the new matrix.
|
145 |
+
CsrBlockSparseMatrix(
|
146 |
+
const CsrBlockSparseMatrix<WeightType, RhsType, DeltaType>& src_matrix,
|
147 |
+
const std::vector<WeightType>& new_weights,
|
148 |
+
const std::vector<DeltaType>& new_deltas, const std::vector<int>& new_nnz,
|
149 |
+
int cols) {
|
150 |
+
num_threads_ = 0;
|
151 |
+
col_multiple_ = src_matrix.col_multiple_;
|
152 |
+
block_width_ = src_matrix.block_width_;
|
153 |
+
block_height_ = src_matrix.block_height_;
|
154 |
+
reduced_rows_ = new_nnz.size();
|
155 |
+
rows_ = reduced_rows_ * block_height_;
|
156 |
+
cols_ = cols;
|
157 |
+
reduced_cols_ = cols_ / block_width_;
|
158 |
+
weights_ = CacheAlignedVector<WeightType>(new_weights);
|
159 |
+
col_deltas_ = CacheAlignedVector<DeltaType>(new_deltas);
|
160 |
+
nnz_per_row_ = CacheAlignedVector<int>(new_nnz);
|
161 |
+
sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_);
|
162 |
+
ComputeRHSIndices();
|
163 |
+
name_ = src_matrix.name_;
|
164 |
+
PrepareForThreads(1);
|
165 |
+
}
|
166 |
+
|
167 |
+
// Factory method takes a column slice out of *this and returns a sparse
|
168 |
+
// matrix that takes as inputs [|start_col|, |end_col|) of *this, and
|
169 |
+
// returns the same number of outputs, but only a partial result.
|
170 |
+
// If |keep_rhs_size|, then the new matrix takes the same rhs as the current
|
171 |
+
// matrix, but uses a subset of it, instead of expecting just the reduced rhs.
|
172 |
+
// If |start_col| > |end_col|, then we slice out the complement of the defined
|
173 |
+
// interval, ie [0, |end_col|) + [|start_col|, current end).
|
174 |
+
// NOTE That |start_col| and |end_col| are in raw column coordinates, NOT
|
175 |
+
// block units.
|
176 |
+
CsrBlockSparseMatrix SplitByColumn(int start_col, int end_col,
|
177 |
+
bool keep_rhs_size = false) const {
|
178 |
+
int weight_index = 0;
|
179 |
+
int delta_index = 0;
|
180 |
+
std::vector<DeltaType> new_deltas;
|
181 |
+
std::vector<WeightType> new_weights;
|
182 |
+
std::vector<int> new_nnz(reduced_rows_);
|
183 |
+
int col = 0;
|
184 |
+
int prev_col = keep_rhs_size ? 0 : start_col;
|
185 |
+
for (int r = 0; r < reduced_rows_; ++r) {
|
186 |
+
int reduced_col_count = nnz_per_row_[r];
|
187 |
+
for (int c = 0; c < reduced_col_count; ++c, ++delta_index) {
|
188 |
+
col += col_deltas_[delta_index] / sizeof(RhsType);
|
189 |
+
if ((start_col < end_col && start_col <= col && col < end_col) ||
|
190 |
+
(start_col > end_col && (col < end_col || col >= start_col))) {
|
191 |
+
++new_nnz[r];
|
192 |
+
new_deltas.push_back((col - prev_col) * sizeof(RhsType));
|
193 |
+
prev_col = col;
|
194 |
+
for (int i = 0; i < block_width_ * block_height_;
|
195 |
+
++i, ++weight_index) {
|
196 |
+
new_weights.push_back(weights_[weight_index]);
|
197 |
+
}
|
198 |
+
} else {
|
199 |
+
weight_index += block_width_ * block_height_;
|
200 |
+
}
|
201 |
+
}
|
202 |
+
}
|
203 |
+
int new_cols = keep_rhs_size ? cols_ : end_col - start_col;
|
204 |
+
return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz,
|
205 |
+
new_cols);
|
206 |
+
}
|
207 |
+
|
208 |
+
// Factory method takes a row slice out of *this and returns a sparse
|
209 |
+
// matrix that takes the sampe inputs as *this, and returns the outputs for
|
210 |
+
// the range [|start_row|, |end_row|).
|
211 |
+
// NOTE That |start_row| and |end_row| are in raw column coordinates, NOT
|
212 |
+
// block units.
|
213 |
+
CsrBlockSparseMatrix SplitByRow(int start_row, int end_row) const {
|
214 |
+
int start_reduced = start_row / block_height_;
|
215 |
+
int end_reduced = end_row / block_height_;
|
216 |
+
std::vector<int> new_nnz(nnz_per_row_.data() + start_reduced,
|
217 |
+
nnz_per_row_.data() + end_reduced);
|
218 |
+
int weight_start = 0;
|
219 |
+
for (int r = 0; r < start_reduced; ++r) {
|
220 |
+
weight_start += nnz_per_row_[r];
|
221 |
+
}
|
222 |
+
int weight_end = weight_start;
|
223 |
+
for (int r = start_reduced; r < end_reduced; ++r) {
|
224 |
+
weight_end += nnz_per_row_[r];
|
225 |
+
}
|
226 |
+
int delta_start = 0;
|
227 |
+
for (int i = 0; i < weight_start; ++i) {
|
228 |
+
delta_start += col_deltas_[i];
|
229 |
+
}
|
230 |
+
std::vector<DeltaType> new_deltas(col_deltas_.data() + weight_start,
|
231 |
+
col_deltas_.data() + weight_end);
|
232 |
+
new_deltas[0] += delta_start;
|
233 |
+
int block_size = block_height_ * block_width_;
|
234 |
+
std::vector<WeightType> new_weights(
|
235 |
+
weights_.data() + weight_start * block_size,
|
236 |
+
weights_.data() + weight_end * block_size);
|
237 |
+
return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, cols_);
|
238 |
+
}
|
239 |
+
|
240 |
+
// Combines adjacent row blocks, doubling the block height.
|
241 |
+
// This necessarily involves adding zero weights where the blocks don't align
|
242 |
+
// across adjacent pairs of rows, so use with caution, as the resulting matrix
|
243 |
+
// is most likely to run slower if very sparse to begin with.
|
244 |
+
// In the few cases where the blocks do mostly align, the resulting matmul
|
245 |
+
// could be much faster, as the number of reads of the rhs will be halved.
|
246 |
+
void DoubleBlockHeight() {
|
247 |
+
int new_rows = reduced_rows_ / 2;
|
248 |
+
std::vector<int> new_nnz(new_rows);
|
249 |
+
std::vector<DeltaType> new_rhs_indices;
|
250 |
+
std::vector<WeightType> new_weights;
|
251 |
+
int rhs_index1 = 0;
|
252 |
+
int rhs_index2 = 0;
|
253 |
+
int block_size = block_height_ * block_width_;
|
254 |
+
for (int r = 0; r < new_rows; ++r) {
|
255 |
+
int start_nnz = new_rhs_indices.size();
|
256 |
+
rhs_index2 += nnz_per_row_[r * 2];
|
257 |
+
int end1 = rhs_index1 + nnz_per_row_[r * 2];
|
258 |
+
int end2 = rhs_index2 + nnz_per_row_[r * 2 + 1];
|
259 |
+
// Run over a pair of rows with 2 iterators, combining blocks as we go, or
|
260 |
+
// padding with zeros where the block positions don't match.
|
261 |
+
while (rhs_index1 < end1 || rhs_index2 < end2) {
|
262 |
+
int col1 = rhs_index1 < end1 ? rhs_indices_[rhs_index1] : reduced_cols_;
|
263 |
+
int col2 = rhs_index2 < end2 ? rhs_indices_[rhs_index2] : reduced_cols_;
|
264 |
+
if (col1 < col2) {
|
265 |
+
// Need zero weights for row2 to pad out weights block.
|
266 |
+
new_rhs_indices.push_back(col1);
|
267 |
+
new_weights.insert(new_weights.end(),
|
268 |
+
weights_.data() + rhs_index1 * block_size,
|
269 |
+
weights_.data() + (rhs_index1 + 1) * block_size);
|
270 |
+
new_weights.insert(new_weights.end(), block_size,
|
271 |
+
static_cast<WeightType>(0.0f));
|
272 |
+
++rhs_index1;
|
273 |
+
} else if (col1 > col2) {
|
274 |
+
// Need zero weights for row1 to pad out weights block.
|
275 |
+
new_rhs_indices.push_back(col2);
|
276 |
+
new_weights.insert(new_weights.end(), block_size,
|
277 |
+
static_cast<WeightType>(0.0f));
|
278 |
+
new_weights.insert(new_weights.end(),
|
279 |
+
weights_.data() + rhs_index2 * block_size,
|
280 |
+
weights_.data() + (rhs_index2 + 1) * block_size);
|
281 |
+
++rhs_index2;
|
282 |
+
} else {
|
283 |
+
// Combine weights for both row1 and row2.
|
284 |
+
new_rhs_indices.push_back(col1);
|
285 |
+
new_weights.insert(new_weights.end(),
|
286 |
+
weights_.data() + rhs_index1 * block_size,
|
287 |
+
weights_.data() + (rhs_index1 + 1) * block_size);
|
288 |
+
new_weights.insert(new_weights.end(),
|
289 |
+
weights_.data() + rhs_index2 * block_size,
|
290 |
+
weights_.data() + (rhs_index2 + 1) * block_size);
|
291 |
+
++rhs_index1;
|
292 |
+
++rhs_index2;
|
293 |
+
}
|
294 |
+
}
|
295 |
+
rhs_index1 = rhs_index2;
|
296 |
+
new_nnz[r] = new_rhs_indices.size() - start_nnz;
|
297 |
+
}
|
298 |
+
block_height_ *= 2;
|
299 |
+
reduced_rows_ /= 2;
|
300 |
+
weights_ = CacheAlignedVector<WeightType>(new_weights);
|
301 |
+
rhs_indices_ = CacheAlignedVector<DeltaType>(new_rhs_indices);
|
302 |
+
nnz_per_row_ = CacheAlignedVector<int>(new_nnz);
|
303 |
+
sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_);
|
304 |
+
ComputeColDeltas();
|
305 |
+
if (num_threads_ > 0) {
|
306 |
+
int num_threads = num_threads_;
|
307 |
+
num_threads_ = 0;
|
308 |
+
PrepareForThreads(num_threads);
|
309 |
+
}
|
310 |
+
}
|
311 |
+
|
312 |
+
// Allocates memory and fills buffer.
|
313 |
+
// Caller is responsible for the memory de-allocation.
|
314 |
+
// TODO(b/189958858): Both Read and Write need to eventually handle the
|
315 |
+
// different possible HalfType and DeltaType values, but punting for now as
|
316 |
+
// there is only one supported combination.
|
317 |
+
std::size_t WriteToFlatBuffer(std::string* csr_flatbuffer) {
|
318 |
+
std::size_t bytes = 0;
|
319 |
+
bytes += FixedParameterSize();
|
320 |
+
bytes += weights_.size() * sizeof(WeightType);
|
321 |
+
bytes += col_deltas_.size() * sizeof(DeltaType);
|
322 |
+
bytes += nnz_per_row_.size() * sizeof(int);
|
323 |
+
|
324 |
+
uint8_t* bytes_ptr_ptr =
|
325 |
+
reinterpret_cast<uint8_t*>(CHECK_NOTNULL(malloc(bytes)));
|
326 |
+
|
327 |
+
int* int_bytes_ptr = reinterpret_cast<int*>(bytes_ptr_ptr);
|
328 |
+
|
329 |
+
*int_bytes_ptr++ = rows_;
|
330 |
+
*int_bytes_ptr++ = cols_;
|
331 |
+
*int_bytes_ptr++ = reduced_rows_;
|
332 |
+
*int_bytes_ptr++ = reduced_cols_;
|
333 |
+
*int_bytes_ptr++ = block_width_;
|
334 |
+
*int_bytes_ptr++ = block_height_;
|
335 |
+
*int_bytes_ptr++ = col_multiple_;
|
336 |
+
*int_bytes_ptr++ = num_threads_;
|
337 |
+
*int_bytes_ptr++ = weights_.size();
|
338 |
+
*int_bytes_ptr++ = col_deltas_.size();
|
339 |
+
*int_bytes_ptr++ = nnz_per_row_.size();
|
340 |
+
|
341 |
+
float* float_bytes_ptr = reinterpret_cast<float*>(int_bytes_ptr);
|
342 |
+
*float_bytes_ptr++ = sparsity_;
|
343 |
+
|
344 |
+
uint8_t* bytes_ptr = reinterpret_cast<uint8_t*>(float_bytes_ptr);
|
345 |
+
|
346 |
+
memcpy(bytes_ptr, weights_.data(), weights_.size() * sizeof(WeightType));
|
347 |
+
bytes_ptr += weights_.size() * sizeof(WeightType);
|
348 |
+
|
349 |
+
memcpy(bytes_ptr, col_deltas_.data(),
|
350 |
+
col_deltas_.size() * sizeof(DeltaType));
|
351 |
+
bytes_ptr += col_deltas_.size() * sizeof(DeltaType);
|
352 |
+
|
353 |
+
memcpy(bytes_ptr, nnz_per_row_.data(), nnz_per_row_.size() * sizeof(int));
|
354 |
+
bytes_ptr += nnz_per_row_.size() * sizeof(int);
|
355 |
+
|
356 |
+
csr_flatbuffer->resize(bytes);
|
357 |
+
csr_flatbuffer->assign(reinterpret_cast<char*>(bytes_ptr_ptr), bytes);
|
358 |
+
free(bytes_ptr_ptr);
|
359 |
+
|
360 |
+
return bytes;
|
361 |
+
}
|
362 |
+
|
363 |
+
void ReadFromFlatBuffer(const uint8_t* const& bytes, const std::size_t& len) {
|
364 |
+
CHECK_GE(len, FixedParameterSize());
|
365 |
+
|
366 |
+
const int* int_bytes_ptr = reinterpret_cast<const int*>(bytes);
|
367 |
+
rows_ = *int_bytes_ptr++;
|
368 |
+
cols_ = *int_bytes_ptr++;
|
369 |
+
reduced_rows_ = *int_bytes_ptr++;
|
370 |
+
reduced_cols_ = *int_bytes_ptr++;
|
371 |
+
block_width_ = *int_bytes_ptr++;
|
372 |
+
block_height_ = *int_bytes_ptr++;
|
373 |
+
col_multiple_ = *int_bytes_ptr++;
|
374 |
+
int num_threads = *int_bytes_ptr++;
|
375 |
+
int32_t weights_size = *int_bytes_ptr++;
|
376 |
+
int32_t col_deltas_size = *int_bytes_ptr++;
|
377 |
+
int32_t nnz_per_row_size = *int_bytes_ptr++;
|
378 |
+
|
379 |
+
// Make sure negative sizes don't mess things up.
|
380 |
+
weights_size = std::max(0, weights_size);
|
381 |
+
col_deltas_size = std::max(0, col_deltas_size);
|
382 |
+
nnz_per_row_size = std::max(0, nnz_per_row_size);
|
383 |
+
|
384 |
+
const float* float_bytes_ptr =
|
385 |
+
reinterpret_cast<const float*>(int_bytes_ptr);
|
386 |
+
sparsity_ = *float_bytes_ptr++;
|
387 |
+
|
388 |
+
std::size_t total_bytes =
|
389 |
+
FixedParameterSize() + weights_size * sizeof(WeightType) +
|
390 |
+
col_deltas_size * sizeof(DeltaType) + nnz_per_row_size * sizeof(int);
|
391 |
+
|
392 |
+
CHECK_EQ(total_bytes, len)
|
393 |
+
<< "total bytes: " << total_bytes << ", actual len given: " << len;
|
394 |
+
|
395 |
+
const uint8_t* bytes_ptr =
|
396 |
+
reinterpret_cast<const uint8_t*>(float_bytes_ptr);
|
397 |
+
std::vector<WeightType> weights_raw(weights_size);
|
398 |
+
memcpy(weights_raw.data(), bytes_ptr, weights_size * sizeof(WeightType));
|
399 |
+
weights_ = CacheAlignedVector<WeightType>(weights_raw);
|
400 |
+
bytes_ptr += weights_size * sizeof(WeightType);
|
401 |
+
|
402 |
+
std::vector<DeltaType> deltas_raw(col_deltas_size);
|
403 |
+
memcpy(deltas_raw.data(), bytes_ptr, col_deltas_size * sizeof(DeltaType));
|
404 |
+
col_deltas_ = CacheAlignedVector<DeltaType>(deltas_raw);
|
405 |
+
bytes_ptr += col_deltas_size * sizeof(DeltaType);
|
406 |
+
|
407 |
+
std::vector<int> nnz_raw(nnz_per_row_size);
|
408 |
+
memcpy(nnz_raw.data(), bytes_ptr, nnz_per_row_size * sizeof(int));
|
409 |
+
nnz_per_row_ = CacheAlignedVector<int>(nnz_raw);
|
410 |
+
num_threads_ = 0;
|
411 |
+
PrepareForThreads(num_threads);
|
412 |
+
}
|
413 |
+
|
414 |
+
// Multiply a Sparse matrix by a possibly dense matrix. Often the matrix is
|
415 |
+
// a vector with a small number of columns, hence the term "fat vector".
|
416 |
+
// 1x1 and 4x4 have specializations for output columns (ie fatness) > 5,
|
417 |
+
// and often achieve twice as many GFlops when multiplying a right hand side
|
418 |
+
// that has 5 or more columns. (Best is a multiple of 5).
|
419 |
+
// 16x1 doesn't have enough registers and just loops over the width 1 kernel.
|
420 |
+
//
|
421 |
+
// |rhs| and |out| are COLUMN MAJOR.
|
422 |
+
|
423 |
+
// Fast Tuples WeightType, BiasType, RhsType, OutType are:
|
424 |
+
// (float, float, float, float)
|
425 |
+
// (bfloat16, float, float, float)
|
426 |
+
// and only on ARM64. All other cases use a slow generic implementation.
|
427 |
+
template <typename RhsClass, typename BiasClass, typename OutClass,
|
428 |
+
typename BiasType = typename BiasClass::value_type,
|
429 |
+
typename OutType = typename OutClass::value_type>
|
430 |
+
void SpMM_bias(const RhsClass& rhs, const BiasClass& bias, OutClass* out,
|
431 |
+
bool relu = false, int tid = 0,
|
432 |
+
SpinBarrier* barrier = nullptr) const {
|
433 |
+
static_assert(std::is_same<typename RhsClass::value_type, RhsType>::value,
|
434 |
+
"Rhs types must match");
|
435 |
+
CHECK_LT(tid, num_threads_);
|
436 |
+
CHECK_EQ(rhs.cols(), out->cols());
|
437 |
+
CHECK_EQ(rhs.rows(), cols_);
|
438 |
+
CHECK_GE(out->rows(), rows_);
|
439 |
+
int cols_to_go = out->cols();
|
440 |
+
int rhs_index = *thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid);
|
441 |
+
const RhsType* rhs_ptr = rhs.data() + rhs_index * block_height_;
|
442 |
+
OutType* out_ptr = thread_bounds_.OffsetOutput(out->data(), tid);
|
443 |
+
const WeightType* weights_ptr =
|
444 |
+
thread_bounds_.OffsetWeights(weights_.data(), tid);
|
445 |
+
const DeltaType* delta_ptr =
|
446 |
+
thread_bounds_.OffsetRhsIndices(col_deltas_.data(), tid);
|
447 |
+
int offset = *delta_ptr / sizeof(RhsType);
|
448 |
+
rhs_ptr -= offset;
|
449 |
+
const int* nnz_ptr = nnz_per_row_.data() + thread_bounds_.StartRow(tid);
|
450 |
+
int assigned_rows =
|
451 |
+
thread_bounds_.StartRow(tid + 1) - thread_bounds_.StartRow(tid);
|
452 |
+
const BiasType* bias_ptr = thread_bounds_.OffsetBias(bias.data(), tid);
|
453 |
+
|
454 |
+
while (cols_to_go > 0) {
|
455 |
+
if (block_width_ == 4 && block_height_ == 4) {
|
456 |
+
if (cols_to_go >= 5) {
|
457 |
+
detail::SpMM5_4x4<WeightType, RhsType, OutType>(
|
458 |
+
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
|
459 |
+
assigned_rows, out->col_stride(), rhs.col_stride(), relu);
|
460 |
+
} else {
|
461 |
+
detail::SpMV_4x4<WeightType, RhsType, OutType>(
|
462 |
+
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
|
463 |
+
assigned_rows, out->col_stride(), rhs.col_stride(), relu);
|
464 |
+
}
|
465 |
+
} else {
|
466 |
+
if (cols_to_go >= 5) {
|
467 |
+
detail::SpMM5_1x1<WeightType, RhsType, OutType>(
|
468 |
+
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
|
469 |
+
assigned_rows, out->col_stride(), rhs.col_stride(), relu);
|
470 |
+
} else {
|
471 |
+
detail::SpMV_1x1<WeightType, RhsType, OutType>(
|
472 |
+
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr,
|
473 |
+
assigned_rows, out->col_stride(), rhs.col_stride(), relu);
|
474 |
+
}
|
475 |
+
}
|
476 |
+
|
477 |
+
if (cols_to_go >= 5) {
|
478 |
+
cols_to_go -= 5;
|
479 |
+
rhs_ptr += rhs.col_stride() * 5;
|
480 |
+
out_ptr += out->col_stride() * 5;
|
481 |
+
} else {
|
482 |
+
cols_to_go--;
|
483 |
+
rhs_ptr += rhs.col_stride();
|
484 |
+
out_ptr += out->col_stride();
|
485 |
+
}
|
486 |
+
if (barrier) barrier->barrier();
|
487 |
+
}
|
488 |
+
}
|
489 |
+
template <typename MVRhsType, typename MVBiasType, typename OutType>
|
490 |
+
void MatVec(const MVRhsType* rhs, const MVBiasType* bias, bool relu, int tid,
|
491 |
+
int replicas, int output_stride, OutType* output) {
|
492 |
+
CHECK_LT(tid, num_threads_);
|
493 |
+
CHECK_EQ(block_width_, 4) << "Block width must be 4!";
|
494 |
+
if (block_height_ == 8) {
|
495 |
+
matmul_.MatVec8x4(
|
496 |
+
thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs,
|
497 |
+
thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(),
|
498 |
+
thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid),
|
499 |
+
thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu,
|
500 |
+
replicas, output_stride, thread_bounds_.OffsetOutput(output, tid));
|
501 |
+
} else {
|
502 |
+
CHECK_EQ(block_height_, 4) << "Block height must be 4 or 8!";
|
503 |
+
matmul_.MatVec4x4(
|
504 |
+
thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs,
|
505 |
+
thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(),
|
506 |
+
thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid),
|
507 |
+
thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu,
|
508 |
+
replicas, output_stride, thread_bounds_.OffsetOutput(output, tid));
|
509 |
+
}
|
510 |
+
}
|
511 |
+
|
512 |
+
int rows() const { return rows_; }
|
513 |
+
int cols() const { return cols_; }
|
514 |
+
int block_height() const { return block_height_; }
|
515 |
+
int block_width() const { return block_width_; }
|
516 |
+
float sparsity() const { return sparsity_; }
|
517 |
+
int num_threads() const { return num_threads_; }
|
518 |
+
const ThreadBounds& thread_bounds() const { return thread_bounds_; }
|
519 |
+
const CacheAlignedVector<DeltaType>& rhs_indices() const {
|
520 |
+
return rhs_indices_;
|
521 |
+
}
|
522 |
+
const std::string& name() const { return name_; }
|
523 |
+
void set_name(const std::string& name) { name_ = name; }
|
524 |
+
const std::vector<int>& split_points() const {
|
525 |
+
return thread_bounds_.row_starts();
|
526 |
+
}
|
527 |
+
|
528 |
+
std::size_t bytes() const {
|
529 |
+
return weights_.size() * sizeof(WeightType) +
|
530 |
+
col_deltas_.size() * sizeof(DeltaType) +
|
531 |
+
nnz_per_row_.size() * sizeof(int);
|
532 |
+
}
|
533 |
+
|
534 |
+
// Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above,
|
535 |
+
// and then samples from the output (softmax distribution) layer.
|
536 |
+
template <typename RhsClass, typename BiasClass, typename OutClass,
|
537 |
+
typename BiasType = typename BiasClass::value_type,
|
538 |
+
typename OutType = typename OutClass::value_type>
|
539 |
+
typename std::enable_if<!IsFixed32Type<OutType>::value, int>::type
|
540 |
+
SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out,
|
541 |
+
float temperature, int tid, SpinBarrier* barrier,
|
542 |
+
std::minstd_rand* gen,
|
543 |
+
CacheAlignedVector<float>* scratch) const {
|
544 |
+
SpMM_bias(rhs, bias, out, /*relu=*/false, tid, barrier);
|
545 |
+
return out->Sample(temperature, gen, scratch);
|
546 |
+
}
|
547 |
+
// Fixed32 version.
|
548 |
+
template <typename RhsClass, typename BiasClass, typename OutClass,
|
549 |
+
typename BiasType = typename BiasClass::value_type,
|
550 |
+
typename OutType = typename OutClass::value_type>
|
551 |
+
typename std::enable_if<IsFixed32Type<OutType>::value, int>::type
|
552 |
+
SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out,
|
553 |
+
float temperature, int tid, SpinBarrier* barrier,
|
554 |
+
std::minstd_rand* gen,
|
555 |
+
CacheAlignedVector<float>* scratch) const {
|
556 |
+
// We don't pass the barrier on, as we have more work to do.
|
557 |
+
SpMM_bias(rhs, bias, out, /*relu=*/false, tid);
|
558 |
+
return out->ReducingSample(gen, scratch, tid, temperature, barrier);
|
559 |
+
}
|
560 |
+
|
561 |
+
void Print() const {
|
562 |
+
std::cout << "Weights\n";
|
563 |
+
weights_.Print();
|
564 |
+
std::cout << std::endl;
|
565 |
+
std::cout << "Deltas\n";
|
566 |
+
col_deltas_.Print();
|
567 |
+
std::cout << std::endl;
|
568 |
+
std::cout << "nnz\n";
|
569 |
+
nnz_per_row_.Print();
|
570 |
+
std::cout << std::endl;
|
571 |
+
}
|
572 |
+
|
573 |
+
// Split the computation amongst threads by rows based on the number of
|
574 |
+
// non zeros, with the addition of a constant to account for the work of the
|
575 |
+
// bias and the horizontal add at the end, and also guarantees that each
|
576 |
+
// thread writes only whole cache lines, based on the size of OutType.
|
577 |
+
// The |cache_line_size| arg is used only for testing. Normally it is provided
|
578 |
+
// through the architecture #defines.
|
579 |
+
// Each thread gets a contiguous row range (|split_points|).
|
580 |
+
// Thread t does rows [ split_points[t], split_points[t + 1] )
|
581 |
+
// Each thread also needs to know how many non zeros were before it to skip
|
582 |
+
// (|nnz_to_skip|). And finally it also needs to know what the offset into
|
583 |
+
// the rhs vector would have been at the split point (|rhs_to_skip|).
|
584 |
+
//
|
585 |
+
// Some tricky corner cases where the number of non-zeros doesn't split
|
586 |
+
// nicely amongst the number of requested threads are not handled and default
|
587 |
+
// to one thread; these cases are only going to happen in tests and not in
|
588 |
+
// the matrices that correspond in real models.
|
589 |
+
//
|
590 |
+
// Returns the maximum number of threads that can be used; <= |num_threads|.
|
591 |
+
template <typename OutType = int32_t>
|
592 |
+
int PrepareForThreads(int num_threads, int cache_line_size = -1) {
|
593 |
+
CHECK_GT(num_threads, 0);
|
594 |
+
// we've already prepared for this number of threads, nothing to do
|
595 |
+
if (num_threads == num_threads_) return num_threads_;
|
596 |
+
|
597 |
+
num_threads_ = num_threads;
|
598 |
+
thread_bounds_.PrepareForThreads(
|
599 |
+
block_width_, block_height_, num_threads_,
|
600 |
+
ReducedRowsPerCacheLine<OutType>(cache_line_size), reduced_rows_,
|
601 |
+
nnz_per_row_.data());
|
602 |
+
return num_threads_;
|
603 |
+
}
|
604 |
+
|
605 |
+
// Computes and stores the |rhs_indices_| from the |col_deltas_|.
|
606 |
+
void ComputeRHSIndices() {
|
607 |
+
std::vector<int> cumulative_deltas = CumulativeColDeltas();
|
608 |
+
std::vector<DeltaType> rhs_indices(cumulative_deltas.size() +
|
609 |
+
reduced_rows_);
|
610 |
+
int total_indices = 0;
|
611 |
+
int delta_index = 0;
|
612 |
+
for (int r = 0; r < reduced_rows_; ++r) {
|
613 |
+
for (int n = 0; n < nnz_per_row_[r]; ++n, ++delta_index) {
|
614 |
+
rhs_indices[total_indices++] =
|
615 |
+
cumulative_deltas[delta_index] / block_width_;
|
616 |
+
}
|
617 |
+
}
|
618 |
+
rhs_indices_ = CacheAlignedVector<DeltaType>(rhs_indices);
|
619 |
+
}
|
620 |
+
|
621 |
+
// Computes and stores the |col_deltas_| from the |rhs_indices_|.
|
622 |
+
void ComputeColDeltas() {
|
623 |
+
std::vector<int> col_deltas(rhs_indices_.size());
|
624 |
+
int prev_index = 0;
|
625 |
+
for (int i = 0; i < rhs_indices_.size(); ++i) {
|
626 |
+
int offset = rhs_indices_[i] - prev_index;
|
627 |
+
prev_index = rhs_indices_[i];
|
628 |
+
col_deltas[i] = offset * block_width_ * sizeof(RhsType);
|
629 |
+
}
|
630 |
+
col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas);
|
631 |
+
}
|
632 |
+
|
633 |
+
// Computes and returns the inclusive prefix sum of the deltas, ie absolute
|
634 |
+
// positions.
|
635 |
+
std::vector<int> CumulativeColDeltas() const {
|
636 |
+
std::vector<int> cum_col_deltas(col_deltas_.size());
|
637 |
+
for (int i = 0; i < col_deltas_.size(); ++i) {
|
638 |
+
cum_col_deltas[i] = col_deltas_[i] / sizeof(RhsType);
|
639 |
+
if (i > 0) cum_col_deltas[i] += cum_col_deltas[i - 1];
|
640 |
+
}
|
641 |
+
return cum_col_deltas;
|
642 |
+
}
|
643 |
+
|
644 |
+
private:
|
645 |
+
constexpr std::size_t FixedParameterSize() const {
|
646 |
+
return sizeof(int) // rows
|
647 |
+
+ sizeof(int) // cols
|
648 |
+
+ sizeof(int) // reduced_rows
|
649 |
+
+ sizeof(int) // reduced_cols
|
650 |
+
+ sizeof(int) // block_width
|
651 |
+
+ sizeof(int) // block_height
|
652 |
+
+ sizeof(float) // sparsity
|
653 |
+
+ sizeof(int) // col_multiple
|
654 |
+
+ sizeof(int) // num_threads_
|
655 |
+
+ sizeof(int) // weights_.size()
|
656 |
+
+ sizeof(int) // col_deltas_.size()
|
657 |
+
+ sizeof(int); // nnz_per_row_.size()
|
658 |
+
}
|
659 |
+
// Possible block sizes are only those that are supported by the computation
|
660 |
+
// default is 1x1, other options are 4x4 and 16x1.
|
661 |
+
template <typename InputType>
|
662 |
+
void DetermineBlockSize(const MaskedSparseMatrix<InputType>& masked_matrix) {
|
663 |
+
const std::vector<std::pair<int, int>> kPreferredOrder = {{4, 4}};
|
664 |
+
int rows = masked_matrix.rows();
|
665 |
+
int cols = masked_matrix.cols();
|
666 |
+
|
667 |
+
for (const auto& block_size : kPreferredOrder) {
|
668 |
+
int block_height, block_width;
|
669 |
+
std::tie(block_height, block_width) = block_size;
|
670 |
+
if (cols % block_width != 0) continue;
|
671 |
+
|
672 |
+
int reduced_rows = (rows + block_height - 1) / block_height;
|
673 |
+
int reduced_cols = cols / block_width;
|
674 |
+
|
675 |
+
// For each possible block, confirm that it is either all 0s or all 1s.
|
676 |
+
bool all_same = true;
|
677 |
+
const auto& mask = masked_matrix.mask();
|
678 |
+
for (int r = 0; r < reduced_rows; ++r) {
|
679 |
+
for (int c = 0; c < reduced_cols; ++c) {
|
680 |
+
int val = mask[r * block_height * cols + c * block_width];
|
681 |
+
for (int i = 0; i < block_height; ++i) {
|
682 |
+
for (int j = 0; j < block_width; ++j) {
|
683 |
+
int index = (r * block_height + i) * cols + c * block_width + j;
|
684 |
+
if (index < masked_matrix.mask().size()) {
|
685 |
+
all_same &= (masked_matrix.mask()[index] == val);
|
686 |
+
}
|
687 |
+
}
|
688 |
+
}
|
689 |
+
}
|
690 |
+
}
|
691 |
+
|
692 |
+
// If this block configuration is possible, accept it.
|
693 |
+
if (all_same) {
|
694 |
+
block_height_ = block_height;
|
695 |
+
block_width_ = block_width;
|
696 |
+
return;
|
697 |
+
}
|
698 |
+
}
|
699 |
+
|
700 |
+
// No large blocks were found, default to 1x1.
|
701 |
+
block_height_ = 1;
|
702 |
+
block_width_ = 1;
|
703 |
+
}
|
704 |
+
|
705 |
+
// CSR descriptors are for the reduced matrix, weights is the full matrix.
|
706 |
+
template <typename InputType>
|
707 |
+
void MakeColumnsMultiple(const std::vector<int>& row_offsets,
|
708 |
+
std::vector<int>* reduced_mask,
|
709 |
+
std::vector<InputType>* weights) {
|
710 |
+
if (col_multiple_ > 0) {
|
711 |
+
// Make sure each row has a number of columns that is a multiple of
|
712 |
+
// |col_multiple|.
|
713 |
+
for (int r = 1; r < row_offsets.size(); ++r) {
|
714 |
+
int num_row = row_offsets[r] - row_offsets[r - 1];
|
715 |
+
int num_needed = col_multiple_ - num_row % col_multiple_;
|
716 |
+
if (num_needed < col_multiple_) {
|
717 |
+
// Find gaps in the columns where we can insert a column of 0 weights.
|
718 |
+
int num_added = 0;
|
719 |
+
for (int c = 0; c < reduced_cols_; ++c) {
|
720 |
+
if ((*reduced_mask)[(r - 1) * reduced_cols_ + c] == 0) {
|
721 |
+
(*reduced_mask)[(r - 1) * reduced_cols_ + c] = 1;
|
722 |
+
|
723 |
+
// Zero out the weights that correspond to this block.
|
724 |
+
for (int i = 0; i < block_height_; ++i) {
|
725 |
+
for (int j = 0; j < block_width_; ++j) {
|
726 |
+
(*weights)[((r - 1) * block_height_ + i) * cols_ +
|
727 |
+
block_width_ * c + j] = InputType(0.f);
|
728 |
+
}
|
729 |
+
}
|
730 |
+
num_added++;
|
731 |
+
}
|
732 |
+
|
733 |
+
if (num_added == num_needed) break;
|
734 |
+
}
|
735 |
+
}
|
736 |
+
}
|
737 |
+
}
|
738 |
+
}
|
739 |
+
|
740 |
+
// Given the final dense mask and weights, convert to the compressed
|
741 |
+
// block CSR representation.
|
742 |
+
template <typename InputType>
|
743 |
+
void MaskAndWeightsToCsr(const std::vector<int>& mask,
|
744 |
+
const std::vector<InputType>& weights,
|
745 |
+
std::vector<int>* nnz_per_row,
|
746 |
+
std::vector<int>* col_indices,
|
747 |
+
std::vector<WeightType>* weights_csr) {
|
748 |
+
std::vector<int> row_offsets = {0};
|
749 |
+
int nnz = 0;
|
750 |
+
// Standard CSR format.
|
751 |
+
if (block_width_ == 1 && block_height_ == 1) {
|
752 |
+
for (int r = 0; r < rows_; ++r) {
|
753 |
+
for (int c = 0; c < cols_; ++c) {
|
754 |
+
if (mask[r * cols_ + c] == 1) {
|
755 |
+
nnz++;
|
756 |
+
col_indices->push_back(c);
|
757 |
+
weights_csr->push_back(WeightType(weights[r * cols_ + c]));
|
758 |
+
}
|
759 |
+
}
|
760 |
+
row_offsets.push_back(nnz);
|
761 |
+
}
|
762 |
+
} else if (block_width_ == 4 && block_height_ == 4) {
|
763 |
+
// Weights are stored contiguously for each block in this case.
|
764 |
+
for (int r = 0; r < reduced_rows_; ++r) {
|
765 |
+
for (int c = 0; c < reduced_cols_; ++c) {
|
766 |
+
if (mask[r * reduced_cols_ + c] == 1) {
|
767 |
+
col_indices->push_back(c);
|
768 |
+
nnz++;
|
769 |
+
for (int i = 0; i < block_height_; ++i) {
|
770 |
+
for (int j = 0; j < block_width_; ++j) {
|
771 |
+
int row_index = (block_height_ * r + i) * cols_;
|
772 |
+
int w_index = row_index + block_width_ * c + j;
|
773 |
+
WeightType weight = w_index < weights.size()
|
774 |
+
? WeightType(weights[w_index])
|
775 |
+
: WeightType(0.0f);
|
776 |
+
weights_csr->push_back(weight);
|
777 |
+
}
|
778 |
+
}
|
779 |
+
}
|
780 |
+
}
|
781 |
+
row_offsets.push_back(nnz);
|
782 |
+
}
|
783 |
+
}
|
784 |
+
for (int i = 1; i < row_offsets.size(); ++i)
|
785 |
+
nnz_per_row->push_back(row_offsets[i] - row_offsets[i - 1]);
|
786 |
+
}
|
787 |
+
|
788 |
+
// Returns the number of block rows per cache line. This is the minimum unit
|
789 |
+
// into which the calculation is broken for threads.
|
790 |
+
template <typename OutType>
|
791 |
+
int ReducedRowsPerCacheLine(int override_cache_line_size = -1) const {
|
792 |
+
int line_size = kCacheLineSize;
|
793 |
+
if (override_cache_line_size >= 1) line_size = override_cache_line_size;
|
794 |
+
return std::max<int>(line_size / (block_height_ * sizeof(OutType)), 1);
|
795 |
+
}
|
796 |
+
|
797 |
+
int col_multiple_;
|
798 |
+
int rows_;
|
799 |
+
int cols_;
|
800 |
+
int reduced_rows_;
|
801 |
+
int reduced_cols_;
|
802 |
+
float sparsity_;
|
803 |
+
int block_width_;
|
804 |
+
int block_height_;
|
805 |
+
int num_threads_;
|
806 |
+
std::string name_;
|
807 |
+
|
808 |
+
CacheAlignedVector<WeightType> weights_;
|
809 |
+
CacheAlignedVector<DeltaType> col_deltas_;
|
810 |
+
CacheAlignedVector<int> nnz_per_row_;
|
811 |
+
// |thread_bounds_| and |rhs_indices_| don't need to be serialized as they are
|
812 |
+
// always recalculated from serialized data.
|
813 |
+
CacheAlignedVector<DeltaType> rhs_indices_;
|
814 |
+
Matmul<WeightType, RhsType> matmul_;
|
815 |
+
ThreadBounds thread_bounds_;
|
816 |
+
static constexpr int kCacheLineSize = 64;
|
817 |
+
};
|
818 |
+
|
819 |
+
// Converts a sparse matrix represented with (|mask|, |weights|, |size|) into
|
820 |
+
// the CSR format, and returns that as a serialized string.
|
821 |
+
template <typename MaskType>
|
822 |
+
std::string ConvertDenseToSparseRepresentation_Int16Deltas(
|
823 |
+
const std::vector<MaskType>& mask, const std::vector<float>& weights,
|
824 |
+
const int rows, const int cols) {
|
825 |
+
MaskedSparseMatrix<float> masked_weights(rows, cols, mask.data(),
|
826 |
+
weights.data());
|
827 |
+
CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t>
|
828 |
+
sparse_masked_weights(masked_weights);
|
829 |
+
std::string buffer;
|
830 |
+
sparse_masked_weights.WriteToFlatBuffer(&buffer);
|
831 |
+
return buffer;
|
832 |
+
}
|
833 |
+
|
834 |
+
} // namespace csrblocksparse
|
835 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_CSR_BLOCKSPARSE_MATRIX_H_
|
sparse_matmul/layers/csrblocksparse_test.cc
ADDED
@@ -0,0 +1,977 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2021 Google LLC
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.
|
14 |
+
|
15 |
+
#include <array>
|
16 |
+
#include <cstdint>
|
17 |
+
#include <tuple>
|
18 |
+
#include <vector>
|
19 |
+
|
20 |
+
// Placeholder for get runfiles header.
|
21 |
+
#include "absl/status/status.h"
|
22 |
+
#include "absl/strings/str_cat.h"
|
23 |
+
#include "absl/strings/string_view.h"
|
24 |
+
#include "absl/types/span.h"
|
25 |
+
#include "gtest/gtest.h"
|
26 |
+
#include "include/ghc/filesystem.hpp"
|
27 |
+
#include "sparse_matmul/compute/matmul.h"
|
28 |
+
#include "sparse_matmul/layers/utils.h"
|
29 |
+
#include "sparse_matmul/numerics/test_utils.h"
|
30 |
+
#include "sparse_matmul/os/coop_threads.h"
|
31 |
+
|
32 |
+
namespace csrblocksparse {
|
33 |
+
namespace {
|
34 |
+
|
35 |
+
inline constexpr absl::string_view kTestdataPath = "layers/testdata";
|
36 |
+
|
37 |
+
TEST(CSRBlockSparseMatrix, FlatBufferSerialization) {
|
38 |
+
const int kRows = 8;
|
39 |
+
const int kCols = 8;
|
40 |
+
std::vector<int> mask = {1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
|
41 |
+
1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0,
|
42 |
+
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1,
|
43 |
+
0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1};
|
44 |
+
std::vector<float> values(kRows * kCols, 1.f);
|
45 |
+
values[1] = 2.f;
|
46 |
+
values[3] = 3.f;
|
47 |
+
values[36] = -1.f;
|
48 |
+
values[45] = -2.f;
|
49 |
+
|
50 |
+
csrblocksparse::CacheAlignedVector<float> bias(kRows);
|
51 |
+
csrblocksparse::CacheAlignedVector<float> rhs(kCols);
|
52 |
+
csrblocksparse::CacheAlignedVector<float> out_ref(kRows);
|
53 |
+
csrblocksparse::CacheAlignedVector<float> out_test(kRows);
|
54 |
+
|
55 |
+
bias.FillZero();
|
56 |
+
rhs.FillOnes();
|
57 |
+
|
58 |
+
csrblocksparse::MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(),
|
59 |
+
values.data());
|
60 |
+
|
61 |
+
matrix.SpMM_bias(rhs, bias, &out_ref);
|
62 |
+
|
63 |
+
csrblocksparse::CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t>
|
64 |
+
block_sparse_matrix(matrix);
|
65 |
+
|
66 |
+
std::string buffer;
|
67 |
+
std::size_t num_bytes = block_sparse_matrix.WriteToFlatBuffer(&buffer);
|
68 |
+
|
69 |
+
csrblocksparse::CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t>
|
70 |
+
new_block_sparse_matrix(reinterpret_cast<const uint8_t*>(buffer.c_str()),
|
71 |
+
num_bytes);
|
72 |
+
|
73 |
+
new_block_sparse_matrix.SpMM_bias(rhs, bias, &out_test);
|
74 |
+
|
75 |
+
CheckResult(out_ref, out_test, kCols);
|
76 |
+
}
|
77 |
+
|
78 |
+
template <typename ComputeType, typename RhsType, typename OutType>
|
79 |
+
void CorrectnessCheckBlockSpMM(int rows, int cols, int block_height,
|
80 |
+
int block_width, float sparsity,
|
81 |
+
bool use_relu = false, int num_threads = 1,
|
82 |
+
int fatness = 1, bool test_matmul = false) {
|
83 |
+
using BiasType = typename TypeOfProduct<ComputeType, RhsType>::type;
|
84 |
+
MaskedSparseMatrix<float> matrix(rows, cols, sparsity, block_height,
|
85 |
+
block_width);
|
86 |
+
matrix.CastWeights<ComputeType>();
|
87 |
+
FatCacheAlignedVector<RhsType> rhs(cols, fatness);
|
88 |
+
CacheAlignedVector<BiasType> bias(rows);
|
89 |
+
FatCacheAlignedVector<OutType> out(rows, fatness);
|
90 |
+
|
91 |
+
bias.FillRandom();
|
92 |
+
rhs.FillRandom();
|
93 |
+
out.FillZero();
|
94 |
+
FatCacheAlignedVector<OutType> out_reference = out;
|
95 |
+
|
96 |
+
matrix.SpMM_bias(rhs, bias, &out_reference, use_relu);
|
97 |
+
|
98 |
+
CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
|
99 |
+
|
100 |
+
SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
|
101 |
+
std::move(sparse_matrix), std::move(bias));
|
102 |
+
num_threads = sparse_linear_layer.PrepareForThreads(num_threads);
|
103 |
+
|
104 |
+
// Checks that the result of applying each thread's portion serially is
|
105 |
+
// correct.
|
106 |
+
for (int thread_id = 0; thread_id < num_threads; ++thread_id) {
|
107 |
+
sparse_linear_layer.SpMM_bias(rhs, &out, use_relu, thread_id);
|
108 |
+
}
|
109 |
+
|
110 |
+
CheckResult(out_reference, out, sparse_linear_layer.cols());
|
111 |
+
|
112 |
+
if (test_matmul) {
|
113 |
+
for (int thread_id = 0; thread_id < num_threads; ++thread_id) {
|
114 |
+
sparse_linear_layer.MatVec(rhs, use_relu, thread_id,
|
115 |
+
/*replicas=*/1, /*output_stride=*/0, &out);
|
116 |
+
}
|
117 |
+
|
118 |
+
CheckResult(out_reference, out, sparse_linear_layer.cols());
|
119 |
+
}
|
120 |
+
}
|
121 |
+
|
122 |
+
// Does:
|
123 |
+
// y = Ax + b;
|
124 |
+
// x = Ay + b;
|
125 |
+
// y = Ax + b;
|
126 |
+
//
|
127 |
+
// to make sure that dependent multiplies are correct.
|
128 |
+
template <typename ComputeType, typename RhsType, typename OutType>
|
129 |
+
void ThreadBody(
|
130 |
+
SpinBarrier* spin_barrier, int tid,
|
131 |
+
const SparseLinearLayer<ComputeType, RhsType>& sparse_linear_layer,
|
132 |
+
FatCacheAlignedVector<RhsType>* rhs, FatCacheAlignedVector<OutType>* out,
|
133 |
+
bool use_relu) {
|
134 |
+
sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid);
|
135 |
+
spin_barrier->barrier();
|
136 |
+
sparse_linear_layer.SpMM_bias(*out, rhs, use_relu, tid);
|
137 |
+
spin_barrier->barrier();
|
138 |
+
sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid);
|
139 |
+
}
|
140 |
+
|
141 |
+
template <typename ComputeType, typename RhsType, typename OutType>
|
142 |
+
void CorrectnessCheckBlockSpMM_MultiThread(int rows, int cols, int block_height,
|
143 |
+
int block_width, float sparsity,
|
144 |
+
bool use_relu = false,
|
145 |
+
int num_threads = 1,
|
146 |
+
int fatness = 1) {
|
147 |
+
typedef typename TypeOfProduct<ComputeType, RhsType>::type BiasType;
|
148 |
+
CHECK(rows == cols);
|
149 |
+
MaskedSparseMatrix<float> matrix(rows, cols, sparsity, block_height,
|
150 |
+
block_width);
|
151 |
+
matrix.CastWeights<ComputeType>();
|
152 |
+
FatCacheAlignedVector<RhsType> rhs(cols, fatness);
|
153 |
+
FatCacheAlignedVector<RhsType> rhs_mt(cols, fatness);
|
154 |
+
CacheAlignedVector<BiasType> bias(rows);
|
155 |
+
FatCacheAlignedVector<OutType> out(rows, fatness);
|
156 |
+
|
157 |
+
bias.FillOnes();
|
158 |
+
rhs.FillOnes();
|
159 |
+
rhs_mt.FillOnes();
|
160 |
+
out.FillZero();
|
161 |
+
FatCacheAlignedVector<OutType> out_reference = out;
|
162 |
+
|
163 |
+
matrix.SpMM_bias(rhs, bias, &out_reference, use_relu);
|
164 |
+
matrix.SpMM_bias(out_reference, bias, &rhs, use_relu);
|
165 |
+
matrix.SpMM_bias(rhs, bias, &out_reference, use_relu);
|
166 |
+
|
167 |
+
CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
|
168 |
+
|
169 |
+
num_threads = sparse_matrix.PrepareForThreads(num_threads,
|
170 |
+
/*cache_line_size=*/1);
|
171 |
+
|
172 |
+
SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
|
173 |
+
std::move(sparse_matrix), std::move(bias));
|
174 |
+
|
175 |
+
csrblocksparse::LaunchOnThreadsWithBarrier(
|
176 |
+
num_threads, ThreadBody<ComputeType, RhsType, OutType>,
|
177 |
+
sparse_linear_layer, &rhs_mt, &out, use_relu);
|
178 |
+
|
179 |
+
CheckResult(out_reference, out, cols);
|
180 |
+
}
|
181 |
+
|
182 |
+
} // namespace
|
183 |
+
|
184 |
+
TEST(MaskedSparseCorrectness, HandCoded) {
|
185 |
+
const int kRows = 8;
|
186 |
+
const int kCols = 8;
|
187 |
+
// clang-format off
|
188 |
+
std::vector<int> mask = {1, 1, 0, 0, 0, 1, 1, 1,
|
189 |
+
0, 1, 0, 1, 0, 1, 0, 1,
|
190 |
+
1, 0, 0, 1, 1, 1, 1, 0,
|
191 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
192 |
+
1, 1, 1, 1, 1, 1, 1, 1,
|
193 |
+
0, 0, 0, 0, 1, 1, 0, 0,
|
194 |
+
1, 1, 0, 0, 1, 1, 0, 0,
|
195 |
+
1, 0, 0, 0, 0, 1, 0, 1};
|
196 |
+
// clang-format on
|
197 |
+
std::vector<float> values(kRows * kCols, 1.f);
|
198 |
+
|
199 |
+
std::vector<float> answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f};
|
200 |
+
|
201 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(), values.data());
|
202 |
+
CacheAlignedVector<float> rhs(kCols);
|
203 |
+
CacheAlignedVector<float> bias(kRows);
|
204 |
+
CacheAlignedVector<float> out(kRows);
|
205 |
+
|
206 |
+
bias.FillOnes();
|
207 |
+
rhs.FillOnes();
|
208 |
+
out.FillZero();
|
209 |
+
|
210 |
+
MaskedLinearLayer<float> masked_linear_layer(std::move(matrix),
|
211 |
+
std::move(bias));
|
212 |
+
|
213 |
+
masked_linear_layer.SpMM_bias(rhs, &out);
|
214 |
+
|
215 |
+
for (int i = 0; i < kRows; ++i) {
|
216 |
+
EXPECT_EQ(answer[i], out[i]);
|
217 |
+
}
|
218 |
+
}
|
219 |
+
|
220 |
+
TEST(MaskedSparseCorrectness, HandCodedFatVector) {
|
221 |
+
const int kRows = 8;
|
222 |
+
const int kCols = 8;
|
223 |
+
// clang-format off
|
224 |
+
std::vector<int> mask = {1, 1, 0, 0, 0, 1, 1, 1,
|
225 |
+
0, 1, 0, 1, 0, 1, 0, 1,
|
226 |
+
1, 0, 0, 1, 1, 1, 1, 0,
|
227 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
228 |
+
1, 1, 1, 1, 1, 1, 1, 1,
|
229 |
+
0, 0, 0, 0, 1, 1, 0, 0,
|
230 |
+
1, 1, 0, 0, 1, 1, 0, 0,
|
231 |
+
1, 0, 0, 0, 0, 1, 0, 1};
|
232 |
+
// clang-format on
|
233 |
+
|
234 |
+
std::vector<float> values(kRows * kCols, 1.f);
|
235 |
+
std::vector<float> answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f};
|
236 |
+
|
237 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(), values.data());
|
238 |
+
const int kMaxWidth = 5;
|
239 |
+
for (int width = 5; width <= kMaxWidth; ++width) {
|
240 |
+
FatCacheAlignedVector<float> rhs(kCols, width);
|
241 |
+
CacheAlignedVector<float> bias(kRows);
|
242 |
+
FatCacheAlignedVector<float> out(kRows, width);
|
243 |
+
|
244 |
+
bias.FillOnes();
|
245 |
+
rhs.FillOnes();
|
246 |
+
out.FillZero();
|
247 |
+
|
248 |
+
MaskedLinearLayer<float> masked_linear_layer(std::move(matrix),
|
249 |
+
std::move(bias));
|
250 |
+
|
251 |
+
masked_linear_layer.SpMM_bias(rhs, &out);
|
252 |
+
|
253 |
+
for (int i = 0; i < kRows; ++i) {
|
254 |
+
for (int width = 0; width < kMaxWidth; ++width) {
|
255 |
+
EXPECT_EQ(answer[i], out[i + width * kRows]);
|
256 |
+
}
|
257 |
+
}
|
258 |
+
}
|
259 |
+
}
|
260 |
+
|
261 |
+
TEST(CsrBlockSparseMatrix, HandCodedMultiThread) {
|
262 |
+
const int kRows = 8;
|
263 |
+
const int kCols = 8;
|
264 |
+
// clang-format off
|
265 |
+
std::vector<int> mask = {1, 1, 0, 0, 0, 1, 1, 1,
|
266 |
+
0, 1, 0, 1, 0, 1, 0, 1,
|
267 |
+
1, 0, 0, 1, 1, 1, 1, 0,
|
268 |
+
0, 0, 0, 0, 0, 0, 0, 0,
|
269 |
+
1, 1, 1, 1, 1, 1, 1, 1,
|
270 |
+
0, 0, 0, 0, 1, 1, 0, 0,
|
271 |
+
1, 1, 0, 0, 1, 1, 0, 0,
|
272 |
+
1, 0, 0, 0, 0, 1, 0, 1};
|
273 |
+
// clang-format on
|
274 |
+
std::vector<float> values(kRows * kCols, 1.f);
|
275 |
+
|
276 |
+
std::vector<float> answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f};
|
277 |
+
|
278 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, mask.data(), values.data());
|
279 |
+
CacheAlignedVector<float> rhs(kCols);
|
280 |
+
CacheAlignedVector<float> bias(kRows);
|
281 |
+
CacheAlignedVector<float> out(kRows);
|
282 |
+
|
283 |
+
bias.FillOnes();
|
284 |
+
rhs.FillOnes();
|
285 |
+
out.FillZero();
|
286 |
+
|
287 |
+
CacheAlignedVector<float> bias_csr = bias;
|
288 |
+
|
289 |
+
CsrBlockSparseMatrix<bfloat16, float> sparse_matrix(matrix);
|
290 |
+
|
291 |
+
MaskedLinearLayer<float> masked_linear_layer(std::move(matrix),
|
292 |
+
std::move(bias));
|
293 |
+
|
294 |
+
masked_linear_layer.SpMM_bias(rhs, &out);
|
295 |
+
|
296 |
+
SparseLinearLayer<bfloat16, float> sparse_linear_layer(
|
297 |
+
std::move(sparse_matrix), std::move(bias_csr));
|
298 |
+
sparse_linear_layer.PrepareForThreads(2, /*cache_line_size=*/1);
|
299 |
+
|
300 |
+
CacheAlignedVector<float> out_tmp(kRows);
|
301 |
+
const bool kUseRelu = false;
|
302 |
+
sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/0);
|
303 |
+
sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/1);
|
304 |
+
|
305 |
+
for (int i = 0; i < kRows; ++i) {
|
306 |
+
EXPECT_EQ(answer[i], out_tmp[i]);
|
307 |
+
}
|
308 |
+
}
|
309 |
+
|
310 |
+
TEST(TestCasts, TestBfloat16) {
|
311 |
+
const int kRows = 1000;
|
312 |
+
const int kCols = 100;
|
313 |
+
const float kSparsity = 0.f;
|
314 |
+
|
315 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
|
316 |
+
MaskedSparseMatrix<float> matrix_bfloat16(kRows, kCols, matrix.mask().data(),
|
317 |
+
matrix.values().data());
|
318 |
+
|
319 |
+
matrix_bfloat16.CastWeights<bfloat16>();
|
320 |
+
|
321 |
+
CheckResult(matrix.values(), matrix_bfloat16.values(), kCols);
|
322 |
+
}
|
323 |
+
|
324 |
+
TEST(TestCasts, TestFP16) {
|
325 |
+
const int kRows = 1000;
|
326 |
+
const int kCols = 100;
|
327 |
+
const float kSparsity = 0.f;
|
328 |
+
|
329 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
|
330 |
+
#if !defined __arm__ && !defined __aarch64__
|
331 |
+
// Conversion doesn't handle denormals, so flush denormals to zero first.
|
332 |
+
for (int i = 0; i < matrix.values().size(); ++i) {
|
333 |
+
if (matrix.data()[i] < 1. / static_cast<float>(1 << 14))
|
334 |
+
matrix.data()[i] = 0.f;
|
335 |
+
}
|
336 |
+
#endif
|
337 |
+
MaskedSparseMatrix<float> matrix_fp16(kRows, kCols, matrix.mask().data(),
|
338 |
+
matrix.values().data());
|
339 |
+
|
340 |
+
matrix_fp16.CastWeights<csrblocksparse::fp16>();
|
341 |
+
|
342 |
+
CheckResult(matrix.values(), matrix_fp16.values(), kCols);
|
343 |
+
}
|
344 |
+
|
345 |
+
TEST(TestCasts, TestFixed16) {
|
346 |
+
const int kRows = 100000;
|
347 |
+
const int kCols = 1;
|
348 |
+
const float kSparsity = 0.f;
|
349 |
+
|
350 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
|
351 |
+
|
352 |
+
// Relative error for fixed point is high near 0.
|
353 |
+
for (int i = 0; i < matrix.values().size(); ++i) {
|
354 |
+
// 1.1e-3 is based on the max error of .013 and a grid spacing of 1 / 2**16
|
355 |
+
// == 3e-5. 3e-5 / .013 / 2 = 1.1e-3.
|
356 |
+
if (std::abs(matrix.data()[i]) < 1.1e-3) {
|
357 |
+
matrix.data()[i] = 0.f;
|
358 |
+
}
|
359 |
+
}
|
360 |
+
|
361 |
+
MaskedSparseMatrix<float> matrix_fixed16 = matrix;
|
362 |
+
|
363 |
+
matrix_fixed16.CastWeights<csrblocksparse::fixed16</*ExponentBits=*/0>>();
|
364 |
+
|
365 |
+
CheckResult(matrix.values(), matrix_fixed16.values(), kCols);
|
366 |
+
}
|
367 |
+
|
368 |
+
TEST(TestCasts, TestFixed32) {
|
369 |
+
const int kRows = 100000;
|
370 |
+
const int kCols = 1;
|
371 |
+
const float kSparsity = 0.f;
|
372 |
+
|
373 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, kSparsity);
|
374 |
+
MaskedSparseMatrix<float> matrix_fixed32 = matrix;
|
375 |
+
|
376 |
+
matrix_fixed32.CastWeights<csrblocksparse::fixed32</*ExponentBits=*/0>>();
|
377 |
+
|
378 |
+
CheckResult(matrix.values(), matrix_fixed32.values(), kCols);
|
379 |
+
}
|
380 |
+
|
381 |
+
template <typename ComputeType, typename RhsType, typename OutType>
|
382 |
+
void TestSpMM(int block_width, int block_height, int fatness,
|
383 |
+
bool test_matmul = false) {
|
384 |
+
std::array<bool, 2> use_relu = {false, true};
|
385 |
+
std::vector<float> sparsity_levels = {.5, .8, .9, .95, .98};
|
386 |
+
std::vector<std::pair<int, int>> sizes = {{8, 8}, {128, 128}, {128, 64},
|
387 |
+
{256, 192}, {512, 512}, {1024, 512},
|
388 |
+
{384, 384}, {512, 384}};
|
389 |
+
for (int num_threads = 1; num_threads < 2 + test_matmul; ++num_threads) {
|
390 |
+
for (const auto& relu : use_relu) {
|
391 |
+
for (const auto& sparsity : sparsity_levels) {
|
392 |
+
for (const auto& size : sizes) {
|
393 |
+
int rows, cols;
|
394 |
+
std::tie(rows, cols) = size;
|
395 |
+
CorrectnessCheckBlockSpMM<ComputeType, RhsType, OutType>(
|
396 |
+
rows, cols, block_height, block_width, sparsity, relu,
|
397 |
+
num_threads, fatness, test_matmul);
|
398 |
+
}
|
399 |
+
}
|
400 |
+
}
|
401 |
+
}
|
402 |
+
}
|
403 |
+
|
404 |
+
template <typename ComputeType, typename RhsType, typename OutType>
|
405 |
+
void TestSpMM_MultiThread(int block_width, int block_height, int fatness) {
|
406 |
+
std::array<bool, 2> use_relu = {false, true};
|
407 |
+
std::vector<float> sparsity_levels = {.5, .8, .9, .95, .98};
|
408 |
+
std::vector<std::pair<int, int>> sizes = {
|
409 |
+
{48, 48}, {128, 128}, {512, 512}, {384, 384}};
|
410 |
+
for (int num_threads = 1; num_threads < 5; ++num_threads) {
|
411 |
+
for (const auto& relu : use_relu) {
|
412 |
+
for (const auto& sparsity : sparsity_levels) {
|
413 |
+
for (const auto& size : sizes) {
|
414 |
+
int rows, cols;
|
415 |
+
std::tie(rows, cols) = size;
|
416 |
+
CorrectnessCheckBlockSpMM_MultiThread<ComputeType, RhsType, OutType>(
|
417 |
+
rows, cols, block_height, block_width, sparsity, relu,
|
418 |
+
num_threads, fatness);
|
419 |
+
}
|
420 |
+
}
|
421 |
+
}
|
422 |
+
}
|
423 |
+
}
|
424 |
+
|
425 |
+
template <typename DataType>
|
426 |
+
void TestSumVectors(int start = 0, int end = -1, int size = 6) {
|
427 |
+
std::vector<DataType> values;
|
428 |
+
std::vector<DataType> answer;
|
429 |
+
|
430 |
+
for (int i = 1; i < size + 1; ++i) {
|
431 |
+
const float x = static_cast<float>(i);
|
432 |
+
values.push_back(static_cast<DataType>(x));
|
433 |
+
answer.push_back(static_cast<DataType>(x * 2));
|
434 |
+
}
|
435 |
+
|
436 |
+
if (end == -1) {
|
437 |
+
end = values.size();
|
438 |
+
}
|
439 |
+
|
440 |
+
csrblocksparse::CacheAlignedVector<DataType> result(values.size());
|
441 |
+
csrblocksparse::CacheAlignedVector<DataType> values_aligned(values);
|
442 |
+
detail::SumVectors(start, end, values_aligned.data(), values_aligned.data(),
|
443 |
+
result.data());
|
444 |
+
for (int i = start; i < end; ++i) {
|
445 |
+
EXPECT_EQ(static_cast<float>(answer[i]), static_cast<float>(result[i]));
|
446 |
+
}
|
447 |
+
}
|
448 |
+
|
449 |
+
TEST(CsrBlockSparseMatrix, SumVectors_Generic) {
|
450 |
+
TestSumVectors<float>();
|
451 |
+
TestSumVectors<float>(1);
|
452 |
+
TestSumVectors<float>(1, 4);
|
453 |
+
}
|
454 |
+
|
455 |
+
TEST(CsrBlockSparseMatrix, SumVectors_Bfloat16) {
|
456 |
+
TestSumVectors<csrblocksparse::bfloat16>();
|
457 |
+
TestSumVectors<csrblocksparse::bfloat16>(1);
|
458 |
+
TestSumVectors<csrblocksparse::bfloat16>(1, 4);
|
459 |
+
}
|
460 |
+
|
461 |
+
// For SIMD-optimized SumVectors, the memory of the vector should be at least
|
462 |
+
// |kSIMDWidth * sizeof(float)| long, and the start position has to be an
|
463 |
+
// aligned memory location. So setting |size| to be 100 to be safe and
|
464 |
+
// |start| to be 0 (|start| == 1 is not aligned).
|
465 |
+
TEST(CsrBlockSparseMatrix, SumVectors_Fixed16) {
|
466 |
+
TestSumVectors<csrblocksparse::fixed16<8>>(0, -1, 100);
|
467 |
+
TestSumVectors<csrblocksparse::fixed16<8>>(0, 4, 100);
|
468 |
+
}
|
469 |
+
|
470 |
+
TEST(CsrBlockSparseMatrix, SumVectors_Fixed32) {
|
471 |
+
TestSumVectors<csrblocksparse::fixed32<11>>(0, -1, 100);
|
472 |
+
TestSumVectors<csrblocksparse::fixed32<11>>(0, 4, 100);
|
473 |
+
}
|
474 |
+
|
475 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block4x4_Bfloat16) {
|
476 |
+
TestSpMM<csrblocksparse::bfloat16, float, float>(/*block_width=*/4,
|
477 |
+
/*block_height=*/4,
|
478 |
+
/*fatness=*/7);
|
479 |
+
}
|
480 |
+
|
481 |
+
// This actually uses multiple threads, and uses the output as the input for
|
482 |
+
// multiple steps to test that synchronization and memory visibility is
|
483 |
+
// working correctly.Requires square matrices.
|
484 |
+
TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_Bfloat16) {
|
485 |
+
TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
|
486 |
+
/*block_width=*/4,
|
487 |
+
/*block_height=*/4,
|
488 |
+
/*fatness=*/1);
|
489 |
+
}
|
490 |
+
|
491 |
+
TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_Bfloat16) {
|
492 |
+
TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
|
493 |
+
/*block_width=*/4,
|
494 |
+
/*block_height=*/4,
|
495 |
+
/*fatness=*/7);
|
496 |
+
}
|
497 |
+
|
498 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block1x1_Bfloat16) {
|
499 |
+
TestSpMM<csrblocksparse::bfloat16, float, float>(/*block_width=*/1,
|
500 |
+
/*block_height=*/1,
|
501 |
+
/*fatness=*/1);
|
502 |
+
}
|
503 |
+
|
504 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block1x1_Bfloat16) {
|
505 |
+
TestSpMM<csrblocksparse::bfloat16, float, float>(/*block_width=*/1,
|
506 |
+
/*block_height=*/1,
|
507 |
+
/*fatness=*/7);
|
508 |
+
}
|
509 |
+
|
510 |
+
// This actually uses multiple threads, and uses the output as the input for
|
511 |
+
// multiple steps to test that synchronization and memory visibility is
|
512 |
+
// working correctly.Requires square matrices.
|
513 |
+
TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_Bfloat16) {
|
514 |
+
TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
|
515 |
+
/*block_width=*/1,
|
516 |
+
/*block_height=*/1,
|
517 |
+
/*fatness=*/1);
|
518 |
+
}
|
519 |
+
|
520 |
+
TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_Bfloat16) {
|
521 |
+
TestSpMM_MultiThread<csrblocksparse::bfloat16, float, float>(
|
522 |
+
/*block_width=*/1,
|
523 |
+
/*block_height=*/1,
|
524 |
+
/*fatness=*/7);
|
525 |
+
}
|
526 |
+
|
527 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block4x4_float) {
|
528 |
+
TestSpMM<float, float, float>(/*block_width=*/4,
|
529 |
+
/*block_height=*/4,
|
530 |
+
/*fatness=*/1,
|
531 |
+
/*test_matmul=*/true);
|
532 |
+
}
|
533 |
+
|
534 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block4x4_float) {
|
535 |
+
TestSpMM<float, float, float>(/*block_width=*/4,
|
536 |
+
/*block_height=*/4,
|
537 |
+
/*fatness=*/7);
|
538 |
+
}
|
539 |
+
|
540 |
+
// This actually uses multiple threads, and uses the output as the input for
|
541 |
+
// multiple steps to test that synchronization and memory visibility is
|
542 |
+
// working correctly.Requires square matrices.
|
543 |
+
TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_float) {
|
544 |
+
TestSpMM_MultiThread<float, float, float>(/*block_width=*/4,
|
545 |
+
/*block_height=*/4,
|
546 |
+
/*fatness=*/1);
|
547 |
+
}
|
548 |
+
|
549 |
+
TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_float) {
|
550 |
+
TestSpMM_MultiThread<float, float, float>(/*block_width=*/4,
|
551 |
+
/*block_height=*/4,
|
552 |
+
/*fatness=*/7);
|
553 |
+
}
|
554 |
+
|
555 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block1x1_float) {
|
556 |
+
TestSpMM<float, float, float>(/*block_width=*/1,
|
557 |
+
/*block_height=*/1,
|
558 |
+
/*fatness=*/1);
|
559 |
+
}
|
560 |
+
|
561 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block1x1_float) {
|
562 |
+
TestSpMM<float, float, float>(/*block_width=*/1,
|
563 |
+
/*block_height=*/1,
|
564 |
+
/*fatness=*/7);
|
565 |
+
}
|
566 |
+
|
567 |
+
// This actually uses multiple threads, and uses the output as the input for
|
568 |
+
// multiple steps to test that synchronization and memory visibility is
|
569 |
+
// working correctly.Requires square matrices.
|
570 |
+
TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_float) {
|
571 |
+
TestSpMM_MultiThread<float, float, float>(/*block_width=*/1,
|
572 |
+
/*block_height=*/1,
|
573 |
+
/*fatness=*/1);
|
574 |
+
}
|
575 |
+
|
576 |
+
TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_float) {
|
577 |
+
TestSpMM_MultiThread<float, float, float>(/*block_width=*/1,
|
578 |
+
/*block_height=*/1,
|
579 |
+
/*fatness=*/7);
|
580 |
+
}
|
581 |
+
|
582 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32) {
|
583 |
+
TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
|
584 |
+
typename csrblocksparse::TypeOfProduct<
|
585 |
+
csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
|
586 |
+
/*block_width=*/4,
|
587 |
+
/*block_height=*/4,
|
588 |
+
/*fatness=*/1,
|
589 |
+
/*test_matmul=*/true);
|
590 |
+
}
|
591 |
+
|
592 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32) {
|
593 |
+
TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
|
594 |
+
typename csrblocksparse::TypeOfProduct<
|
595 |
+
csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
|
596 |
+
/*block_width=*/4,
|
597 |
+
/*block_height=*/4,
|
598 |
+
/*fatness=*/7);
|
599 |
+
}
|
600 |
+
|
601 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32) {
|
602 |
+
TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
|
603 |
+
typename csrblocksparse::TypeOfProduct<
|
604 |
+
csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
|
605 |
+
/*block_width=*/1,
|
606 |
+
/*block_height=*/1,
|
607 |
+
/*fatness=*/1);
|
608 |
+
}
|
609 |
+
|
610 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32) {
|
611 |
+
TestSpMM<csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>,
|
612 |
+
typename csrblocksparse::TypeOfProduct<
|
613 |
+
csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>(
|
614 |
+
/*block_width=*/1,
|
615 |
+
/*block_height=*/1,
|
616 |
+
/*fatness=*/7);
|
617 |
+
}
|
618 |
+
|
619 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_16) {
|
620 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
621 |
+
csrblocksparse::fixed16<8>>(
|
622 |
+
/*block_width=*/4,
|
623 |
+
/*block_height=*/4,
|
624 |
+
/*fatness=*/1,
|
625 |
+
/*test_matmul=*/true);
|
626 |
+
}
|
627 |
+
|
628 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_16) {
|
629 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
630 |
+
csrblocksparse::fixed16<8>>(
|
631 |
+
/*block_width=*/4,
|
632 |
+
/*block_height=*/4,
|
633 |
+
/*fatness=*/7);
|
634 |
+
}
|
635 |
+
|
636 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_16) {
|
637 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
638 |
+
csrblocksparse::fixed16<8>>(
|
639 |
+
/*block_width=*/1,
|
640 |
+
/*block_height=*/1,
|
641 |
+
/*fatness=*/1);
|
642 |
+
}
|
643 |
+
|
644 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_16) {
|
645 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
646 |
+
csrblocksparse::fixed16<8>>(
|
647 |
+
/*block_width=*/1,
|
648 |
+
/*block_height=*/1,
|
649 |
+
/*fatness=*/7);
|
650 |
+
}
|
651 |
+
|
652 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32_unmatched) {
|
653 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
654 |
+
csrblocksparse::fixed32<13>>(
|
655 |
+
/*block_width=*/4,
|
656 |
+
/*block_height=*/4,
|
657 |
+
/*fatness=*/1,
|
658 |
+
/*test_matmul=*/true);
|
659 |
+
}
|
660 |
+
|
661 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32_unmatched) {
|
662 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
663 |
+
csrblocksparse::fixed32<13>>(
|
664 |
+
/*block_width=*/4,
|
665 |
+
/*block_height=*/4,
|
666 |
+
/*fatness=*/7);
|
667 |
+
}
|
668 |
+
|
669 |
+
TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32_unmatched) {
|
670 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
671 |
+
csrblocksparse::fixed32<13>>(
|
672 |
+
/*block_width=*/1,
|
673 |
+
/*block_height=*/1,
|
674 |
+
/*fatness=*/1);
|
675 |
+
}
|
676 |
+
|
677 |
+
TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32_unmatched) {
|
678 |
+
TestSpMM<csrblocksparse::fixed16<5>, csrblocksparse::fixed16<5>,
|
679 |
+
csrblocksparse::fixed32<13>>(
|
680 |
+
/*block_width=*/1,
|
681 |
+
/*block_height=*/1,
|
682 |
+
/*fatness=*/7);
|
683 |
+
}
|
684 |
+
|
685 |
+
TEST(CsrBlockSparseMatrix, RhsIndicesDeltasRoundTrip) {
|
686 |
+
MaskedSparseMatrix<float> matrix(/*rows=*/256, /*cols=*/256,
|
687 |
+
/*sparsity=*/0.9, /*block_height=*/4,
|
688 |
+
/*block_width=*/4);
|
689 |
+
CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
|
690 |
+
CacheAlignedVector<int16_t> copy_indices = sparse_matrix.rhs_indices();
|
691 |
+
sparse_matrix.ComputeColDeltas();
|
692 |
+
sparse_matrix.ComputeRHSIndices();
|
693 |
+
// They get padded when created, so the newer one could be bigger.
|
694 |
+
EXPECT_LE(copy_indices.size(), sparse_matrix.rhs_indices().size());
|
695 |
+
for (int i = 0; i < copy_indices.size(); ++i) {
|
696 |
+
EXPECT_EQ(copy_indices[i], sparse_matrix.rhs_indices()[i]) << "i=" << i;
|
697 |
+
}
|
698 |
+
}
|
699 |
+
|
700 |
+
// Tests that a Layer that is split into 2 by columns (inputs) computes the same
|
701 |
+
// result as the original layer.
|
702 |
+
TEST(CsrBlockSparseMatrix, SplitByCol) {
|
703 |
+
int kRows = 1024;
|
704 |
+
int kCols = 1024;
|
705 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, 0.95, /*block_height=*/4,
|
706 |
+
/*block_width=*/4);
|
707 |
+
FatCacheAlignedVector<float> rhs(kCols, /*cols=*/1);
|
708 |
+
CacheAlignedVector<float> bias(kRows);
|
709 |
+
FatCacheAlignedVector<float> out1(kRows, /*cols=*/1);
|
710 |
+
FatCacheAlignedVector<float> out2(kRows, /*cols=*/1);
|
711 |
+
|
712 |
+
bias.FillRandom();
|
713 |
+
rhs.FillRandom();
|
714 |
+
out1.FillZero();
|
715 |
+
out2.FillZero();
|
716 |
+
FatCacheAlignedVector<float> out_reference = out1;
|
717 |
+
|
718 |
+
CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
|
719 |
+
|
720 |
+
SparseLinearLayer<float, float> sparse_linear_layer(std::move(sparse_matrix),
|
721 |
+
std::move(bias));
|
722 |
+
sparse_linear_layer.PrepareForThreads(1);
|
723 |
+
sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false,
|
724 |
+
/*tid=*/0);
|
725 |
+
// Split the layer into 2 parts.
|
726 |
+
SparseLinearLayer<float, float> part1, part2;
|
727 |
+
sparse_linear_layer.SplitInputs(&part1, &part2);
|
728 |
+
part1.PrepareForThreads(1);
|
729 |
+
part2.PrepareForThreads(1);
|
730 |
+
EXPECT_EQ(kRows, part1.rows());
|
731 |
+
EXPECT_EQ(kCols / 2, part1.cols());
|
732 |
+
EXPECT_EQ(kRows, part2.rows());
|
733 |
+
EXPECT_EQ(kCols / 2, part2.cols());
|
734 |
+
MutableVectorView<float> rhs1(&rhs, 0, kCols / 2);
|
735 |
+
MutableVectorView<float> rhs2(&rhs, kCols / 2, kCols / 2);
|
736 |
+
for (int i = 0; i < kCols / 2; ++i) {
|
737 |
+
EXPECT_FLOAT_EQ(rhs[i], rhs1.data()[i]);
|
738 |
+
EXPECT_FLOAT_EQ(rhs[i + kCols / 2], rhs2.data()[i]);
|
739 |
+
}
|
740 |
+
part1.SpMM_bias(rhs1, &out1, /*relu=*/false, /*tid=*/0);
|
741 |
+
part2.SpMM_bias(rhs2, &out2, /*relu=*/false, /*tid=*/0);
|
742 |
+
// Check that out1 + out2 = out_reference.
|
743 |
+
for (int i = 0; i < kRows; ++i) {
|
744 |
+
EXPECT_NEAR(out_reference[i], out1[i] + out2[i], 2e-5)
|
745 |
+
<< " i=" << i << " out1=" << out1[i] << " out2=" << out2[i];
|
746 |
+
}
|
747 |
+
}
|
748 |
+
// Tests that a Layer that is split into 2 by rows (outputs) computes the same
|
749 |
+
// result as the original layer.
|
750 |
+
TEST(CsrBlockSparseMatrix, SplitByRow) {
|
751 |
+
int kRows = 1024;
|
752 |
+
int kCols = 1024;
|
753 |
+
MaskedSparseMatrix<float> matrix(kRows, kCols, 0.95, /*block_height=*/4,
|
754 |
+
/*block_width=*/4);
|
755 |
+
FatCacheAlignedVector<float> rhs(kCols, /*cols=*/1);
|
756 |
+
CacheAlignedVector<float> bias(kRows);
|
757 |
+
FatCacheAlignedVector<float> out1(kRows, /*cols=*/1);
|
758 |
+
FatCacheAlignedVector<float> out2(kRows, /*cols=*/1);
|
759 |
+
|
760 |
+
bias.FillRandom();
|
761 |
+
rhs.FillRandom();
|
762 |
+
out1.FillZero();
|
763 |
+
out2.FillZero();
|
764 |
+
FatCacheAlignedVector<float> out_reference = out1;
|
765 |
+
|
766 |
+
CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
|
767 |
+
|
768 |
+
SparseLinearLayer<float, float> sparse_linear_layer(std::move(sparse_matrix),
|
769 |
+
std::move(bias));
|
770 |
+
sparse_linear_layer.PrepareForThreads(1);
|
771 |
+
sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false,
|
772 |
+
/*tid=*/0);
|
773 |
+
// Split the layer into 2 parts.
|
774 |
+
SparseLinearLayer<float, float> part1, part2;
|
775 |
+
sparse_linear_layer.SplitOutputs(&part1, &part2);
|
776 |
+
part1.PrepareForThreads(1);
|
777 |
+
part2.PrepareForThreads(1);
|
778 |
+
EXPECT_EQ(kRows / 2, part1.rows());
|
779 |
+
EXPECT_EQ(kCols, part1.cols());
|
780 |
+
EXPECT_EQ(kRows / 2, part2.rows());
|
781 |
+
EXPECT_EQ(kCols, part2.cols());
|
782 |
+
MutableVectorView<float> out2a(&out2, 0, kRows / 2);
|
783 |
+
MutableVectorView<float> out2b(&out2, kRows / 2, kRows / 2);
|
784 |
+
part1.SpMM_bias(rhs, &out2a, /*relu=*/false, /*tid=*/0);
|
785 |
+
part2.SpMM_bias(rhs, &out2b, /*relu=*/false, /*tid=*/0);
|
786 |
+
// Check that out2 = out_reference.
|
787 |
+
for (int i = 0; i < kRows; ++i) {
|
788 |
+
EXPECT_NEAR(out_reference[i], out2[i], 2e-5)
|
789 |
+
<< " i=" << i << " out1=" << out_reference[i] << " out2=" << out2[i];
|
790 |
+
}
|
791 |
+
}
|
792 |
+
|
793 |
+
TEST(CsrBlockSparseMatrix, MutableVectorView) {
|
794 |
+
const int kRows = 1024;
|
795 |
+
const int kCols = 1024;
|
796 |
+
const int kFatness = 2;
|
797 |
+
|
798 |
+
std::vector<float> values(kRows * kCols, 1.f);
|
799 |
+
std::vector<int> mask(kRows * kCols);
|
800 |
+
for (int i = 0; i < mask.size(); ++i) mask[i] = i % 2;
|
801 |
+
|
802 |
+
auto masked_matrix =
|
803 |
+
MaskedSparseMatrix<float>(kRows, kCols, mask.data(), values.data());
|
804 |
+
auto sparse_matrix = CsrBlockSparseMatrix<bfloat16, float>(masked_matrix);
|
805 |
+
FatCacheAlignedVector<float> x(kCols, kFatness);
|
806 |
+
x.FillOnes();
|
807 |
+
|
808 |
+
CacheAlignedVector<float> bias(kRows);
|
809 |
+
bias.FillZero();
|
810 |
+
|
811 |
+
// First check that we can use spans as output. Split a multiplication
|
812 |
+
// into upper and lower halves times the full vector:
|
813 |
+
// --------------- x t
|
814 |
+
// | | x t
|
815 |
+
// | | x t
|
816 |
+
// --------------- =
|
817 |
+
// | | x b
|
818 |
+
// | | x b
|
819 |
+
// --------------- x b
|
820 |
+
|
821 |
+
FatCacheAlignedVector<float> out(kRows, kFatness);
|
822 |
+
FatCacheAlignedVector<float> out_view(kRows, kFatness);
|
823 |
+
|
824 |
+
MutableVectorView<float> out_view_top(&out_view, 0, kRows / 2);
|
825 |
+
MutableVectorView<float> out_view_bottom(&out_view, kRows / 2, kRows / 2);
|
826 |
+
|
827 |
+
sparse_matrix.SpMM_bias(x, bias, &out);
|
828 |
+
|
829 |
+
auto masked_matrix_top =
|
830 |
+
MaskedSparseMatrix<float>(kRows / 2, kCols, mask.data(), values.data());
|
831 |
+
auto masked_matrix_bottom = MaskedSparseMatrix<float>(
|
832 |
+
kRows / 2, kCols, mask.data() + kRows * kCols / 2,
|
833 |
+
values.data() + kRows * kCols / 2);
|
834 |
+
auto sparse_matrix_top =
|
835 |
+
CsrBlockSparseMatrix<bfloat16, float>(masked_matrix_top);
|
836 |
+
auto sparse_matrix_bottom =
|
837 |
+
CsrBlockSparseMatrix<bfloat16, float>(masked_matrix_bottom);
|
838 |
+
|
839 |
+
sparse_matrix_top.SpMM_bias(x, bias, &out_view_top);
|
840 |
+
sparse_matrix_bottom.SpMM_bias(x, bias, &out_view_bottom);
|
841 |
+
|
842 |
+
CheckResult(out, out_view, kCols);
|
843 |
+
|
844 |
+
// Check that we can use a span as an input vector. Multiply upper left
|
845 |
+
// portion of the matrix by the top half of the vector.
|
846 |
+
// ---------------
|
847 |
+
// |oooooo | x q
|
848 |
+
// |oooooo | x q
|
849 |
+
// | | =
|
850 |
+
// | |
|
851 |
+
// ---------------
|
852 |
+
|
853 |
+
auto masked_matrix_quarter = MaskedSparseMatrix<float>(
|
854 |
+
kRows / 2, kCols / 2, mask.data(), values.data());
|
855 |
+
auto sparse_matrix_quarter =
|
856 |
+
CsrBlockSparseMatrix<bfloat16, float>(masked_matrix_quarter);
|
857 |
+
|
858 |
+
MutableVectorView<float> x_top(&x, 0, kCols / 2);
|
859 |
+
FatCacheAlignedVector<float> out_correct(kRows / 2, /*cols=*/2);
|
860 |
+
|
861 |
+
for (int i = 0; i < kFatness * (kRows / 2); ++i) out_correct[i] = 256.f;
|
862 |
+
|
863 |
+
MutableVectorView<float> bias_top(&bias, 0, kRows / 2);
|
864 |
+
FatCacheAlignedVector<float> out_quarter(kRows / 2, kFatness);
|
865 |
+
|
866 |
+
sparse_matrix_quarter.SpMM_bias(x_top, bias_top, &out_quarter);
|
867 |
+
|
868 |
+
CheckResult(out_correct, out_quarter, kCols / 2);
|
869 |
+
}
|
870 |
+
|
871 |
+
namespace {
|
872 |
+
|
873 |
+
bool skip_test(const absl::Status& status, absl::string_view msg) {
|
874 |
+
if (!status.ok()) {
|
875 |
+
LOG(INFO) << "Couldn't load " << msg << ", skipping test " << status;
|
876 |
+
return true;
|
877 |
+
}
|
878 |
+
|
879 |
+
return false;
|
880 |
+
}
|
881 |
+
|
882 |
+
} // namespace
|
883 |
+
|
884 |
+
TEST(CsrBlockSparseMatrix, ModelMatrices_Bfloat16) {
|
885 |
+
std::vector<std::string> names = {
|
886 |
+
"768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_",
|
887 |
+
"768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_",
|
888 |
+
"768_512_95_4x4_finelogit_", "lyra_conv1d_"};
|
889 |
+
const std::string kPath =
|
890 |
+
#if defined __arm__ || defined __aarch64__
|
891 |
+
"/data/local/tmp/";
|
892 |
+
#else
|
893 |
+
(ghc::filesystem::current_path() / kTestdataPath).string();
|
894 |
+
#endif
|
895 |
+
for (auto& layer_name : names) {
|
896 |
+
SparseLinearLayer<bfloat16, float> sparse_linear_layer;
|
897 |
+
auto status = LoadSparseLayer<bfloat16, float>(layer_name, /*zipped=*/true,
|
898 |
+
&sparse_linear_layer, kPath);
|
899 |
+
// If the files don't exist on the device we're running on, just skip this
|
900 |
+
// test and log that it was skipped.
|
901 |
+
if (skip_test(status, layer_name)) return;
|
902 |
+
|
903 |
+
int rows = sparse_linear_layer.rows();
|
904 |
+
int cols = sparse_linear_layer.cols();
|
905 |
+
|
906 |
+
MaskedLinearLayer<float> masked_linear_layer;
|
907 |
+
status = LoadMaskedLayer<float>(layer_name, /*zipped=*/true,
|
908 |
+
&masked_linear_layer, kPath);
|
909 |
+
if (skip_test(status, layer_name)) return;
|
910 |
+
masked_linear_layer.CastWeights<csrblocksparse::bfloat16>();
|
911 |
+
|
912 |
+
CacheAlignedVector<float> rhs(cols);
|
913 |
+
CacheAlignedVector<float> out_ref(rows);
|
914 |
+
CacheAlignedVector<float> out_spmv(rows);
|
915 |
+
|
916 |
+
rhs.FillRandom();
|
917 |
+
out_ref.FillZero();
|
918 |
+
out_spmv.FillZero();
|
919 |
+
|
920 |
+
std::array<bool, 2> use_relus = {false, true};
|
921 |
+
for (bool use_relu : use_relus) {
|
922 |
+
masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu);
|
923 |
+
sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu);
|
924 |
+
|
925 |
+
CheckResult(out_ref, out_spmv, cols);
|
926 |
+
}
|
927 |
+
}
|
928 |
+
}
|
929 |
+
|
930 |
+
TEST(CsrBlockSparseMatrix, ModelMatrices_float) {
|
931 |
+
std::vector<std::string> names = {
|
932 |
+
"768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_",
|
933 |
+
"768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_",
|
934 |
+
"768_512_95_4x4_finelogit_", "lyra_conv1d_"};
|
935 |
+
const std::string kPath =
|
936 |
+
#if defined __arm__ || defined __aarch64__
|
937 |
+
"/data/local/tmp/";
|
938 |
+
#else
|
939 |
+
(ghc::filesystem::current_path() / kTestdataPath).string();
|
940 |
+
#endif
|
941 |
+
for (auto& layer_name : names) {
|
942 |
+
SparseLinearLayer<float, float> sparse_linear_layer;
|
943 |
+
auto status = LoadSparseLayer<float, float>(layer_name, /*zipped=*/true,
|
944 |
+
&sparse_linear_layer, kPath);
|
945 |
+
// If the files don't exist on the device we're running on, just skip this
|
946 |
+
// test and log that it was skipped.
|
947 |
+
if (skip_test(status, layer_name)) return;
|
948 |
+
|
949 |
+
int rows = sparse_linear_layer.rows();
|
950 |
+
int cols = sparse_linear_layer.cols();
|
951 |
+
|
952 |
+
MaskedLinearLayer<float> masked_linear_layer;
|
953 |
+
status = LoadMaskedLayer<float>(layer_name, /*zipped=*/true,
|
954 |
+
&masked_linear_layer, kPath);
|
955 |
+
if (skip_test(status, layer_name)) return;
|
956 |
+
|
957 |
+
CacheAlignedVector<float> rhs(cols);
|
958 |
+
CacheAlignedVector<float> out_ref(rows);
|
959 |
+
CacheAlignedVector<float> out_spmv(rows);
|
960 |
+
|
961 |
+
rhs.FillRandom();
|
962 |
+
out_ref.FillZero();
|
963 |
+
out_spmv.FillZero();
|
964 |
+
|
965 |
+
std::array<bool, 2> use_relus = {false, true};
|
966 |
+
for (bool use_relu : use_relus) {
|
967 |
+
masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu);
|
968 |
+
sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu);
|
969 |
+
|
970 |
+
CheckResult(out_ref, out_spmv, cols);
|
971 |
+
}
|
972 |
+
}
|
973 |
+
}
|
974 |
+
|
975 |
+
#undef SKIP_TEST
|
976 |
+
|
977 |
+
} // namespace csrblocksparse
|
sparse_matmul/layers/errno_mapping.cc
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2021 Google LLC
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.
|
14 |
+
|
15 |
+
#include "sparse_matmul/layers/errno_mapping.h"
|
16 |
+
|
17 |
+
#include <string>
|
18 |
+
|
19 |
+
#include "absl/strings/str_cat.h"
|
20 |
+
|
21 |
+
namespace csrblocksparse {
|
22 |
+
|
23 |
+
namespace {
|
24 |
+
|
25 |
+
absl::StatusCode ErrnoToCode(int error_number) {
|
26 |
+
switch (error_number) {
|
27 |
+
case 0:
|
28 |
+
return absl::StatusCode::kOk;
|
29 |
+
case EINVAL: // Invalid argument
|
30 |
+
case ENAMETOOLONG: // Filename too long
|
31 |
+
case E2BIG: // Argument list too long
|
32 |
+
case EDESTADDRREQ: // Destination address required
|
33 |
+
case EDOM: // Mathematics argument out of domain of function
|
34 |
+
case EFAULT: // Bad address
|
35 |
+
case EILSEQ: // Illegal byte sequence
|
36 |
+
case ENOPROTOOPT: // Protocol not available
|
37 |
+
case ENOSTR: // Not a STREAM
|
38 |
+
case ENOTSOCK: // Not a socket
|
39 |
+
case ENOTTY: // Inappropriate I/O control operation
|
40 |
+
case EPROTOTYPE: // Protocol wrong type for socket
|
41 |
+
case ESPIPE: // Invalid seek
|
42 |
+
return absl::StatusCode::kInvalidArgument;
|
43 |
+
case ETIMEDOUT: // Connection timed out
|
44 |
+
case ETIME: // Timer expired
|
45 |
+
return absl::StatusCode::kDeadlineExceeded;
|
46 |
+
case ENODEV: // No such device
|
47 |
+
case ENOENT: // No such file or directory
|
48 |
+
#ifdef ENOMEDIUM
|
49 |
+
case ENOMEDIUM: // No medium found
|
50 |
+
#endif
|
51 |
+
case ENXIO: // No such device or address
|
52 |
+
case ESRCH: // No such process
|
53 |
+
return absl::StatusCode::kNotFound;
|
54 |
+
case EEXIST: // File exists
|
55 |
+
case EADDRNOTAVAIL: // Address not available
|
56 |
+
case EALREADY: // Connection already in progress
|
57 |
+
#ifdef ENOTUNIQ
|
58 |
+
case ENOTUNIQ: // Name not unique on network
|
59 |
+
#endif
|
60 |
+
return absl::StatusCode::kAlreadyExists;
|
61 |
+
case EPERM: // Operation not permitted
|
62 |
+
case EACCES: // Permission denied
|
63 |
+
#ifdef ENOKEY
|
64 |
+
case ENOKEY: // Required key not available
|
65 |
+
#endif
|
66 |
+
case EROFS: // Read only file system
|
67 |
+
return absl::StatusCode::kPermissionDenied;
|
68 |
+
case ENOTEMPTY: // Directory not empty
|
69 |
+
case EISDIR: // Is a directory
|
70 |
+
case ENOTDIR: // Not a directory
|
71 |
+
case EADDRINUSE: // Address already in use
|
72 |
+
case EBADF: // Invalid file descriptor
|
73 |
+
#ifdef EBADFD
|
74 |
+
case EBADFD: // File descriptor in bad state
|
75 |
+
#endif
|
76 |
+
case EBUSY: // Device or resource busy
|
77 |
+
case ECHILD: // No child processes
|
78 |
+
case EISCONN: // Socket is connected
|
79 |
+
#ifdef EISNAM
|
80 |
+
case EISNAM: // Is a named type file
|
81 |
+
#endif
|
82 |
+
#ifdef ENOTBLK
|
83 |
+
case ENOTBLK: // Block device required
|
84 |
+
#endif
|
85 |
+
case ENOTCONN: // The socket is not connected
|
86 |
+
case EPIPE: // Broken pipe
|
87 |
+
#ifdef ESHUTDOWN
|
88 |
+
case ESHUTDOWN: // Cannot send after transport endpoint shutdown
|
89 |
+
#endif
|
90 |
+
case ETXTBSY: // Text file busy
|
91 |
+
#ifdef EUNATCH
|
92 |
+
case EUNATCH: // Protocol driver not attached
|
93 |
+
#endif
|
94 |
+
return absl::StatusCode::kFailedPrecondition;
|
95 |
+
case ENOSPC: // No space left on device
|
96 |
+
#ifdef EDQUOT
|
97 |
+
case EDQUOT: // Disk quota exceeded
|
98 |
+
#endif
|
99 |
+
case EMFILE: // Too many open files
|
100 |
+
case EMLINK: // Too many links
|
101 |
+
case ENFILE: // Too many open files in system
|
102 |
+
case ENOBUFS: // No buffer space available
|
103 |
+
case ENODATA: // No message is available on the STREAM read queue
|
104 |
+
case ENOMEM: // Not enough space
|
105 |
+
case ENOSR: // No STREAM resources
|
106 |
+
#ifdef EUSERS
|
107 |
+
case EUSERS: // Too many users
|
108 |
+
#endif
|
109 |
+
return absl::StatusCode::kResourceExhausted;
|
110 |
+
#ifdef ECHRNG
|
111 |
+
case ECHRNG: // Channel number out of range
|
112 |
+
#endif
|
113 |
+
case EFBIG: // File too large
|
114 |
+
case EOVERFLOW: // Value too large to be stored in data type
|
115 |
+
case ERANGE: // Result too large
|
116 |
+
return absl::StatusCode::kOutOfRange;
|
117 |
+
#ifdef ENOPKG
|
118 |
+
case ENOPKG: // Package not installed
|
119 |
+
#endif
|
120 |
+
case ENOSYS: // Function not implemented
|
121 |
+
case ENOTSUP: // Operation not supported
|
122 |
+
case EAFNOSUPPORT: // Address family not supported
|
123 |
+
#ifdef EPFNOSUPPORT
|
124 |
+
case EPFNOSUPPORT: // Protocol family not supported
|
125 |
+
#endif
|
126 |
+
case EPROTONOSUPPORT: // Protocol not supported
|
127 |
+
#ifdef ESOCKTNOSUPPORT
|
128 |
+
case ESOCKTNOSUPPORT: // Socket type not supported
|
129 |
+
#endif
|
130 |
+
case EXDEV: // Improper link
|
131 |
+
return absl::StatusCode::kUnimplemented;
|
132 |
+
case EAGAIN: // Resource temporarily unavailable
|
133 |
+
#ifdef ECOMM
|
134 |
+
case ECOMM: // Communication error on send
|
135 |
+
#endif
|
136 |
+
case ECONNREFUSED: // Connection refused
|
137 |
+
case ECONNABORTED: // Connection aborted
|
138 |
+
case ECONNRESET: // Connection reset
|
139 |
+
case EINTR: // Interrupted function call
|
140 |
+
#ifdef EHOSTDOWN
|
141 |
+
case EHOSTDOWN: // Host is down
|
142 |
+
#endif
|
143 |
+
case EHOSTUNREACH: // Host is unreachable
|
144 |
+
case ENETDOWN: // Network is down
|
145 |
+
case ENETRESET: // Connection aborted by network
|
146 |
+
case ENETUNREACH: // Network unreachable
|
147 |
+
case ENOLCK: // No locks available
|
148 |
+
case ENOLINK: // Link has been severed
|
149 |
+
#ifdef ENONET
|
150 |
+
case ENONET: // Machine is not on the network
|
151 |
+
#endif
|
152 |
+
return absl::StatusCode::kUnavailable;
|
153 |
+
case EDEADLK: // Resource deadlock avoided
|
154 |
+
#ifdef ESTALE
|
155 |
+
case ESTALE: // Stale file handle
|
156 |
+
#endif
|
157 |
+
return absl::StatusCode::kAborted;
|
158 |
+
case ECANCELED: // Operation cancelled
|
159 |
+
return absl::StatusCode::kCancelled;
|
160 |
+
default:
|
161 |
+
return absl::StatusCode::kUnknown;
|
162 |
+
}
|
163 |
+
}
|
164 |
+
|
165 |
+
// POSIX `strerror_r()` returns `int`.
|
166 |
+
ABSL_ATTRIBUTE_UNUSED std::string StrErrorResult(int result, const char* buffer,
|
167 |
+
int error_code) {
|
168 |
+
if (ABSL_PREDICT_FALSE(result != 0)) {
|
169 |
+
return absl::StrCat("Unknown error ", error_code);
|
170 |
+
}
|
171 |
+
return buffer;
|
172 |
+
}
|
173 |
+
|
174 |
+
// GNU `strerror_r()` returns `char*`.
|
175 |
+
ABSL_ATTRIBUTE_UNUSED std::string StrErrorResult(char* result,
|
176 |
+
const char* buffer,
|
177 |
+
int error_code) {
|
178 |
+
return result;
|
179 |
+
}
|
180 |
+
|
181 |
+
std::string StrError(int error_code) {
|
182 |
+
char message[256];
|
183 |
+
return StrErrorResult(strerror_r(error_code, message, sizeof(message)),
|
184 |
+
message, error_code);
|
185 |
+
}
|
186 |
+
|
187 |
+
} // namespace
|
188 |
+
|
189 |
+
absl::Status ErrnoToCanonicalStatus(int error_number,
|
190 |
+
absl::string_view message) {
|
191 |
+
return absl::Status(ErrnoToCode(error_number),
|
192 |
+
absl::StrCat(message, ": ", StrError(error_number)));
|
193 |
+
}
|
194 |
+
|
195 |
+
} // namespace csrblocksparse
|
sparse_matmul/layers/errno_mapping.h
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2021 Google LLC
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.
|
14 |
+
|
15 |
+
#ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_
|
16 |
+
#define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_
|
17 |
+
|
18 |
+
#include "absl/status/status.h"
|
19 |
+
#include "absl/strings/string_view.h"
|
20 |
+
|
21 |
+
namespace csrblocksparse {
|
22 |
+
|
23 |
+
// Converts |error_number| value to absl::Status.
|
24 |
+
absl::Status ErrnoToCanonicalStatus(int error_number,
|
25 |
+
absl::string_view message);
|
26 |
+
|
27 |
+
} // namespace csrblocksparse
|
28 |
+
|
29 |
+
#endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_ERRNO_MAPPING_H_
|
sparse_matmul/layers/masked_sparse_matrix.h
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
|
19 |
+
|
20 |
+
#include <algorithm>
|
21 |
+
#include <cstdio>
|
22 |
+
#include <numeric>
|
23 |
+
#include <vector>
|
24 |
+
|
25 |
+
#include "absl/strings/str_format.h"
|
26 |
+
#include "sparse_matmul/vector/cache_aligned_vector.h"
|
27 |
+
|
28 |
+
namespace csrblocksparse {
|
29 |
+
|
30 |
+
// MaskedSparseMatrix serves two purposes:
|
31 |
+
// 1) It is useful as a reference implementation of SpMV for correctness
|
32 |
+
// checking the much more complicated implementations in CSRBlockSparseMatrix
|
33 |
+
// 2) This is the format that sparse matrices are represented after pruning
|
34 |
+
// in TF. This class provides a bridge to getting these parameters into
|
35 |
+
// a compressed form suitable for computation and serialization.
|
36 |
+
//
|
37 |
+
// MaskedSparseMatrix<float> matrix(rows, cols, mask_from_tf, values_from_tf);
|
38 |
+
// CSRBlockSparseMatrix<float, bfloat16, int16_t> csr_matrix(matrix);
|
39 |
+
// csr_matrix.Multiply(rhs, bias, &out);
|
40 |
+
template <typename T>
|
41 |
+
class MaskedSparseMatrix {
|
42 |
+
public:
|
43 |
+
MaskedSparseMatrix() {}
|
44 |
+
|
45 |
+
// Construct a MaskedSparseMatrix of the given size, sparsity and block size.
|
46 |
+
// This is mainly useful for testing.
|
47 |
+
MaskedSparseMatrix(int rows, int cols, float sparsity, int block_height = 1,
|
48 |
+
int block_width = 1, float constant = 1.f,
|
49 |
+
bool random = true)
|
50 |
+
: rows_(rows), cols_(cols), sparsity_(sparsity) {
|
51 |
+
CHECK_EQ(rows % block_height, 0);
|
52 |
+
CHECK_EQ(cols % block_width, 0);
|
53 |
+
|
54 |
+
init(sparsity, block_height, block_width, constant, random);
|
55 |
+
}
|
56 |
+
|
57 |
+
// Construct from an existing mask and values (most likely from a TF model).
|
58 |
+
template <typename MaskType>
|
59 |
+
MaskedSparseMatrix(int rows, int cols, const MaskType* mask, const T* values)
|
60 |
+
: rows_(rows), cols_(cols) {
|
61 |
+
mask_.resize(rows * cols);
|
62 |
+
values_.resize(rows * cols);
|
63 |
+
std::copy_n(mask, rows * cols, mask_.begin());
|
64 |
+
std::copy_n(values, rows * cols, values_.begin());
|
65 |
+
sparsity_ =
|
66 |
+
1.f - std::accumulate(mask_.begin(), mask_.end(), 0.f) / mask_.size();
|
67 |
+
}
|
68 |
+
|
69 |
+
const std::vector<int>& mask() const { return mask_; }
|
70 |
+
const std::vector<T>& values() const { return values_; }
|
71 |
+
T* data() { return values_.data(); }
|
72 |
+
const T* data() const { return values_.data(); }
|
73 |
+
|
74 |
+
int rows() const { return rows_; }
|
75 |
+
int cols() const { return cols_; }
|
76 |
+
float sparsity() const { return sparsity_; }
|
77 |
+
|
78 |
+
void Print() const {
|
79 |
+
absl::PrintF("-------Values---------\n");
|
80 |
+
for (int r = 0; r < rows_; ++r) {
|
81 |
+
for (int c = 0; c < cols_; ++c) {
|
82 |
+
absl::PrintF("%+6.3f ", static_cast<float>(values_[r * cols_ + c]));
|
83 |
+
}
|
84 |
+
absl::PrintF("\n");
|
85 |
+
}
|
86 |
+
absl::PrintF("-------Mask---------\n");
|
87 |
+
for (int r = 0; r < rows_; ++r) {
|
88 |
+
for (int c = 0; c < cols_; ++c) {
|
89 |
+
printf("%2d ", mask_[r * cols_ + c]);
|
90 |
+
}
|
91 |
+
absl::PrintF("\n");
|
92 |
+
}
|
93 |
+
}
|
94 |
+
|
95 |
+
// This routine is useful for rounding the possibly higher precision values
|
96 |
+
// stored in this class to a lower precision, so that correctness checks
|
97 |
+
// between this class and CSRBlockSparseMatrix can have a tighter tolerance.
|
98 |
+
template <typename U>
|
99 |
+
void CastWeights() {
|
100 |
+
for (int i = 0; i < values_.size(); ++i) {
|
101 |
+
values_[i] = static_cast<T>(U(values_[i]));
|
102 |
+
}
|
103 |
+
}
|
104 |
+
|
105 |
+
// Only meant for correctness checking.
|
106 |
+
// RhsClassType is meant to be either CacheAlignedVector OR
|
107 |
+
// FatCacheAlignedVector.
|
108 |
+
// The weight matrix is ROW MAJOR and RhsClassType is COLUMN MAJOR.
|
109 |
+
// |bias| is broadcast if |rhs| has more than one column.
|
110 |
+
template <typename RhsClassType, typename BiasType, typename OutClassType,
|
111 |
+
typename RhsType = typename RhsClassType::value_type,
|
112 |
+
typename OutType = typename OutClassType::value_type>
|
113 |
+
void SpMM_bias(const RhsClassType& rhs,
|
114 |
+
const CacheAlignedVector<BiasType>& bias, OutClassType* out,
|
115 |
+
bool relu = false) {
|
116 |
+
for (int r = 0; r < rows_; ++r) {
|
117 |
+
for (int n = 0; n < rhs.cols(); ++n) {
|
118 |
+
float sum = 0.f;
|
119 |
+
const RhsType* rhs_ptr = rhs.data() + n * rhs.rows();
|
120 |
+
OutType* out_ptr = out->data() + n * out->rows();
|
121 |
+
const int* mask_ptr = mask_.data() + r * cols_;
|
122 |
+
const T* value_ptr = values_.data() + r * cols_;
|
123 |
+
for (int c = 0; c < cols_; ++c) {
|
124 |
+
sum += mask_ptr[c] * static_cast<float>(value_ptr[c]) *
|
125 |
+
static_cast<float>(rhs_ptr[c]);
|
126 |
+
}
|
127 |
+
out_ptr[r] = static_cast<OutType>(
|
128 |
+
relu ? std::max(sum + static_cast<float>(bias[r]), 0.f)
|
129 |
+
: sum + static_cast<float>(bias[r]));
|
130 |
+
}
|
131 |
+
}
|
132 |
+
}
|
133 |
+
|
134 |
+
private:
|
135 |
+
// Generate a random matrix with the specified sparsity.
|
136 |
+
// Useful for testing.
|
137 |
+
void init(float sparsity, int block_height, int block_width, float constant,
|
138 |
+
bool random = true) {
|
139 |
+
int reduced_rows = rows_ / block_height;
|
140 |
+
int reduced_cols = cols_ / block_width;
|
141 |
+
mask_.resize(rows_ * cols_, 0);
|
142 |
+
|
143 |
+
// Fill with non-zero value to make sure masking works.
|
144 |
+
values_.resize(rows_ * cols_, static_cast<T>(2.f));
|
145 |
+
|
146 |
+
std::mt19937 generator(0);
|
147 |
+
std::uniform_real_distribution<float> dist_sparsity;
|
148 |
+
std::uniform_real_distribution<float> dist_value(-1.f, 1.f);
|
149 |
+
int nnz = 0;
|
150 |
+
while (nnz == 0) {
|
151 |
+
for (int r = 0; r < reduced_rows; ++r) {
|
152 |
+
for (int c = 0; c < reduced_cols; ++c) {
|
153 |
+
if (dist_sparsity(generator) > sparsity) {
|
154 |
+
nnz++;
|
155 |
+
for (int i = 0; i < block_height; ++i) {
|
156 |
+
for (int j = 0; j < block_width; ++j) {
|
157 |
+
mask_[(r * block_height + i) * cols_ + block_width * c + j] = 1;
|
158 |
+
values_[(r * block_height + i) * cols_ + block_width * c + j] =
|
159 |
+
static_cast<T>(random ? dist_value(generator) : constant);
|
160 |
+
}
|
161 |
+
}
|
162 |
+
}
|
163 |
+
}
|
164 |
+
}
|
165 |
+
}
|
166 |
+
}
|
167 |
+
|
168 |
+
std::vector<int> mask_;
|
169 |
+
std::vector<T> values_;
|
170 |
+
int rows_;
|
171 |
+
int cols_;
|
172 |
+
float sparsity_;
|
173 |
+
};
|
174 |
+
|
175 |
+
template <typename T>
|
176 |
+
class MaskedLinearLayer {
|
177 |
+
public:
|
178 |
+
MaskedLinearLayer(MaskedSparseMatrix<T>&& weights,
|
179 |
+
CacheAlignedVector<T>&& bias)
|
180 |
+
: weights_(std::move(weights)), bias_(std::move(bias)) {}
|
181 |
+
|
182 |
+
MaskedLinearLayer() {}
|
183 |
+
|
184 |
+
template <typename U>
|
185 |
+
void CastWeights() {
|
186 |
+
weights_.template CastWeights<U>();
|
187 |
+
}
|
188 |
+
|
189 |
+
// Does Ax + b where A is a masked sparse ROW MAJOR matrix and
|
190 |
+
// x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is
|
191 |
+
// broadcast is rhs has more than one column.
|
192 |
+
template <typename FatVector>
|
193 |
+
void SpMM_bias(const FatVector& rhs, FatVector* out, bool relu = false) {
|
194 |
+
static_assert(std::is_same<typename FatVector::value_type, T>::value,
|
195 |
+
"FatVector value_type must match masked_linear_layer type");
|
196 |
+
weights_.SpMM_bias(rhs, bias_, out, relu);
|
197 |
+
}
|
198 |
+
|
199 |
+
private:
|
200 |
+
MaskedSparseMatrix<T> weights_;
|
201 |
+
CacheAlignedVector<T> bias_;
|
202 |
+
};
|
203 |
+
|
204 |
+
} // namespace csrblocksparse
|
205 |
+
|
206 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_MASKED_SPARSE_MATRIX_H_
|
sparse_matmul/layers/read_array_ifstream.h
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
// Low-level array reading function using std::ifstream.
|
18 |
+
|
19 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_
|
20 |
+
#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_
|
21 |
+
|
22 |
+
#include <cstdint>
|
23 |
+
#include <fstream>
|
24 |
+
#include <sstream>
|
25 |
+
#include <string>
|
26 |
+
|
27 |
+
#include "absl/status/status.h"
|
28 |
+
#include "absl/strings/substitute.h"
|
29 |
+
#include "include/ghc/filesystem.hpp"
|
30 |
+
|
31 |
+
namespace csrblocksparse {
|
32 |
+
namespace detail {
|
33 |
+
|
34 |
+
template <typename T>
|
35 |
+
absl::Status ReadArrayIfstream(const std::string& file_name,
|
36 |
+
const std::string& path, std::vector<T>* array,
|
37 |
+
int64_t* length) {
|
38 |
+
ghc::filesystem::path complete_path(path);
|
39 |
+
complete_path /= file_name;
|
40 |
+
std::ifstream in_stream(complete_path.u8string(), std::ios::binary);
|
41 |
+
if (!in_stream.is_open()) {
|
42 |
+
return absl::UnknownError(
|
43 |
+
absl::Substitute("Error opening $0", complete_path.string()));
|
44 |
+
}
|
45 |
+
|
46 |
+
std::stringstream buffer;
|
47 |
+
buffer << in_stream.rdbuf();
|
48 |
+
if (buffer.str().empty()) {
|
49 |
+
LOG(ERROR) << "File " << complete_path << " was empty.";
|
50 |
+
return absl::UnknownError(
|
51 |
+
absl::Substitute("File $0 was empty", complete_path.string()));
|
52 |
+
}
|
53 |
+
std::string contents = buffer.str();
|
54 |
+
*length = contents.length();
|
55 |
+
int64_t elem = (*length + sizeof(T) - 1) / sizeof(T);
|
56 |
+
array->resize(elem);
|
57 |
+
std::move(contents.begin(), contents.end(),
|
58 |
+
reinterpret_cast<char*>(array->data()));
|
59 |
+
|
60 |
+
return absl::OkStatus();
|
61 |
+
}
|
62 |
+
|
63 |
+
} // namespace detail
|
64 |
+
} // namespace csrblocksparse
|
65 |
+
|
66 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_READ_ARRAY_IFSTREAM_H_
|
sparse_matmul/layers/sparse_linear_layer.h
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/*
|
2 |
+
* Copyright 2021 Google LLC
|
3 |
+
*
|
4 |
+
* Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
* you may not use this file except in compliance with the License.
|
6 |
+
* You may obtain a copy of the License at
|
7 |
+
*
|
8 |
+
* http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
*
|
10 |
+
* Unless required by applicable law or agreed to in writing, software
|
11 |
+
* distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
* See the License for the specific language governing permissions and
|
14 |
+
* limitations under the License.
|
15 |
+
*/
|
16 |
+
|
17 |
+
#ifndef LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_
|
18 |
+
#define LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_
|
19 |
+
|
20 |
+
#include <cstdint>
|
21 |
+
|
22 |
+
#include "absl/memory/memory.h"
|
23 |
+
#include "glog/logging.h"
|
24 |
+
#include "sparse_matmul/layers/csr_blocksparse_matrix.h"
|
25 |
+
#include "sparse_matmul/layers/masked_sparse_matrix.h"
|
26 |
+
#include "sparse_matmul/numerics/type_utils.h"
|
27 |
+
#include "sparse_matmul/os/coop_threads.h"
|
28 |
+
#include "sparse_matmul/vector/cache_aligned_vector.h"
|
29 |
+
|
30 |
+
namespace csrblocksparse {
|
31 |
+
|
32 |
+
template <typename WeightType, typename RhsType,
|
33 |
+
typename BiasType = typename TypeOfProduct<WeightType, RhsType>::type,
|
34 |
+
typename DeltaType = int16_t>
|
35 |
+
class SparseLinearLayer {
|
36 |
+
public:
|
37 |
+
SparseLinearLayer() {}
|
38 |
+
|
39 |
+
SparseLinearLayer(CsrBlockSparseMatrix<WeightType, RhsType>&& sparse_matrix,
|
40 |
+
CacheAlignedVector<BiasType>&& bias)
|
41 |
+
: sparse_matrix_(std::move(sparse_matrix)), full_bias_(std::move(bias)) {
|
42 |
+
CHECK_EQ(sparse_matrix_.rows(), full_bias_.size());
|
43 |
+
// Some kernels expect that the bias is divided by 4, so we store a second
|
44 |
+
// copy of a quarter of the bias.
|
45 |
+
// TODO(b/189958858): Remove the quartered bias if it can be done without
|
46 |
+
// loss of speed, and rename the |full_bias_| member back to |bias_|.
|
47 |
+
bias_ = full_bias_;
|
48 |
+
for (int i = 0; i < bias_.size(); ++i) {
|
49 |
+
bias_[i] = static_cast<BiasType>(.25f * static_cast<float>(bias_[i]));
|
50 |
+
}
|
51 |
+
}
|
52 |
+
SparseLinearLayer(
|
53 |
+
const SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>& src) {
|
54 |
+
*this = src;
|
55 |
+
}
|
56 |
+
SparseLinearLayer& operator=(
|
57 |
+
const SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>& src) {
|
58 |
+
sparse_matrix_ = src.sparse_matrix_;
|
59 |
+
bias_ = src.bias_;
|
60 |
+
full_bias_ = src.full_bias_;
|
61 |
+
mid_output_ = src.mid_output_;
|
62 |
+
thread_layers_ = src.thread_layers_;
|
63 |
+
num_threads_ = src.num_threads_;
|
64 |
+
if (src.split_pc_) {
|
65 |
+
split_pc_ = absl::make_unique<ProducerConsumer>(
|
66 |
+
src.split_pc_->num_producers(), src.split_pc_->num_consumers());
|
67 |
+
}
|
68 |
+
return *this;
|
69 |
+
}
|
70 |
+
|
71 |
+
// Does Ax + b where A is a block sparse compressed sparse row matrix and
|
72 |
+
// x is a COLUMN MAJOR dense vector or matrix. Bias is a vector that is
|
73 |
+
// broadcast if rhs has more than one column.
|
74 |
+
template <typename RhsClassType, typename OutType>
|
75 |
+
void SpMM_bias(const RhsClassType& rhs, OutType* out, bool relu = false,
|
76 |
+
int tid = 0, SpinBarrier* barrier = nullptr) const {
|
77 |
+
static_assert(
|
78 |
+
std::is_same<typename RhsClassType::value_type, RhsType>::value, "");
|
79 |
+
sparse_matrix_.SpMM_bias(rhs, bias_, out, relu, tid, barrier);
|
80 |
+
}
|
81 |
+
// Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above,
|
82 |
+
// and then samples from the output (softmax distribution) layer.
|
83 |
+
template <typename RhsClassType, typename OutType>
|
84 |
+
int SpMM_bias_Sample(const RhsClassType& rhs, OutType* out, float temperature,
|
85 |
+
int tid, SpinBarrier* barrier, std::minstd_rand* gen,
|
86 |
+
CacheAlignedVector<float>* scratch) const {
|
87 |
+
static_assert(
|
88 |
+
std::is_same<typename RhsClassType::value_type, RhsType>::value, "");
|
89 |
+
return sparse_matrix_.SpMM_bias_Sample(rhs, bias_, out, temperature, tid,
|
90 |
+
barrier, gen, scratch);
|
91 |
+
}
|
92 |
+
template <typename RhsClassType, typename OutType>
|
93 |
+
void MatVec(const RhsClassType& rhs, bool relu, int tid, int replicas,
|
94 |
+
int output_stride, OutType* output,
|
95 |
+
SpinBarrier* barrier = nullptr) {
|
96 |
+
static_assert(
|
97 |
+
std::is_same<typename RhsClassType::value_type, RhsType>::value, "");
|
98 |
+
#ifdef __AVX2__
|
99 |
+
if (block_width() == 4 && (block_height() == 4 || block_height() == 8) &&
|
100 |
+
!IsCustomFloatType<WeightType>::value) {
|
101 |
+
if (!IsSplit()) {
|
102 |
+
sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu,
|
103 |
+
tid, replicas, output_stride, output->data());
|
104 |
+
if (barrier != nullptr) barrier->barrier();
|
105 |
+
return;
|
106 |
+
}
|
107 |
+
// NOTE: Until the quartered bias is removed it is a bad idea to split
|
108 |
+
// for ARM in the same way, as we would have to quarter the output of
|
109 |
+
// the first part of the split before running the second part.
|
110 |
+
// Signal completion of the previous MatVec.
|
111 |
+
split_pc_->produce();
|
112 |
+
PartLinearLayer& thread_part = thread_layers_[tid];
|
113 |
+
auto offset_output =
|
114 |
+
sparse_matrix_.thread_bounds().OffsetOutput(output->data(), tid);
|
115 |
+
auto mid_output =
|
116 |
+
sparse_matrix_.thread_bounds().OffsetOutput(mid_output_.data(), tid);
|
117 |
+
auto offset_bias = sparse_matrix_.thread_bounds().OffsetOutput(
|
118 |
+
mid_output_.cast_data(), tid);
|
119 |
+
// We can continue to consume the data that this thread produced and
|
120 |
+
// compute just the |self_matrix| part.
|
121 |
+
// No |relu| or |replicas|, as this is only a partial matmul.
|
122 |
+
// |tid| is always zero because the matrix has been split by tid.
|
123 |
+
thread_part.self_matrix.MatVec(
|
124 |
+
rhs.cast_data(), thread_part.full_bias.cast_data(), /*relu=*/false,
|
125 |
+
/*tid=*/0, /*replicas=*/1, output_stride, mid_output);
|
126 |
+
// We have to wait for the other threads to finish working on the previous
|
127 |
+
// MatMul before consuming the rest of |rhs|.
|
128 |
+
split_pc_->consume();
|
129 |
+
thread_part.other_matrix.MatVec(rhs.cast_data(), offset_bias, relu,
|
130 |
+
/*tid=*/0, replicas, output_stride,
|
131 |
+
offset_output);
|
132 |
+
return;
|
133 |
+
}
|
134 |
+
#endif
|
135 |
+
DCHECK_EQ(replicas, 1) << "Must have single replica for SpMM API";
|
136 |
+
if (IsSplit()) {
|
137 |
+
// Generics aren't setup to use a split matrix. This will be inefficient.
|
138 |
+
split_pc_->produce();
|
139 |
+
split_pc_->consume();
|
140 |
+
}
|
141 |
+
if (block_height() == 8) {
|
142 |
+
// We are currently forced to use MatVec generics for this case.
|
143 |
+
LOG(WARNING) << "Need to implement MatVec for 8x4 for non-AVX2 targets!!";
|
144 |
+
sparse_matrix_.MatVec(rhs.cast_data(), full_bias_.cast_data(), relu, tid,
|
145 |
+
replicas, output_stride, output->data());
|
146 |
+
if (barrier != nullptr) barrier->barrier();
|
147 |
+
} else {
|
148 |
+
sparse_matrix_.SpMM_bias(rhs, bias_, output, relu, tid, barrier);
|
149 |
+
}
|
150 |
+
}
|
151 |
+
|
152 |
+
int rows() const { return sparse_matrix_.rows(); }
|
153 |
+
int cols() const { return sparse_matrix_.cols(); }
|
154 |
+
float sparsity() const { return sparse_matrix_.sparsity(); }
|
155 |
+
int block_width() const { return sparse_matrix_.block_width(); }
|
156 |
+
int block_height() const { return sparse_matrix_.block_height(); }
|
157 |
+
int num_threads() const { return sparse_matrix_.num_threads(); }
|
158 |
+
const CacheAlignedVector<BiasType>& bias() const { return bias_; }
|
159 |
+
const std::vector<int>& split_points() const {
|
160 |
+
return sparse_matrix_.split_points();
|
161 |
+
}
|
162 |
+
bool IsSplit() const {
|
163 |
+
return !thread_layers_.empty() && split_pc_ != nullptr;
|
164 |
+
}
|
165 |
+
|
166 |
+
std::size_t bytes() const { return sparse_matrix_.bytes() + bias_.bytes(); }
|
167 |
+
void Print() const {
|
168 |
+
printf("Matrix\n");
|
169 |
+
sparse_matrix_.Print();
|
170 |
+
printf("Bias\n");
|
171 |
+
bias_.Print();
|
172 |
+
}
|
173 |
+
|
174 |
+
// Combines adjacent row blocks, doubling the block height.
|
175 |
+
// This necessarily involves adding zero weights where the blocks don't align
|
176 |
+
// across adjacent pairs of rows, so use with caution, as the resulting matrix
|
177 |
+
// is most likely to run slower if very sparse to begin with.
|
178 |
+
// In the few cases where the blocks do mostly align, the resulting matmul
|
179 |
+
// could be much faster, as the number of reads of the rhs will be halved.
|
180 |
+
void DoubleBlockHeight() { sparse_matrix_.DoubleBlockHeight(); }
|
181 |
+
|
182 |
+
// Cache_line_size is provided only for testing. Normally uses a value for
|
183 |
+
// the current architecture.
|
184 |
+
int PrepareForThreads(int num_threads, int cache_line_size = -1) {
|
185 |
+
num_threads_ = num_threads;
|
186 |
+
if (num_threads_ > 1) {
|
187 |
+
split_pc_ =
|
188 |
+
absl::make_unique<ProducerConsumer>(num_threads_, num_threads_);
|
189 |
+
} else {
|
190 |
+
split_pc_.reset(nullptr);
|
191 |
+
}
|
192 |
+
return sparse_matrix_.PrepareForThreads(num_threads, cache_line_size);
|
193 |
+
}
|
194 |
+
|
195 |
+
// Partitions the matrix into pieces by thread.
|
196 |
+
// In this matrix, we can go ahead and calculate the part that only depends
|
197 |
+
// on rhs inputs that were generated by this thread in the previous matvec,
|
198 |
+
// without having to use any thread synchronization, and only after that do we
|
199 |
+
// have to wait for the other threads to finish the previous matvec.
|
200 |
+
// So we split the matrix using the |split_points| from the previous matrix
|
201 |
+
// into 2 * |num_threads_| pieces: self and other for each thread, being the
|
202 |
+
// parts that can be calculated before and after the other threads have
|
203 |
+
// completed their calculation of the previous matvec.
|
204 |
+
// We then have to use a ProducerConsumer lock instead of a SpinBarrier to
|
205 |
+
// synchronize the data produced by the other threads.
|
206 |
+
void SliceForThreads(const std::vector<int>& split_points) {
|
207 |
+
thread_layers_.clear();
|
208 |
+
thread_layers_.reserve(num_threads_);
|
209 |
+
LOG(INFO) << "Slicing " << rows() << "x" << cols() << " matrix for "
|
210 |
+
<< num_threads_ << " threads";
|
211 |
+
for (int tid = 0; tid < num_threads_; ++tid) {
|
212 |
+
thread_layers_.emplace_back(
|
213 |
+
sparse_matrix_, full_bias_, bias_, tid,
|
214 |
+
split_points[tid] * sparse_matrix_.block_height(),
|
215 |
+
split_points[tid + 1] * sparse_matrix_.block_height());
|
216 |
+
}
|
217 |
+
mid_output_ =
|
218 |
+
std::move(csrblocksparse::CacheAlignedVector<BiasType>(rows()));
|
219 |
+
mid_output_.FillZero();
|
220 |
+
}
|
221 |
+
|
222 |
+
// Splits the layer by inputs into 2 equal pieces. Each of the resulting
|
223 |
+
// layers should be computed independently on the first and second halves of
|
224 |
+
// the inputs respectively and the results added to achieve the same effect
|
225 |
+
// as the original layer.
|
226 |
+
void SplitInputs(
|
227 |
+
SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part1,
|
228 |
+
SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part2) {
|
229 |
+
CsrBlockSparseMatrix<WeightType, RhsType> matrix1(
|
230 |
+
sparse_matrix_.SplitByColumn(0, sparse_matrix_.cols() / 2));
|
231 |
+
CsrBlockSparseMatrix<WeightType, RhsType> matrix2(
|
232 |
+
sparse_matrix_.SplitByColumn(sparse_matrix_.cols() / 2,
|
233 |
+
sparse_matrix_.cols()));
|
234 |
+
*part1 =
|
235 |
+
std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
|
236 |
+
std::move(matrix1),
|
237 |
+
std::move(CacheAlignedVector<BiasType>(full_bias_))));
|
238 |
+
CacheAlignedVector<BiasType> bias2(sparse_matrix_.rows());
|
239 |
+
bias2.FillZero();
|
240 |
+
*part2 =
|
241 |
+
std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
|
242 |
+
std::move(matrix2), std::move(bias2)));
|
243 |
+
}
|
244 |
+
|
245 |
+
// Splits the layer by outputs into 2 equal pieces. Each of the resulting
|
246 |
+
// layers should be computed independently on the full inputs and the results
|
247 |
+
// concatenated to achieve the same effect as the original layer.
|
248 |
+
void SplitOutputs(
|
249 |
+
SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part1,
|
250 |
+
SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>* part2) {
|
251 |
+
LOG(INFO) << "input rows=" << sparse_matrix_.rows()
|
252 |
+
<< ", cols=" << sparse_matrix_.cols();
|
253 |
+
CsrBlockSparseMatrix<WeightType, RhsType> matrix1(
|
254 |
+
sparse_matrix_.SplitByRow(0, sparse_matrix_.rows() / 2));
|
255 |
+
CsrBlockSparseMatrix<WeightType, RhsType> matrix2(sparse_matrix_.SplitByRow(
|
256 |
+
sparse_matrix_.rows() / 2, sparse_matrix_.rows()));
|
257 |
+
CacheAlignedVector<BiasType> bias1(full_bias_, 0, full_bias_.size() / 2);
|
258 |
+
*part1 =
|
259 |
+
std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
|
260 |
+
std::move(matrix1), std::move(bias1)));
|
261 |
+
CacheAlignedVector<BiasType> bias2(full_bias_, full_bias_.size() / 2,
|
262 |
+
full_bias_.size());
|
263 |
+
*part2 =
|
264 |
+
std::move(SparseLinearLayer<WeightType, RhsType, BiasType, DeltaType>(
|
265 |
+
std::move(matrix2), std::move(bias2)));
|
266 |
+
}
|
267 |
+
|
268 |
+
private:
|
269 |
+
// Simple struct to hold a partitioned layer.
|
270 |
+
struct PartLinearLayer {
|
271 |
+
// The original matrix is first split by row to generate only the outputs
|
272 |
+
// for the given tid. The |row_sub_matrix| is then split by column into two
|
273 |
+
// partitions:
|
274 |
+
// self is the part for which the rhs elements in [|start_col|, |end_col|)
|
275 |
+
// were generated by this thread in some previous matmul.
|
276 |
+
// |other| is the rest of the columns that require rhs elements from other
|
277 |
+
// threads.
|
278 |
+
// NOTE that| start_col|, |end_col| are in raw columns, not blocks.
|
279 |
+
PartLinearLayer(const CsrBlockSparseMatrix<WeightType, RhsType>& matrix,
|
280 |
+
const CacheAlignedVector<BiasType>& bias,
|
281 |
+
const CacheAlignedVector<BiasType>& bias_4, int tid,
|
282 |
+
int start_col, int end_col) {
|
283 |
+
int block_height = matrix.block_height();
|
284 |
+
// Split the input matrix by row, selecting only the rows relevant to
|
285 |
+
// thread tid.
|
286 |
+
int start_row = matrix.split_points()[tid] * block_height;
|
287 |
+
int end_row = matrix.split_points()[tid + 1] * block_height;
|
288 |
+
LOG(INFO) << "input cols [" << start_col << "," << end_col << ") rows ["
|
289 |
+
<< start_row << "," << end_row << ")";
|
290 |
+
CsrBlockSparseMatrix<WeightType, RhsType> row_sub_matrix =
|
291 |
+
matrix.SplitByRow(start_row, end_row);
|
292 |
+
// Partition into the columns that use rhs elements that thread tid
|
293 |
+
// produced in a previous matmul, and the other rhs elements.
|
294 |
+
// NOTE that we |keep_rhs_size|=true so that each matrix can operate on
|
295 |
+
// the same rhs input vector. The self matrix just guarantees not to
|
296 |
+
// access any of the elements that are generated by another thread.
|
297 |
+
self_matrix = std::move(row_sub_matrix.SplitByColumn(
|
298 |
+
start_col, end_col, /*keep_rhs_size=*/true));
|
299 |
+
self_matrix.PrepareForThreads(1);
|
300 |
+
// The reversed start and end slice out the complement of [start, end).
|
301 |
+
other_matrix = std::move(row_sub_matrix.SplitByColumn(
|
302 |
+
end_col, start_col, /*keep_rhs_size=*/true));
|
303 |
+
other_matrix.PrepareForThreads(1);
|
304 |
+
full_bias =
|
305 |
+
std::move(CacheAlignedVector<BiasType>(bias, start_row, end_row));
|
306 |
+
// TODO(b/189958858): Eliminate the quarter bias from all the code.
|
307 |
+
quarter_bias =
|
308 |
+
std::move(CacheAlignedVector<BiasType>(bias_4, start_row, end_row));
|
309 |
+
}
|
310 |
+
// The part of the matrix that only depends on this thread for rhs inputs.
|
311 |
+
CsrBlockSparseMatrix<WeightType, RhsType> self_matrix;
|
312 |
+
CacheAlignedVector<BiasType> full_bias;
|
313 |
+
CacheAlignedVector<BiasType> quarter_bias;
|
314 |
+
// The part of the matrix that uses rhs inputs from other threads.
|
315 |
+
CsrBlockSparseMatrix<WeightType, RhsType> other_matrix;
|
316 |
+
};
|
317 |
+
CsrBlockSparseMatrix<WeightType, RhsType, DeltaType> sparse_matrix_;
|
318 |
+
CacheAlignedVector<BiasType> bias_;
|
319 |
+
CacheAlignedVector<BiasType> full_bias_;
|
320 |
+
// Output from the self_matrix that will be given to |other_matrix| as bias.
|
321 |
+
CacheAlignedVector<BiasType> mid_output_;
|
322 |
+
// One partitioned pair of matrices for each thread.
|
323 |
+
std::vector<PartLinearLayer> thread_layers_;
|
324 |
+
// Producer-consumer lock used to wait between computing |self_matrix| and
|
325 |
+
// |other_matrix| for the other threads to finish the *previous* matvec.
|
326 |
+
std::unique_ptr<ProducerConsumer> split_pc_;
|
327 |
+
int num_threads_ = 0;
|
328 |
+
};
|
329 |
+
|
330 |
+
template <typename WeightType, typename RhsType>
|
331 |
+
SparseLinearLayer<WeightType, RhsType> CreateRandomLayer(int rows, int cols,
|
332 |
+
float sparsity,
|
333 |
+
int block_height = 1,
|
334 |
+
int block_width = 1) {
|
335 |
+
typedef typename TypeOfProduct<WeightType, RhsType>::type BiasType;
|
336 |
+
CacheAlignedVector<BiasType> bias(rows);
|
337 |
+
bias.FillRandom();
|
338 |
+
|
339 |
+
auto masked_matrix = MaskedSparseMatrix<float>(rows, cols, sparsity,
|
340 |
+
block_height, block_width);
|
341 |
+
auto sparse_matrix = CsrBlockSparseMatrix<WeightType, RhsType>(masked_matrix);
|
342 |
+
|
343 |
+
return SparseLinearLayer<WeightType, RhsType>(std::move(sparse_matrix),
|
344 |
+
std::move(bias));
|
345 |
+
}
|
346 |
+
|
347 |
+
template <typename WeightType, typename RhsType>
|
348 |
+
SparseLinearLayer<WeightType, RhsType> CreateConstantLayer(
|
349 |
+
int rows, int cols, float sparsity, float constant = 1.f) {
|
350 |
+
typedef typename TypeOfProduct<WeightType, RhsType>::type BiasType;
|
351 |
+
CacheAlignedVector<BiasType> bias(rows);
|
352 |
+
bias.FillOnes();
|
353 |
+
|
354 |
+
MaskedSparseMatrix<float> masked_matrix(rows, cols, sparsity,
|
355 |
+
/*block_height=*/1, /*block_width=*/1,
|
356 |
+
constant, /*random=*/false);
|
357 |
+
CsrBlockSparseMatrix<WeightType, RhsType> sparse_matrix(masked_matrix);
|
358 |
+
|
359 |
+
return SparseLinearLayer<WeightType, RhsType>(std::move(sparse_matrix),
|
360 |
+
std::move(bias));
|
361 |
+
}
|
362 |
+
|
363 |
+
} // namespace csrblocksparse
|
364 |
+
|
365 |
+
#endif // LYRA_CODEC_SPARSE_MATMUL_LAYERS_SPARSE_LINEAR_LAYER_H_
|
sparse_matmul/layers/sparse_linear_layer_test.cc
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2021 Google LLC
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.
|
14 |
+
|
15 |
+
#include "sparse_matmul/layers/sparse_linear_layer.h"
|
16 |
+
|
17 |
+
#include "gmock/gmock.h"
|
18 |
+
#include "gtest/gtest.h"
|
19 |
+
#include "sparse_matmul/numerics/test_utils.h"
|
20 |
+
|
21 |
+
namespace csrblocksparse {
|
22 |
+
namespace {
|
23 |
+
|
24 |
+
constexpr int kBlockSize = 4;
|
25 |
+
constexpr int kSize = 256;
|
26 |
+
constexpr int kNumThreads = 4;
|
27 |
+
constexpr int kCols = 1;
|
28 |
+
|
29 |
+
void SlicedThreadBody(SpinBarrier* spin_barrier, int tid,
|
30 |
+
const FatCacheAlignedVector<float>& rhs,
|
31 |
+
SparseLinearLayer<float, float>* sparse_linear_layer,
|
32 |
+
FatCacheAlignedVector<float>* out, bool use_relu) {
|
33 |
+
sparse_linear_layer->MatVec(rhs, use_relu, tid, /*replicas=*/1,
|
34 |
+
/*output_stride=*/0, out);
|
35 |
+
spin_barrier->barrier();
|
36 |
+
}
|
37 |
+
|
38 |
+
// Tests that a Layer that has been SliceForThreads computes the same result as
|
39 |
+
// the original layer. This is a basic test that all the slicing didn't mess up
|
40 |
+
// any of the computations.
|
41 |
+
TEST(CsrBlockSparseMatrix, SliceForThreads) {
|
42 |
+
MaskedSparseMatrix<float> matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize);
|
43 |
+
FatCacheAlignedVector<float> rhs(kSize, kCols);
|
44 |
+
CacheAlignedVector<float> bias(kSize);
|
45 |
+
FatCacheAlignedVector<float> out1(kSize, kCols);
|
46 |
+
|
47 |
+
bias.FillRandom();
|
48 |
+
rhs.FillRandom();
|
49 |
+
out1.FillZero();
|
50 |
+
FatCacheAlignedVector<float> out_reference = out1;
|
51 |
+
CsrBlockSparseMatrix<float, float> sparse_matrix(matrix);
|
52 |
+
SparseLinearLayer<float, float> sparse_linear_layer(std::move(sparse_matrix),
|
53 |
+
std::move(bias));
|
54 |
+
sparse_linear_layer.PrepareForThreads(1);
|
55 |
+
sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
|
56 |
+
/*output_stride=*/0, &out_reference);
|
57 |
+
std::vector<int> fake_split_points = {0, 48 / kBlockSize, 128 / kBlockSize,
|
58 |
+
208 / kBlockSize, kSize / kBlockSize};
|
59 |
+
sparse_linear_layer.PrepareForThreads(kNumThreads);
|
60 |
+
sparse_linear_layer.SliceForThreads(fake_split_points);
|
61 |
+
csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, SlicedThreadBody, rhs,
|
62 |
+
&sparse_linear_layer, &out1,
|
63 |
+
/*relu=*/true);
|
64 |
+
|
65 |
+
CheckResult(out_reference, out1, kCols);
|
66 |
+
}
|
67 |
+
|
68 |
+
void LayersThreadBody(SpinBarrier* spin_barrier, int tid,
|
69 |
+
const FatCacheAlignedVector<float>& rhs,
|
70 |
+
SparseLinearLayer<float, float>* sparse_linear_layer1,
|
71 |
+
SparseLinearLayer<float, float>* sparse_linear_layer2,
|
72 |
+
FatCacheAlignedVector<float>* out1,
|
73 |
+
FatCacheAlignedVector<float>* out2, bool use_relu) {
|
74 |
+
sparse_linear_layer1->MatVec(rhs, use_relu, tid, /*replicas=*/1,
|
75 |
+
/*output_stride=*/0, out1);
|
76 |
+
// NOTE no barrier here!
|
77 |
+
sparse_linear_layer2->MatVec(*out1, use_relu, tid, /*replicas=*/1,
|
78 |
+
/*output_stride=*/0, out2);
|
79 |
+
spin_barrier->barrier();
|
80 |
+
}
|
81 |
+
|
82 |
+
// Tests that a pair of layers computes the same result whether or not the
|
83 |
+
// second layer has been SliceForThreads. This is a more critical test that
|
84 |
+
// the replacement of barriers with producer-consumer locks works.
|
85 |
+
// Must be run with tsan to really test it properly.
|
86 |
+
TEST(CsrBlockSparseMatrix, SliceForThreadsLayers) {
|
87 |
+
MaskedSparseMatrix<float> matrix1(kSize, kSize, 0.95, kBlockSize, kBlockSize);
|
88 |
+
FatCacheAlignedVector<float> rhs(kSize, kCols);
|
89 |
+
CacheAlignedVector<float> bias1(kSize);
|
90 |
+
FatCacheAlignedVector<float> out1(kSize, kCols);
|
91 |
+
MaskedSparseMatrix<float> matrix2(kSize, kSize, 0.95, kBlockSize, kBlockSize);
|
92 |
+
CacheAlignedVector<float> bias2(kSize);
|
93 |
+
FatCacheAlignedVector<float> out2(kSize, kCols);
|
94 |
+
|
95 |
+
bias1.FillRandom();
|
96 |
+
rhs.FillRandom();
|
97 |
+
bias2.FillRandom();
|
98 |
+
out1.FillZero();
|
99 |
+
out2.FillZero();
|
100 |
+
FatCacheAlignedVector<float> out_reference = out2;
|
101 |
+
CsrBlockSparseMatrix<float, float> sparse_matrix1(matrix1);
|
102 |
+
SparseLinearLayer<float, float> layer1(std::move(sparse_matrix1),
|
103 |
+
std::move(bias1));
|
104 |
+
CsrBlockSparseMatrix<float, float> sparse_matrix2(matrix2);
|
105 |
+
SparseLinearLayer<float, float> layer2(std::move(sparse_matrix2),
|
106 |
+
std::move(bias2));
|
107 |
+
layer1.PrepareForThreads(1);
|
108 |
+
layer2.PrepareForThreads(1);
|
109 |
+
layer1.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
|
110 |
+
/*output_stride=*/0, &out1);
|
111 |
+
layer2.MatVec(out1, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
|
112 |
+
/*output_stride=*/0, &out_reference);
|
113 |
+
layer1.PrepareForThreads(kNumThreads);
|
114 |
+
layer2.PrepareForThreads(kNumThreads);
|
115 |
+
layer2.SliceForThreads(layer1.split_points());
|
116 |
+
csrblocksparse::LaunchOnThreadsWithBarrier(kNumThreads, LayersThreadBody, rhs,
|
117 |
+
&layer1, &layer2, &out1, &out2,
|
118 |
+
/*relu=*/true);
|
119 |
+
|
120 |
+
CheckResult(out_reference, out2, kCols);
|
121 |
+
}
|
122 |
+
|
123 |
+
// Tests that a Layer that has been DoubleBlockHeight()-ed computes the same
|
124 |
+
// result as original layer. (Float compute type).
|
125 |
+
TEST(CsrBlockSparseMatrix, Float8x4) {
|
126 |
+
using ComputeType = float;
|
127 |
+
using RhsType = float;
|
128 |
+
using BiasType = float;
|
129 |
+
MaskedSparseMatrix<float> matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize);
|
130 |
+
matrix.CastWeights<ComputeType>();
|
131 |
+
FatCacheAlignedVector<RhsType> rhs(kSize, kCols);
|
132 |
+
CacheAlignedVector<BiasType> bias(kSize);
|
133 |
+
FatCacheAlignedVector<BiasType> out1(kSize, kCols);
|
134 |
+
|
135 |
+
bias.FillRandom();
|
136 |
+
rhs.FillRandom();
|
137 |
+
out1.FillZero();
|
138 |
+
FatCacheAlignedVector<BiasType> out_reference = out1;
|
139 |
+
CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
|
140 |
+
SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
|
141 |
+
std::move(sparse_matrix), std::move(bias));
|
142 |
+
sparse_linear_layer.PrepareForThreads(1);
|
143 |
+
sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
|
144 |
+
/*output_stride=*/0, &out_reference);
|
145 |
+
sparse_linear_layer.DoubleBlockHeight();
|
146 |
+
sparse_linear_layer.PrepareForThreads(1);
|
147 |
+
sparse_linear_layer.MatVec(rhs, /*relu=*/true, /*tid=*/0, /*replicas=*/1,
|
148 |
+
/*output_stride=*/0, &out1);
|
149 |
+
CheckResult(out_reference, out1, kCols);
|
150 |
+
}
|
151 |
+
|
152 |
+
// Tests that a Layer that has been DoubleBlockHeight()-ed computes the same
|
153 |
+
// result as original layer. (Fixed16 compute type).
|
154 |
+
TEST(CsrBlockSparseMatrix, Fixed8x4) {
|
155 |
+
using ComputeType = csrblocksparse::fixed16<4>;
|
156 |
+
using RhsType = csrblocksparse::fixed16<4>;
|
157 |
+
using BiasType = typename TypeOfProduct<ComputeType, RhsType>::type;
|
158 |
+
MaskedSparseMatrix<float> matrix(kSize, kSize, 0.95, kBlockSize, kBlockSize);
|
159 |
+
matrix.CastWeights<ComputeType>();
|
160 |
+
FatCacheAlignedVector<RhsType> rhs(kSize, kCols);
|
161 |
+
CacheAlignedVector<BiasType> bias(kSize);
|
162 |
+
FatCacheAlignedVector<BiasType> out1(kSize, kCols);
|
163 |
+
|
164 |
+
bias.FillRandom();
|
165 |
+
rhs.FillRandom();
|
166 |
+
out1.FillZero();
|
167 |
+
FatCacheAlignedVector<BiasType> out_reference = out1;
|
168 |
+
CsrBlockSparseMatrix<ComputeType, RhsType> sparse_matrix(matrix);
|
169 |
+
SparseLinearLayer<ComputeType, RhsType> sparse_linear_layer(
|
170 |
+
std::move(sparse_matrix), std::move(bias));
|
171 |
+
sparse_linear_layer.PrepareForThreads(1);
|
172 |
+
sparse_linear_layer.MatVec(rhs, /*relu=*/false, /*tid=*/0, /*replicas=*/1,
|
173 |
+
/*output_stride=*/0, &out_reference);
|
174 |
+
sparse_linear_layer.DoubleBlockHeight();
|
175 |
+
sparse_linear_layer.PrepareForThreads(1);
|
176 |
+
sparse_linear_layer.MatVec(rhs, /*relu=*/false, /*tid=*/0, /*replicas=*/1,
|
177 |
+
/*output_stride=*/0, &out1);
|
178 |
+
CheckResult(out_reference, out1, kCols);
|
179 |
+
}
|
180 |
+
|
181 |
+
TEST(SparseLinearLayerTest, PrintCompiles) {
|
182 |
+
SparseLinearLayer<float, float> sparse_linear_layer;
|
183 |
+
sparse_linear_layer.Print();
|
184 |
+
}
|
185 |
+
|
186 |
+
} // namespace
|
187 |
+
} // namespace csrblocksparse
|
sparse_matmul/layers/status_macros.h
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
// Copyright 2021 Google LLC
|
2 |
+
//
|
3 |
+
// Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
// you may not use this file except in compliance with the License.
|
5 |
+
// You may obtain a copy of the License at
|
6 |
+
//
|
7 |
+
// http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
//
|
9 |
+
// Unless required by applicable law or agreed to in writing, software
|
10 |
+
// distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
// See the License for the specific language governing permissions and
|
13 |
+
// limitations under the License.
|
14 |
+
|
15 |
+
#ifndef THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_
|
16 |
+
#define THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_
|
17 |
+
|
18 |
+
#include "absl/status/status.h"
|
19 |
+
#include "absl/status/statusor.h"
|
20 |
+
|
21 |
+
#define SPARSE_MATMUL_RETURN_IF_ERROR(expr) \
|
22 |
+
do { \
|
23 |
+
const absl::Status _status = (expr); \
|
24 |
+
if (!_status.ok()) return _status; \
|
25 |
+
} while (0)
|
26 |
+
template <typename T>
|
27 |
+
absl::Status DoAssignOrReturn(T& lhs, absl::StatusOr<T> result) {
|
28 |
+
if (result.ok()) {
|
29 |
+
lhs = result.value();
|
30 |
+
}
|
31 |
+
return result.status();
|
32 |
+
}
|
33 |
+
|
34 |
+
#endif // THIRD_PARTY_LYRA_CODEC_SPARSE_MATMUL_LAYERS_STATUS_MACROS_H_
|
sparse_matmul/layers/testdata/768_512_95_4x4_QRhat_weights.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50f861af29b1f767830d74ef83874944b18d80157b6b0256fdc4c14fa79ec936
|
3 |
+
size 20852
|
sparse_matmul/layers/testdata/768_512_95_4x4_What_weights.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a2d534bde2caf6e59990a46b4b1907088b8144c53d62d97de7e2b4bdc956da68
|
3 |
+
size 5133
|
sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_bias.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:11399f9d0e8f8dfbef6eb37e0c096f858658bc650f728a08f3135ccca44f0a5a
|
3 |
+
size 1062
|
sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_mask.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d3d971e067a6df985d68beac26bcf4e9a6cc13ff328599e84d50a0fc9a7c103b
|
3 |
+
size 2382
|
sparse_matmul/layers/testdata/768_512_95_4x4_coarselogit_weights.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d1376ef7a360699dae24a49f40a254990d4a70b844dadcdbe9dcbf1a306999a8
|
3 |
+
size 55829
|
sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_bias.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ffcc8ccf086fccfacc928877aa29ef03ce51cce0f0b7d2aacf81782b7b527089
|
3 |
+
size 2003
|
sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_mask.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a16f98ba6f09031ea9fefb79fdc9ba90e44f0046ab70dab014ac971ca7f7186
|
3 |
+
size 4684
|
sparse_matmul/layers/testdata/768_512_95_4x4_coarseproj_weights.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b1b91304f5b6f7b53651ec7f9c827d4a2447366d1f990032adff46b18377741f
|
3 |
+
size 113777
|
sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_bias.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9ebb84ab4e16408f898b41a28c0d2c611f6735c8d9ad96a6805947c57cb547c7
|
3 |
+
size 1055
|
sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_mask.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:071159e5397eff604ff3f1fca3ba90980a1ff9ae12838022179709d2c50e4627
|
3 |
+
size 2322
|
sparse_matmul/layers/testdata/768_512_95_4x4_finelogit_weights.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1fdd0cbc0e79ea0a0dc1fc2ce8b10c5f25387fb4fd2ca019b66ac7ad7f44d219
|
3 |
+
size 51615
|
sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_bias.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:abd83a1795fd5e7044200029eae3ce6406b84095b7128288ac0dda1de5746b59
|
3 |
+
size 2001
|
sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_mask.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:455e1c142dd29bc4a4bb5a15c1f88ef3e0fbb580425620ef6f923b6e04faab01
|
3 |
+
size 4459
|
sparse_matmul/layers/testdata/768_512_95_4x4_fineproj_weights.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:171d1e86e04fbefeca7dcce59817ad82d30556a110b4552cd5757a9348405d1c
|
3 |
+
size 111636
|
sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_bias.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fba804daa5c3c4d5c87ca1ff4060d118c33f8e2201077e6faa233822c5f0c511
|
3 |
+
size 10706
|
sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_mask.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:62c03b31f5f58eb67773dcc5b0bae5b4790a26dca1934d79802342b4175e7a74
|
3 |
+
size 50978
|
sparse_matmul/layers/testdata/768_512_95_4x4_wavernn_gru_weights.raw.gz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:679c5bd2d5ca6abaae96225e8bab2ce9f9d57170027471465c85fc220c0c44a8
|
3 |
+
size 1361746
|