// Copyright 2021 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include #include #include // Placeholder for get runfiles header. #include "absl/status/status.h" #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" #include "gtest/gtest.h" #include "include/ghc/filesystem.hpp" #include "sparse_matmul/compute/matmul.h" #include "sparse_matmul/layers/utils.h" #include "sparse_matmul/numerics/test_utils.h" #include "sparse_matmul/os/coop_threads.h" namespace csrblocksparse { namespace { inline constexpr absl::string_view kTestdataPath = "layers/testdata"; TEST(CSRBlockSparseMatrix, FlatBufferSerialization) { const int kRows = 8; const int kCols = 8; std::vector mask = {1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1}; std::vector values(kRows * kCols, 1.f); values[1] = 2.f; values[3] = 3.f; values[36] = -1.f; values[45] = -2.f; csrblocksparse::CacheAlignedVector bias(kRows); csrblocksparse::CacheAlignedVector rhs(kCols); csrblocksparse::CacheAlignedVector out_ref(kRows); csrblocksparse::CacheAlignedVector out_test(kRows); bias.FillZero(); rhs.FillOnes(); csrblocksparse::MaskedSparseMatrix matrix(kRows, kCols, mask.data(), values.data()); matrix.SpMM_bias(rhs, bias, &out_ref); csrblocksparse::CsrBlockSparseMatrix block_sparse_matrix(matrix); std::string buffer; std::size_t num_bytes = block_sparse_matrix.WriteToFlatBuffer(&buffer); csrblocksparse::CsrBlockSparseMatrix new_block_sparse_matrix(reinterpret_cast(buffer.c_str()), num_bytes); new_block_sparse_matrix.SpMM_bias(rhs, bias, &out_test); CheckResult(out_ref, out_test, kCols); } template void CorrectnessCheckBlockSpMM(int rows, int cols, int block_height, int block_width, float sparsity, bool use_relu = false, int num_threads = 1, int fatness = 1, bool test_matmul = false) { using BiasType = typename TypeOfProduct::type; MaskedSparseMatrix matrix(rows, cols, sparsity, block_height, block_width); matrix.CastWeights(); FatCacheAlignedVector rhs(cols, fatness); CacheAlignedVector bias(rows); FatCacheAlignedVector out(rows, fatness); bias.FillRandom(); rhs.FillRandom(); out.FillZero(); FatCacheAlignedVector out_reference = out; matrix.SpMM_bias(rhs, bias, &out_reference, use_relu); CsrBlockSparseMatrix sparse_matrix(matrix); SparseLinearLayer sparse_linear_layer( std::move(sparse_matrix), std::move(bias)); num_threads = sparse_linear_layer.PrepareForThreads(num_threads); // Checks that the result of applying each thread's portion serially is // correct. for (int thread_id = 0; thread_id < num_threads; ++thread_id) { sparse_linear_layer.SpMM_bias(rhs, &out, use_relu, thread_id); } CheckResult(out_reference, out, sparse_linear_layer.cols()); if (test_matmul) { for (int thread_id = 0; thread_id < num_threads; ++thread_id) { sparse_linear_layer.MatVec(rhs, use_relu, thread_id, /*replicas=*/1, /*output_stride=*/0, &out); } CheckResult(out_reference, out, sparse_linear_layer.cols()); } } // Does: // y = Ax + b; // x = Ay + b; // y = Ax + b; // // to make sure that dependent multiplies are correct. template void ThreadBody( SpinBarrier* spin_barrier, int tid, const SparseLinearLayer& sparse_linear_layer, FatCacheAlignedVector* rhs, FatCacheAlignedVector* out, bool use_relu) { sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid); spin_barrier->barrier(); sparse_linear_layer.SpMM_bias(*out, rhs, use_relu, tid); spin_barrier->barrier(); sparse_linear_layer.SpMM_bias(*rhs, out, use_relu, tid); } template void CorrectnessCheckBlockSpMM_MultiThread(int rows, int cols, int block_height, int block_width, float sparsity, bool use_relu = false, int num_threads = 1, int fatness = 1) { typedef typename TypeOfProduct::type BiasType; CHECK(rows == cols); MaskedSparseMatrix matrix(rows, cols, sparsity, block_height, block_width); matrix.CastWeights(); FatCacheAlignedVector rhs(cols, fatness); FatCacheAlignedVector rhs_mt(cols, fatness); CacheAlignedVector bias(rows); FatCacheAlignedVector out(rows, fatness); bias.FillOnes(); rhs.FillOnes(); rhs_mt.FillOnes(); out.FillZero(); FatCacheAlignedVector out_reference = out; matrix.SpMM_bias(rhs, bias, &out_reference, use_relu); matrix.SpMM_bias(out_reference, bias, &rhs, use_relu); matrix.SpMM_bias(rhs, bias, &out_reference, use_relu); CsrBlockSparseMatrix sparse_matrix(matrix); num_threads = sparse_matrix.PrepareForThreads(num_threads, /*cache_line_size=*/1); SparseLinearLayer sparse_linear_layer( std::move(sparse_matrix), std::move(bias)); csrblocksparse::LaunchOnThreadsWithBarrier( num_threads, ThreadBody, sparse_linear_layer, &rhs_mt, &out, use_relu); CheckResult(out_reference, out, cols); } } // namespace TEST(MaskedSparseCorrectness, HandCoded) { const int kRows = 8; const int kCols = 8; // clang-format off std::vector mask = {1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1}; // clang-format on std::vector values(kRows * kCols, 1.f); std::vector answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f}; MaskedSparseMatrix matrix(kRows, kCols, mask.data(), values.data()); CacheAlignedVector rhs(kCols); CacheAlignedVector bias(kRows); CacheAlignedVector out(kRows); bias.FillOnes(); rhs.FillOnes(); out.FillZero(); MaskedLinearLayer masked_linear_layer(std::move(matrix), std::move(bias)); masked_linear_layer.SpMM_bias(rhs, &out); for (int i = 0; i < kRows; ++i) { EXPECT_EQ(answer[i], out[i]); } } TEST(MaskedSparseCorrectness, HandCodedFatVector) { const int kRows = 8; const int kCols = 8; // clang-format off std::vector mask = {1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1}; // clang-format on std::vector values(kRows * kCols, 1.f); std::vector answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f}; MaskedSparseMatrix matrix(kRows, kCols, mask.data(), values.data()); const int kMaxWidth = 5; for (int width = 5; width <= kMaxWidth; ++width) { FatCacheAlignedVector rhs(kCols, width); CacheAlignedVector bias(kRows); FatCacheAlignedVector out(kRows, width); bias.FillOnes(); rhs.FillOnes(); out.FillZero(); MaskedLinearLayer masked_linear_layer(std::move(matrix), std::move(bias)); masked_linear_layer.SpMM_bias(rhs, &out); for (int i = 0; i < kRows; ++i) { for (int width = 0; width < kMaxWidth; ++width) { EXPECT_EQ(answer[i], out[i + width * kRows]); } } } } TEST(CsrBlockSparseMatrix, HandCodedMultiThread) { const int kRows = 8; const int kCols = 8; // clang-format off std::vector mask = {1, 1, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1}; // clang-format on std::vector values(kRows * kCols, 1.f); std::vector answer = {6.f, 5.f, 6.f, 1.f, 9.f, 3.f, 5.f, 4.f}; MaskedSparseMatrix matrix(kRows, kCols, mask.data(), values.data()); CacheAlignedVector rhs(kCols); CacheAlignedVector bias(kRows); CacheAlignedVector out(kRows); bias.FillOnes(); rhs.FillOnes(); out.FillZero(); CacheAlignedVector bias_csr = bias; CsrBlockSparseMatrix sparse_matrix(matrix); MaskedLinearLayer masked_linear_layer(std::move(matrix), std::move(bias)); masked_linear_layer.SpMM_bias(rhs, &out); SparseLinearLayer sparse_linear_layer( std::move(sparse_matrix), std::move(bias_csr)); sparse_linear_layer.PrepareForThreads(2, /*cache_line_size=*/1); CacheAlignedVector out_tmp(kRows); const bool kUseRelu = false; sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/0); sparse_linear_layer.SpMM_bias(rhs, &out_tmp, kUseRelu, /*tid=*/1); for (int i = 0; i < kRows; ++i) { EXPECT_EQ(answer[i], out_tmp[i]); } } TEST(TestCasts, TestBfloat16) { const int kRows = 1000; const int kCols = 100; const float kSparsity = 0.f; MaskedSparseMatrix matrix(kRows, kCols, kSparsity); MaskedSparseMatrix matrix_bfloat16(kRows, kCols, matrix.mask().data(), matrix.values().data()); matrix_bfloat16.CastWeights(); CheckResult(matrix.values(), matrix_bfloat16.values(), kCols); } TEST(TestCasts, TestFP16) { const int kRows = 1000; const int kCols = 100; const float kSparsity = 0.f; MaskedSparseMatrix matrix(kRows, kCols, kSparsity); #if !defined __arm__ && !defined __aarch64__ // Conversion doesn't handle denormals, so flush denormals to zero first. for (int i = 0; i < matrix.values().size(); ++i) { if (matrix.data()[i] < 1. / static_cast(1 << 14)) matrix.data()[i] = 0.f; } #endif MaskedSparseMatrix matrix_fp16(kRows, kCols, matrix.mask().data(), matrix.values().data()); matrix_fp16.CastWeights(); CheckResult(matrix.values(), matrix_fp16.values(), kCols); } TEST(TestCasts, TestFixed16) { const int kRows = 100000; const int kCols = 1; const float kSparsity = 0.f; MaskedSparseMatrix matrix(kRows, kCols, kSparsity); // Relative error for fixed point is high near 0. for (int i = 0; i < matrix.values().size(); ++i) { // 1.1e-3 is based on the max error of .013 and a grid spacing of 1 / 2**16 // == 3e-5. 3e-5 / .013 / 2 = 1.1e-3. if (std::abs(matrix.data()[i]) < 1.1e-3) { matrix.data()[i] = 0.f; } } MaskedSparseMatrix matrix_fixed16 = matrix; matrix_fixed16.CastWeights>(); CheckResult(matrix.values(), matrix_fixed16.values(), kCols); } TEST(TestCasts, TestFixed32) { const int kRows = 100000; const int kCols = 1; const float kSparsity = 0.f; MaskedSparseMatrix matrix(kRows, kCols, kSparsity); MaskedSparseMatrix matrix_fixed32 = matrix; matrix_fixed32.CastWeights>(); CheckResult(matrix.values(), matrix_fixed32.values(), kCols); } template void TestSpMM(int block_width, int block_height, int fatness, bool test_matmul = false) { std::array use_relu = {false, true}; std::vector sparsity_levels = {.5, .8, .9, .95, .98}; std::vector> sizes = {{8, 8}, {128, 128}, {128, 64}, {256, 192}, {512, 512}, {1024, 512}, {384, 384}, {512, 384}}; for (int num_threads = 1; num_threads < 2 + test_matmul; ++num_threads) { for (const auto& relu : use_relu) { for (const auto& sparsity : sparsity_levels) { for (const auto& size : sizes) { int rows, cols; std::tie(rows, cols) = size; CorrectnessCheckBlockSpMM( rows, cols, block_height, block_width, sparsity, relu, num_threads, fatness, test_matmul); } } } } } template void TestSpMM_MultiThread(int block_width, int block_height, int fatness) { std::array use_relu = {false, true}; std::vector sparsity_levels = {.5, .8, .9, .95, .98}; std::vector> sizes = { {48, 48}, {128, 128}, {512, 512}, {384, 384}}; for (int num_threads = 1; num_threads < 5; ++num_threads) { for (const auto& relu : use_relu) { for (const auto& sparsity : sparsity_levels) { for (const auto& size : sizes) { int rows, cols; std::tie(rows, cols) = size; CorrectnessCheckBlockSpMM_MultiThread( rows, cols, block_height, block_width, sparsity, relu, num_threads, fatness); } } } } } template void TestSumVectors(int start = 0, int end = -1, int size = 6) { std::vector values; std::vector answer; for (int i = 1; i < size + 1; ++i) { const float x = static_cast(i); values.push_back(static_cast(x)); answer.push_back(static_cast(x * 2)); } if (end == -1) { end = values.size(); } csrblocksparse::CacheAlignedVector result(values.size()); csrblocksparse::CacheAlignedVector values_aligned(values); detail::SumVectors(start, end, values_aligned.data(), values_aligned.data(), result.data()); for (int i = start; i < end; ++i) { EXPECT_EQ(static_cast(answer[i]), static_cast(result[i])); } } TEST(CsrBlockSparseMatrix, SumVectors_Generic) { TestSumVectors(); TestSumVectors(1); TestSumVectors(1, 4); } TEST(CsrBlockSparseMatrix, SumVectors_Bfloat16) { TestSumVectors(); TestSumVectors(1); TestSumVectors(1, 4); } // For SIMD-optimized SumVectors, the memory of the vector should be at least // |kSIMDWidth * sizeof(float)| long, and the start position has to be an // aligned memory location. So setting |size| to be 100 to be safe and // |start| to be 0 (|start| == 1 is not aligned). TEST(CsrBlockSparseMatrix, SumVectors_Fixed16) { TestSumVectors>(0, -1, 100); TestSumVectors>(0, 4, 100); } TEST(CsrBlockSparseMatrix, SumVectors_Fixed32) { TestSumVectors>(0, -1, 100); TestSumVectors>(0, 4, 100); } TEST(CsrBlockSparseMatrix, SpMM_Block4x4_Bfloat16) { TestSpMM(/*block_width=*/4, /*block_height=*/4, /*fatness=*/7); } // This actually uses multiple threads, and uses the output as the input for // multiple steps to test that synchronization and memory visibility is // working correctly.Requires square matrices. TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_Bfloat16) { TestSpMM_MultiThread( /*block_width=*/4, /*block_height=*/4, /*fatness=*/1); } TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_Bfloat16) { TestSpMM_MultiThread( /*block_width=*/4, /*block_height=*/4, /*fatness=*/7); } TEST(CsrBlockSparseMatrix, SpMV_Block1x1_Bfloat16) { TestSpMM(/*block_width=*/1, /*block_height=*/1, /*fatness=*/1); } TEST(CsrBlockSparseMatrix, SpMM_Block1x1_Bfloat16) { TestSpMM(/*block_width=*/1, /*block_height=*/1, /*fatness=*/7); } // This actually uses multiple threads, and uses the output as the input for // multiple steps to test that synchronization and memory visibility is // working correctly.Requires square matrices. TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_Bfloat16) { TestSpMM_MultiThread( /*block_width=*/1, /*block_height=*/1, /*fatness=*/1); } TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_Bfloat16) { TestSpMM_MultiThread( /*block_width=*/1, /*block_height=*/1, /*fatness=*/7); } TEST(CsrBlockSparseMatrix, SpMV_Block4x4_float) { TestSpMM(/*block_width=*/4, /*block_height=*/4, /*fatness=*/1, /*test_matmul=*/true); } TEST(CsrBlockSparseMatrix, SpMM_Block4x4_float) { TestSpMM(/*block_width=*/4, /*block_height=*/4, /*fatness=*/7); } // This actually uses multiple threads, and uses the output as the input for // multiple steps to test that synchronization and memory visibility is // working correctly.Requires square matrices. TEST(CsrBlockSparseMatrix, SpMV_4x4MultiThreading_float) { TestSpMM_MultiThread(/*block_width=*/4, /*block_height=*/4, /*fatness=*/1); } TEST(CsrBlockSparseMatrix, SpMM_4x4MultiThreading_float) { TestSpMM_MultiThread(/*block_width=*/4, /*block_height=*/4, /*fatness=*/7); } TEST(CsrBlockSparseMatrix, SpMV_Block1x1_float) { TestSpMM(/*block_width=*/1, /*block_height=*/1, /*fatness=*/1); } TEST(CsrBlockSparseMatrix, SpMM_Block1x1_float) { TestSpMM(/*block_width=*/1, /*block_height=*/1, /*fatness=*/7); } // This actually uses multiple threads, and uses the output as the input for // multiple steps to test that synchronization and memory visibility is // working correctly.Requires square matrices. TEST(CsrBlockSparseMatrix, SpMV_1x1MultiThreading_float) { TestSpMM_MultiThread(/*block_width=*/1, /*block_height=*/1, /*fatness=*/1); } TEST(CsrBlockSparseMatrix, SpMM_1x1MultiThreading_float) { TestSpMM_MultiThread(/*block_width=*/1, /*block_height=*/1, /*fatness=*/7); } TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32) { TestSpMM, csrblocksparse::fixed16<4>, typename csrblocksparse::TypeOfProduct< csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>( /*block_width=*/4, /*block_height=*/4, /*fatness=*/1, /*test_matmul=*/true); } TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32) { TestSpMM, csrblocksparse::fixed16<4>, typename csrblocksparse::TypeOfProduct< csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>( /*block_width=*/4, /*block_height=*/4, /*fatness=*/7); } TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32) { TestSpMM, csrblocksparse::fixed16<4>, typename csrblocksparse::TypeOfProduct< csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>( /*block_width=*/1, /*block_height=*/1, /*fatness=*/1); } TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32) { TestSpMM, csrblocksparse::fixed16<4>, typename csrblocksparse::TypeOfProduct< csrblocksparse::fixed16<4>, csrblocksparse::fixed16<4>>::type>( /*block_width=*/1, /*block_height=*/1, /*fatness=*/7); } TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_16) { TestSpMM, csrblocksparse::fixed16<5>, csrblocksparse::fixed16<8>>( /*block_width=*/4, /*block_height=*/4, /*fatness=*/1, /*test_matmul=*/true); } TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_16) { TestSpMM, csrblocksparse::fixed16<5>, csrblocksparse::fixed16<8>>( /*block_width=*/4, /*block_height=*/4, /*fatness=*/7); } TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_16) { TestSpMM, csrblocksparse::fixed16<5>, csrblocksparse::fixed16<8>>( /*block_width=*/1, /*block_height=*/1, /*fatness=*/1); } TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_16) { TestSpMM, csrblocksparse::fixed16<5>, csrblocksparse::fixed16<8>>( /*block_width=*/1, /*block_height=*/1, /*fatness=*/7); } TEST(CsrBlockSparseMatrix, SpMV_Block4x4_fixed16x16_32_unmatched) { TestSpMM, csrblocksparse::fixed16<5>, csrblocksparse::fixed32<13>>( /*block_width=*/4, /*block_height=*/4, /*fatness=*/1, /*test_matmul=*/true); } TEST(CsrBlockSparseMatrix, SpMM_Block4x4_fixed16x16_32_unmatched) { TestSpMM, csrblocksparse::fixed16<5>, csrblocksparse::fixed32<13>>( /*block_width=*/4, /*block_height=*/4, /*fatness=*/7); } TEST(CsrBlockSparseMatrix, SpMV_Block1x1_fixed16x16_32_unmatched) { TestSpMM, csrblocksparse::fixed16<5>, csrblocksparse::fixed32<13>>( /*block_width=*/1, /*block_height=*/1, /*fatness=*/1); } TEST(CsrBlockSparseMatrix, SpMM_Block1x1_fixed16x16_32_unmatched) { TestSpMM, csrblocksparse::fixed16<5>, csrblocksparse::fixed32<13>>( /*block_width=*/1, /*block_height=*/1, /*fatness=*/7); } TEST(CsrBlockSparseMatrix, RhsIndicesDeltasRoundTrip) { MaskedSparseMatrix matrix(/*rows=*/256, /*cols=*/256, /*sparsity=*/0.9, /*block_height=*/4, /*block_width=*/4); CsrBlockSparseMatrix sparse_matrix(matrix); CacheAlignedVector copy_indices = sparse_matrix.rhs_indices(); sparse_matrix.ComputeColDeltas(); sparse_matrix.ComputeRHSIndices(); // They get padded when created, so the newer one could be bigger. EXPECT_LE(copy_indices.size(), sparse_matrix.rhs_indices().size()); for (int i = 0; i < copy_indices.size(); ++i) { EXPECT_EQ(copy_indices[i], sparse_matrix.rhs_indices()[i]) << "i=" << i; } } // Tests that a Layer that is split into 2 by columns (inputs) computes the same // result as the original layer. TEST(CsrBlockSparseMatrix, SplitByCol) { int kRows = 1024; int kCols = 1024; MaskedSparseMatrix matrix(kRows, kCols, 0.95, /*block_height=*/4, /*block_width=*/4); FatCacheAlignedVector rhs(kCols, /*cols=*/1); CacheAlignedVector bias(kRows); FatCacheAlignedVector out1(kRows, /*cols=*/1); FatCacheAlignedVector out2(kRows, /*cols=*/1); bias.FillRandom(); rhs.FillRandom(); out1.FillZero(); out2.FillZero(); FatCacheAlignedVector out_reference = out1; CsrBlockSparseMatrix sparse_matrix(matrix); SparseLinearLayer sparse_linear_layer(std::move(sparse_matrix), std::move(bias)); sparse_linear_layer.PrepareForThreads(1); sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false, /*tid=*/0); // Split the layer into 2 parts. SparseLinearLayer part1, part2; sparse_linear_layer.SplitInputs(&part1, &part2); part1.PrepareForThreads(1); part2.PrepareForThreads(1); EXPECT_EQ(kRows, part1.rows()); EXPECT_EQ(kCols / 2, part1.cols()); EXPECT_EQ(kRows, part2.rows()); EXPECT_EQ(kCols / 2, part2.cols()); MutableVectorView rhs1(&rhs, 0, kCols / 2); MutableVectorView rhs2(&rhs, kCols / 2, kCols / 2); for (int i = 0; i < kCols / 2; ++i) { EXPECT_FLOAT_EQ(rhs[i], rhs1.data()[i]); EXPECT_FLOAT_EQ(rhs[i + kCols / 2], rhs2.data()[i]); } part1.SpMM_bias(rhs1, &out1, /*relu=*/false, /*tid=*/0); part2.SpMM_bias(rhs2, &out2, /*relu=*/false, /*tid=*/0); // Check that out1 + out2 = out_reference. for (int i = 0; i < kRows; ++i) { EXPECT_NEAR(out_reference[i], out1[i] + out2[i], 2e-5) << " i=" << i << " out1=" << out1[i] << " out2=" << out2[i]; } } // Tests that a Layer that is split into 2 by rows (outputs) computes the same // result as the original layer. TEST(CsrBlockSparseMatrix, SplitByRow) { int kRows = 1024; int kCols = 1024; MaskedSparseMatrix matrix(kRows, kCols, 0.95, /*block_height=*/4, /*block_width=*/4); FatCacheAlignedVector rhs(kCols, /*cols=*/1); CacheAlignedVector bias(kRows); FatCacheAlignedVector out1(kRows, /*cols=*/1); FatCacheAlignedVector out2(kRows, /*cols=*/1); bias.FillRandom(); rhs.FillRandom(); out1.FillZero(); out2.FillZero(); FatCacheAlignedVector out_reference = out1; CsrBlockSparseMatrix sparse_matrix(matrix); SparseLinearLayer sparse_linear_layer(std::move(sparse_matrix), std::move(bias)); sparse_linear_layer.PrepareForThreads(1); sparse_linear_layer.SpMM_bias(rhs, &out_reference, /*relu=*/false, /*tid=*/0); // Split the layer into 2 parts. SparseLinearLayer part1, part2; sparse_linear_layer.SplitOutputs(&part1, &part2); part1.PrepareForThreads(1); part2.PrepareForThreads(1); EXPECT_EQ(kRows / 2, part1.rows()); EXPECT_EQ(kCols, part1.cols()); EXPECT_EQ(kRows / 2, part2.rows()); EXPECT_EQ(kCols, part2.cols()); MutableVectorView out2a(&out2, 0, kRows / 2); MutableVectorView out2b(&out2, kRows / 2, kRows / 2); part1.SpMM_bias(rhs, &out2a, /*relu=*/false, /*tid=*/0); part2.SpMM_bias(rhs, &out2b, /*relu=*/false, /*tid=*/0); // Check that out2 = out_reference. for (int i = 0; i < kRows; ++i) { EXPECT_NEAR(out_reference[i], out2[i], 2e-5) << " i=" << i << " out1=" << out_reference[i] << " out2=" << out2[i]; } } TEST(CsrBlockSparseMatrix, MutableVectorView) { const int kRows = 1024; const int kCols = 1024; const int kFatness = 2; std::vector values(kRows * kCols, 1.f); std::vector mask(kRows * kCols); for (int i = 0; i < mask.size(); ++i) mask[i] = i % 2; auto masked_matrix = MaskedSparseMatrix(kRows, kCols, mask.data(), values.data()); auto sparse_matrix = CsrBlockSparseMatrix(masked_matrix); FatCacheAlignedVector x(kCols, kFatness); x.FillOnes(); CacheAlignedVector bias(kRows); bias.FillZero(); // First check that we can use spans as output. Split a multiplication // into upper and lower halves times the full vector: // --------------- x t // | | x t // | | x t // --------------- = // | | x b // | | x b // --------------- x b FatCacheAlignedVector out(kRows, kFatness); FatCacheAlignedVector out_view(kRows, kFatness); MutableVectorView out_view_top(&out_view, 0, kRows / 2); MutableVectorView out_view_bottom(&out_view, kRows / 2, kRows / 2); sparse_matrix.SpMM_bias(x, bias, &out); auto masked_matrix_top = MaskedSparseMatrix(kRows / 2, kCols, mask.data(), values.data()); auto masked_matrix_bottom = MaskedSparseMatrix( kRows / 2, kCols, mask.data() + kRows * kCols / 2, values.data() + kRows * kCols / 2); auto sparse_matrix_top = CsrBlockSparseMatrix(masked_matrix_top); auto sparse_matrix_bottom = CsrBlockSparseMatrix(masked_matrix_bottom); sparse_matrix_top.SpMM_bias(x, bias, &out_view_top); sparse_matrix_bottom.SpMM_bias(x, bias, &out_view_bottom); CheckResult(out, out_view, kCols); // Check that we can use a span as an input vector. Multiply upper left // portion of the matrix by the top half of the vector. // --------------- // |oooooo | x q // |oooooo | x q // | | = // | | // --------------- auto masked_matrix_quarter = MaskedSparseMatrix( kRows / 2, kCols / 2, mask.data(), values.data()); auto sparse_matrix_quarter = CsrBlockSparseMatrix(masked_matrix_quarter); MutableVectorView x_top(&x, 0, kCols / 2); FatCacheAlignedVector out_correct(kRows / 2, /*cols=*/2); for (int i = 0; i < kFatness * (kRows / 2); ++i) out_correct[i] = 256.f; MutableVectorView bias_top(&bias, 0, kRows / 2); FatCacheAlignedVector out_quarter(kRows / 2, kFatness); sparse_matrix_quarter.SpMM_bias(x_top, bias_top, &out_quarter); CheckResult(out_correct, out_quarter, kCols / 2); } namespace { bool skip_test(const absl::Status& status, absl::string_view msg) { if (!status.ok()) { LOG(INFO) << "Couldn't load " << msg << ", skipping test " << status; return true; } return false; } } // namespace TEST(CsrBlockSparseMatrix, ModelMatrices_Bfloat16) { std::vector names = { "768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_", "768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_", "768_512_95_4x4_finelogit_", "lyra_conv1d_"}; const std::string kPath = #if defined __arm__ || defined __aarch64__ "/data/local/tmp/"; #else (ghc::filesystem::current_path() / kTestdataPath).string(); #endif for (auto& layer_name : names) { SparseLinearLayer sparse_linear_layer; auto status = LoadSparseLayer(layer_name, /*zipped=*/true, &sparse_linear_layer, kPath); // If the files don't exist on the device we're running on, just skip this // test and log that it was skipped. if (skip_test(status, layer_name)) return; int rows = sparse_linear_layer.rows(); int cols = sparse_linear_layer.cols(); MaskedLinearLayer masked_linear_layer; status = LoadMaskedLayer(layer_name, /*zipped=*/true, &masked_linear_layer, kPath); if (skip_test(status, layer_name)) return; masked_linear_layer.CastWeights(); CacheAlignedVector rhs(cols); CacheAlignedVector out_ref(rows); CacheAlignedVector out_spmv(rows); rhs.FillRandom(); out_ref.FillZero(); out_spmv.FillZero(); std::array use_relus = {false, true}; for (bool use_relu : use_relus) { masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu); sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu); CheckResult(out_ref, out_spmv, cols); } } } TEST(CsrBlockSparseMatrix, ModelMatrices_float) { std::vector names = { "768_512_95_4x4_wavernn_gru_", "768_512_95_4x4_coarseproj_", "768_512_95_4x4_coarselogit_", "768_512_95_4x4_fineproj_", "768_512_95_4x4_finelogit_", "lyra_conv1d_"}; const std::string kPath = #if defined __arm__ || defined __aarch64__ "/data/local/tmp/"; #else (ghc::filesystem::current_path() / kTestdataPath).string(); #endif for (auto& layer_name : names) { SparseLinearLayer sparse_linear_layer; auto status = LoadSparseLayer(layer_name, /*zipped=*/true, &sparse_linear_layer, kPath); // If the files don't exist on the device we're running on, just skip this // test and log that it was skipped. if (skip_test(status, layer_name)) return; int rows = sparse_linear_layer.rows(); int cols = sparse_linear_layer.cols(); MaskedLinearLayer masked_linear_layer; status = LoadMaskedLayer(layer_name, /*zipped=*/true, &masked_linear_layer, kPath); if (skip_test(status, layer_name)) return; CacheAlignedVector rhs(cols); CacheAlignedVector out_ref(rows); CacheAlignedVector out_spmv(rows); rhs.FillRandom(); out_ref.FillZero(); out_spmv.FillZero(); std::array use_relus = {false, true}; for (bool use_relu : use_relus) { masked_linear_layer.SpMM_bias(rhs, &out_ref, use_relu); sparse_linear_layer.SpMM_bias(rhs, &out_spmv, use_relu); CheckResult(out_ref, out_spmv, cols); } } } #undef SKIP_TEST } // namespace csrblocksparse