Spaces:
Runtime error
Runtime error
/* | |
* Copyright 2021 Google LLC | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
// IWYU pragma: begin_exports | |
// IWYU pragma: end_exports | |
namespace csrblocksparse { | |
// CsrBlockSparseMatrix stores a modified block compressed sparse row | |
// representation of a sparse matrix. The ordering of the weights is modified | |
// in the 16x1 and 1x1 cases so that a certain number (4 and 8 respectively) | |
// of columns of weights are stored contiguously before moving on to the next | |
// row. The 4x4 case stores each block contiguously. | |
// | |
// Currently it is constructed from a MaskedSparseMatrix which usees a dense | |
// binary mask representation. The construction generates the compressed | |
// representation. Further iterations will support a direct serialization | |
// of the compressed representation. | |
// | |
// MaskedSparseMatrix masked_matrix(rows, cols, existing_mask, existing_values) | |
// CsrBlockSparseMatrix matrix(masked_matrix) | |
// | |
// matrix.SpMV_bias(rhs, bias, &out); | |
// | |
// This class is thread compatible. | |
template <typename WeightType, typename RhsType, typename DeltaType = int16_t> | |
class CsrBlockSparseMatrix { | |
public: | |
CsrBlockSparseMatrix() {} | |
// Reference used to indicate that this is an input and not an output. | |
CsrBlockSparseMatrix(const uint8_t* const& buffer, const std::size_t& len) { | |
ReadFromFlatBuffer(buffer, len); | |
ComputeRHSIndices(); | |
} | |
template <typename InputType> | |
CsrBlockSparseMatrix(const MaskedSparseMatrix<InputType>& masked_matrix) { | |
sparsity_ = masked_matrix.sparsity(); | |
rows_ = masked_matrix.rows(); | |
cols_ = masked_matrix.cols(); | |
DetermineBlockSize(masked_matrix); | |
if (block_width_ == 1 && block_height_ == 1) | |
col_multiple_ = 8; | |
else | |
col_multiple_ = 1; | |
std::vector<InputType> weights(masked_matrix.values().begin(), | |
masked_matrix.values().end()); | |
reduced_rows_ = (rows_ + block_height_ - 1) / block_height_; | |
rows_ = reduced_rows_ * block_height_; | |
reduced_cols_ = cols_ / block_width_; | |
// Calculate the reduced CSR representation of the matrix. | |
std::vector<int> reduced_mask(reduced_rows_ * reduced_cols_); | |
std::vector<int> row_offsets = {0}; | |
int nnz = 0; | |
const auto& mask = masked_matrix.mask(); | |
for (int r = 0; r < reduced_rows_; ++r) { | |
for (int c = 0; c < reduced_cols_; ++c) { | |
int mask_val = mask[r * block_height_ * cols_ + c * block_width_]; | |
reduced_mask[r * reduced_cols_ + c] = mask_val; | |
nnz += mask_val; | |
} | |
row_offsets.push_back(nnz); | |
} | |
// Make sure the reduced representation has the correct number of columns. | |
MakeColumnsMultiple(row_offsets, &reduced_mask, &weights); | |
std::vector<int> col_indices; | |
std::vector<WeightType> weights_csr; | |
std::vector<int> nnz_per_row; | |
MaskAndWeightsToCsr(reduced_mask, weights, &nnz_per_row, &col_indices, | |
&weights_csr); | |
// Generate column deltas from |col_indices|. | |
std::vector<DeltaType> col_deltas; | |
for (int i = 0; i < col_indices.size(); ++i) { | |
// |col_indices| are used to index the RHS vector which is always float. | |
int64_t diff = sizeof(RhsType); | |
if (i == 0) | |
diff *= block_width_ * (col_indices[i]); | |
else | |
diff *= block_width_ * (col_indices[i] - col_indices[i - 1]); | |
CHECK(diff < std::numeric_limits<DeltaType>::max()) | |
<< "delta between column indices in bytes " << diff | |
<< " exceeded the maximum size of the DeltaType " | |
<< std::numeric_limits<DeltaType>::max(); | |
col_deltas.push_back(static_cast<DeltaType>(diff)); | |
} | |
// Because of pre-fetching we need some extra values at the end. | |
col_deltas.insert(col_deltas.end(), std::max(2, col_multiple_ + 1), 0); | |
nnz_per_row.insert(nnz_per_row.end(), 2, nnz_per_row.back()); | |
weights_ = CacheAlignedVector<WeightType>(weights_csr); | |
col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas); | |
nnz_per_row_ = CacheAlignedVector<int>(nnz_per_row); | |
ComputeRHSIndices(); | |
num_threads_ = 0; | |
PrepareForThreads(1); | |
} | |
// Constructor makes a matrix from the given weights, deltas and nnz, taking | |
// the other parameters from |src_matrix|. |cols| is the number of raw columns | |
// (NOT blocks) of the new matrix. | |
CsrBlockSparseMatrix( | |
const CsrBlockSparseMatrix<WeightType, RhsType, DeltaType>& src_matrix, | |
const std::vector<WeightType>& new_weights, | |
const std::vector<DeltaType>& new_deltas, const std::vector<int>& new_nnz, | |
int cols) { | |
num_threads_ = 0; | |
col_multiple_ = src_matrix.col_multiple_; | |
block_width_ = src_matrix.block_width_; | |
block_height_ = src_matrix.block_height_; | |
reduced_rows_ = new_nnz.size(); | |
rows_ = reduced_rows_ * block_height_; | |
cols_ = cols; | |
reduced_cols_ = cols_ / block_width_; | |
weights_ = CacheAlignedVector<WeightType>(new_weights); | |
col_deltas_ = CacheAlignedVector<DeltaType>(new_deltas); | |
nnz_per_row_ = CacheAlignedVector<int>(new_nnz); | |
sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_); | |
ComputeRHSIndices(); | |
name_ = src_matrix.name_; | |
PrepareForThreads(1); | |
} | |
// Factory method takes a column slice out of *this and returns a sparse | |
// matrix that takes as inputs [|start_col|, |end_col|) of *this, and | |
// returns the same number of outputs, but only a partial result. | |
// If |keep_rhs_size|, then the new matrix takes the same rhs as the current | |
// matrix, but uses a subset of it, instead of expecting just the reduced rhs. | |
// If |start_col| > |end_col|, then we slice out the complement of the defined | |
// interval, ie [0, |end_col|) + [|start_col|, current end). | |
// NOTE That |start_col| and |end_col| are in raw column coordinates, NOT | |
// block units. | |
CsrBlockSparseMatrix SplitByColumn(int start_col, int end_col, | |
bool keep_rhs_size = false) const { | |
int weight_index = 0; | |
int delta_index = 0; | |
std::vector<DeltaType> new_deltas; | |
std::vector<WeightType> new_weights; | |
std::vector<int> new_nnz(reduced_rows_); | |
int col = 0; | |
int prev_col = keep_rhs_size ? 0 : start_col; | |
for (int r = 0; r < reduced_rows_; ++r) { | |
int reduced_col_count = nnz_per_row_[r]; | |
for (int c = 0; c < reduced_col_count; ++c, ++delta_index) { | |
col += col_deltas_[delta_index] / sizeof(RhsType); | |
if ((start_col < end_col && start_col <= col && col < end_col) || | |
(start_col > end_col && (col < end_col || col >= start_col))) { | |
++new_nnz[r]; | |
new_deltas.push_back((col - prev_col) * sizeof(RhsType)); | |
prev_col = col; | |
for (int i = 0; i < block_width_ * block_height_; | |
++i, ++weight_index) { | |
new_weights.push_back(weights_[weight_index]); | |
} | |
} else { | |
weight_index += block_width_ * block_height_; | |
} | |
} | |
} | |
int new_cols = keep_rhs_size ? cols_ : end_col - start_col; | |
return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, | |
new_cols); | |
} | |
// Factory method takes a row slice out of *this and returns a sparse | |
// matrix that takes the sampe inputs as *this, and returns the outputs for | |
// the range [|start_row|, |end_row|). | |
// NOTE That |start_row| and |end_row| are in raw column coordinates, NOT | |
// block units. | |
CsrBlockSparseMatrix SplitByRow(int start_row, int end_row) const { | |
int start_reduced = start_row / block_height_; | |
int end_reduced = end_row / block_height_; | |
std::vector<int> new_nnz(nnz_per_row_.data() + start_reduced, | |
nnz_per_row_.data() + end_reduced); | |
int weight_start = 0; | |
for (int r = 0; r < start_reduced; ++r) { | |
weight_start += nnz_per_row_[r]; | |
} | |
int weight_end = weight_start; | |
for (int r = start_reduced; r < end_reduced; ++r) { | |
weight_end += nnz_per_row_[r]; | |
} | |
int delta_start = 0; | |
for (int i = 0; i < weight_start; ++i) { | |
delta_start += col_deltas_[i]; | |
} | |
std::vector<DeltaType> new_deltas(col_deltas_.data() + weight_start, | |
col_deltas_.data() + weight_end); | |
new_deltas[0] += delta_start; | |
int block_size = block_height_ * block_width_; | |
std::vector<WeightType> new_weights( | |
weights_.data() + weight_start * block_size, | |
weights_.data() + weight_end * block_size); | |
return CsrBlockSparseMatrix(*this, new_weights, new_deltas, new_nnz, cols_); | |
} | |
// Combines adjacent row blocks, doubling the block height. | |
// This necessarily involves adding zero weights where the blocks don't align | |
// across adjacent pairs of rows, so use with caution, as the resulting matrix | |
// is most likely to run slower if very sparse to begin with. | |
// In the few cases where the blocks do mostly align, the resulting matmul | |
// could be much faster, as the number of reads of the rhs will be halved. | |
void DoubleBlockHeight() { | |
int new_rows = reduced_rows_ / 2; | |
std::vector<int> new_nnz(new_rows); | |
std::vector<DeltaType> new_rhs_indices; | |
std::vector<WeightType> new_weights; | |
int rhs_index1 = 0; | |
int rhs_index2 = 0; | |
int block_size = block_height_ * block_width_; | |
for (int r = 0; r < new_rows; ++r) { | |
int start_nnz = new_rhs_indices.size(); | |
rhs_index2 += nnz_per_row_[r * 2]; | |
int end1 = rhs_index1 + nnz_per_row_[r * 2]; | |
int end2 = rhs_index2 + nnz_per_row_[r * 2 + 1]; | |
// Run over a pair of rows with 2 iterators, combining blocks as we go, or | |
// padding with zeros where the block positions don't match. | |
while (rhs_index1 < end1 || rhs_index2 < end2) { | |
int col1 = rhs_index1 < end1 ? rhs_indices_[rhs_index1] : reduced_cols_; | |
int col2 = rhs_index2 < end2 ? rhs_indices_[rhs_index2] : reduced_cols_; | |
if (col1 < col2) { | |
// Need zero weights for row2 to pad out weights block. | |
new_rhs_indices.push_back(col1); | |
new_weights.insert(new_weights.end(), | |
weights_.data() + rhs_index1 * block_size, | |
weights_.data() + (rhs_index1 + 1) * block_size); | |
new_weights.insert(new_weights.end(), block_size, | |
static_cast<WeightType>(0.0f)); | |
++rhs_index1; | |
} else if (col1 > col2) { | |
// Need zero weights for row1 to pad out weights block. | |
new_rhs_indices.push_back(col2); | |
new_weights.insert(new_weights.end(), block_size, | |
static_cast<WeightType>(0.0f)); | |
new_weights.insert(new_weights.end(), | |
weights_.data() + rhs_index2 * block_size, | |
weights_.data() + (rhs_index2 + 1) * block_size); | |
++rhs_index2; | |
} else { | |
// Combine weights for both row1 and row2. | |
new_rhs_indices.push_back(col1); | |
new_weights.insert(new_weights.end(), | |
weights_.data() + rhs_index1 * block_size, | |
weights_.data() + (rhs_index1 + 1) * block_size); | |
new_weights.insert(new_weights.end(), | |
weights_.data() + rhs_index2 * block_size, | |
weights_.data() + (rhs_index2 + 1) * block_size); | |
++rhs_index1; | |
++rhs_index2; | |
} | |
} | |
rhs_index1 = rhs_index2; | |
new_nnz[r] = new_rhs_indices.size() - start_nnz; | |
} | |
block_height_ *= 2; | |
reduced_rows_ /= 2; | |
weights_ = CacheAlignedVector<WeightType>(new_weights); | |
rhs_indices_ = CacheAlignedVector<DeltaType>(new_rhs_indices); | |
nnz_per_row_ = CacheAlignedVector<int>(new_nnz); | |
sparsity_ = 1.0f - static_cast<float>(new_weights.size()) / (rows_ * cols_); | |
ComputeColDeltas(); | |
if (num_threads_ > 0) { | |
int num_threads = num_threads_; | |
num_threads_ = 0; | |
PrepareForThreads(num_threads); | |
} | |
} | |
// Allocates memory and fills buffer. | |
// Caller is responsible for the memory de-allocation. | |
// TODO(b/189958858): Both Read and Write need to eventually handle the | |
// different possible HalfType and DeltaType values, but punting for now as | |
// there is only one supported combination. | |
std::size_t WriteToFlatBuffer(std::string* csr_flatbuffer) { | |
std::size_t bytes = 0; | |
bytes += FixedParameterSize(); | |
bytes += weights_.size() * sizeof(WeightType); | |
bytes += col_deltas_.size() * sizeof(DeltaType); | |
bytes += nnz_per_row_.size() * sizeof(int); | |
uint8_t* bytes_ptr_ptr = | |
reinterpret_cast<uint8_t*>(CHECK_NOTNULL(malloc(bytes))); | |
int* int_bytes_ptr = reinterpret_cast<int*>(bytes_ptr_ptr); | |
*int_bytes_ptr++ = rows_; | |
*int_bytes_ptr++ = cols_; | |
*int_bytes_ptr++ = reduced_rows_; | |
*int_bytes_ptr++ = reduced_cols_; | |
*int_bytes_ptr++ = block_width_; | |
*int_bytes_ptr++ = block_height_; | |
*int_bytes_ptr++ = col_multiple_; | |
*int_bytes_ptr++ = num_threads_; | |
*int_bytes_ptr++ = weights_.size(); | |
*int_bytes_ptr++ = col_deltas_.size(); | |
*int_bytes_ptr++ = nnz_per_row_.size(); | |
float* float_bytes_ptr = reinterpret_cast<float*>(int_bytes_ptr); | |
*float_bytes_ptr++ = sparsity_; | |
uint8_t* bytes_ptr = reinterpret_cast<uint8_t*>(float_bytes_ptr); | |
memcpy(bytes_ptr, weights_.data(), weights_.size() * sizeof(WeightType)); | |
bytes_ptr += weights_.size() * sizeof(WeightType); | |
memcpy(bytes_ptr, col_deltas_.data(), | |
col_deltas_.size() * sizeof(DeltaType)); | |
bytes_ptr += col_deltas_.size() * sizeof(DeltaType); | |
memcpy(bytes_ptr, nnz_per_row_.data(), nnz_per_row_.size() * sizeof(int)); | |
bytes_ptr += nnz_per_row_.size() * sizeof(int); | |
csr_flatbuffer->resize(bytes); | |
csr_flatbuffer->assign(reinterpret_cast<char*>(bytes_ptr_ptr), bytes); | |
free(bytes_ptr_ptr); | |
return bytes; | |
} | |
void ReadFromFlatBuffer(const uint8_t* const& bytes, const std::size_t& len) { | |
CHECK_GE(len, FixedParameterSize()); | |
const int* int_bytes_ptr = reinterpret_cast<const int*>(bytes); | |
rows_ = *int_bytes_ptr++; | |
cols_ = *int_bytes_ptr++; | |
reduced_rows_ = *int_bytes_ptr++; | |
reduced_cols_ = *int_bytes_ptr++; | |
block_width_ = *int_bytes_ptr++; | |
block_height_ = *int_bytes_ptr++; | |
col_multiple_ = *int_bytes_ptr++; | |
int num_threads = *int_bytes_ptr++; | |
int32_t weights_size = *int_bytes_ptr++; | |
int32_t col_deltas_size = *int_bytes_ptr++; | |
int32_t nnz_per_row_size = *int_bytes_ptr++; | |
// Make sure negative sizes don't mess things up. | |
weights_size = std::max(0, weights_size); | |
col_deltas_size = std::max(0, col_deltas_size); | |
nnz_per_row_size = std::max(0, nnz_per_row_size); | |
const float* float_bytes_ptr = | |
reinterpret_cast<const float*>(int_bytes_ptr); | |
sparsity_ = *float_bytes_ptr++; | |
std::size_t total_bytes = | |
FixedParameterSize() + weights_size * sizeof(WeightType) + | |
col_deltas_size * sizeof(DeltaType) + nnz_per_row_size * sizeof(int); | |
CHECK_EQ(total_bytes, len) | |
<< "total bytes: " << total_bytes << ", actual len given: " << len; | |
const uint8_t* bytes_ptr = | |
reinterpret_cast<const uint8_t*>(float_bytes_ptr); | |
std::vector<WeightType> weights_raw(weights_size); | |
memcpy(weights_raw.data(), bytes_ptr, weights_size * sizeof(WeightType)); | |
weights_ = CacheAlignedVector<WeightType>(weights_raw); | |
bytes_ptr += weights_size * sizeof(WeightType); | |
std::vector<DeltaType> deltas_raw(col_deltas_size); | |
memcpy(deltas_raw.data(), bytes_ptr, col_deltas_size * sizeof(DeltaType)); | |
col_deltas_ = CacheAlignedVector<DeltaType>(deltas_raw); | |
bytes_ptr += col_deltas_size * sizeof(DeltaType); | |
std::vector<int> nnz_raw(nnz_per_row_size); | |
memcpy(nnz_raw.data(), bytes_ptr, nnz_per_row_size * sizeof(int)); | |
nnz_per_row_ = CacheAlignedVector<int>(nnz_raw); | |
num_threads_ = 0; | |
PrepareForThreads(num_threads); | |
} | |
// Multiply a Sparse matrix by a possibly dense matrix. Often the matrix is | |
// a vector with a small number of columns, hence the term "fat vector". | |
// 1x1 and 4x4 have specializations for output columns (ie fatness) > 5, | |
// and often achieve twice as many GFlops when multiplying a right hand side | |
// that has 5 or more columns. (Best is a multiple of 5). | |
// 16x1 doesn't have enough registers and just loops over the width 1 kernel. | |
// | |
// |rhs| and |out| are COLUMN MAJOR. | |
// Fast Tuples WeightType, BiasType, RhsType, OutType are: | |
// (float, float, float, float) | |
// (bfloat16, float, float, float) | |
// and only on ARM64. All other cases use a slow generic implementation. | |
template <typename RhsClass, typename BiasClass, typename OutClass, | |
typename BiasType = typename BiasClass::value_type, | |
typename OutType = typename OutClass::value_type> | |
void SpMM_bias(const RhsClass& rhs, const BiasClass& bias, OutClass* out, | |
bool relu = false, int tid = 0, | |
SpinBarrier* barrier = nullptr) const { | |
static_assert(std::is_same<typename RhsClass::value_type, RhsType>::value, | |
"Rhs types must match"); | |
CHECK_LT(tid, num_threads_); | |
CHECK_EQ(rhs.cols(), out->cols()); | |
CHECK_EQ(rhs.rows(), cols_); | |
CHECK_GE(out->rows(), rows_); | |
int cols_to_go = out->cols(); | |
int rhs_index = *thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid); | |
const RhsType* rhs_ptr = rhs.data() + rhs_index * block_height_; | |
OutType* out_ptr = thread_bounds_.OffsetOutput(out->data(), tid); | |
const WeightType* weights_ptr = | |
thread_bounds_.OffsetWeights(weights_.data(), tid); | |
const DeltaType* delta_ptr = | |
thread_bounds_.OffsetRhsIndices(col_deltas_.data(), tid); | |
int offset = *delta_ptr / sizeof(RhsType); | |
rhs_ptr -= offset; | |
const int* nnz_ptr = nnz_per_row_.data() + thread_bounds_.StartRow(tid); | |
int assigned_rows = | |
thread_bounds_.StartRow(tid + 1) - thread_bounds_.StartRow(tid); | |
const BiasType* bias_ptr = thread_bounds_.OffsetBias(bias.data(), tid); | |
while (cols_to_go > 0) { | |
if (block_width_ == 4 && block_height_ == 4) { | |
if (cols_to_go >= 5) { | |
detail::SpMM5_4x4<WeightType, RhsType, OutType>( | |
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, | |
assigned_rows, out->col_stride(), rhs.col_stride(), relu); | |
} else { | |
detail::SpMV_4x4<WeightType, RhsType, OutType>( | |
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, | |
assigned_rows, out->col_stride(), rhs.col_stride(), relu); | |
} | |
} else { | |
if (cols_to_go >= 5) { | |
detail::SpMM5_1x1<WeightType, RhsType, OutType>( | |
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, | |
assigned_rows, out->col_stride(), rhs.col_stride(), relu); | |
} else { | |
detail::SpMV_1x1<WeightType, RhsType, OutType>( | |
weights_ptr, delta_ptr, nnz_ptr, rhs_ptr, bias_ptr, out_ptr, | |
assigned_rows, out->col_stride(), rhs.col_stride(), relu); | |
} | |
} | |
if (cols_to_go >= 5) { | |
cols_to_go -= 5; | |
rhs_ptr += rhs.col_stride() * 5; | |
out_ptr += out->col_stride() * 5; | |
} else { | |
cols_to_go--; | |
rhs_ptr += rhs.col_stride(); | |
out_ptr += out->col_stride(); | |
} | |
if (barrier) barrier->barrier(); | |
} | |
} | |
template <typename MVRhsType, typename MVBiasType, typename OutType> | |
void MatVec(const MVRhsType* rhs, const MVBiasType* bias, bool relu, int tid, | |
int replicas, int output_stride, OutType* output) { | |
CHECK_LT(tid, num_threads_); | |
CHECK_EQ(block_width_, 4) << "Block width must be 4!"; | |
if (block_height_ == 8) { | |
matmul_.MatVec8x4( | |
thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs, | |
thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(), | |
thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid), | |
thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu, | |
replicas, output_stride, thread_bounds_.OffsetOutput(output, tid)); | |
} else { | |
CHECK_EQ(block_height_, 4) << "Block height must be 4 or 8!"; | |
matmul_.MatVec4x4( | |
thread_bounds_.OffsetWeights(weights_.cast_data(), tid), rhs, | |
thread_bounds_.OffsetBias(bias, tid), nnz_per_row_.data(), | |
thread_bounds_.OffsetRhsIndices(rhs_indices_.data(), tid), | |
thread_bounds_.StartRow(tid), thread_bounds_.StartRow(tid + 1), relu, | |
replicas, output_stride, thread_bounds_.OffsetOutput(output, tid)); | |
} | |
} | |
int rows() const { return rows_; } | |
int cols() const { return cols_; } | |
int block_height() const { return block_height_; } | |
int block_width() const { return block_width_; } | |
float sparsity() const { return sparsity_; } | |
int num_threads() const { return num_threads_; } | |
const ThreadBounds& thread_bounds() const { return thread_bounds_; } | |
const CacheAlignedVector<DeltaType>& rhs_indices() const { | |
return rhs_indices_; | |
} | |
const std::string& name() const { return name_; } | |
void set_name(const std::string& name) { name_ = name; } | |
const std::vector<int>& split_points() const { | |
return thread_bounds_.row_starts(); | |
} | |
std::size_t bytes() const { | |
return weights_.size() * sizeof(WeightType) + | |
col_deltas_.size() * sizeof(DeltaType) + | |
nnz_per_row_.size() * sizeof(int); | |
} | |
// Multiplies a sparse matrix by a possibly dense matrix, as SpMM_bias above, | |
// and then samples from the output (softmax distribution) layer. | |
template <typename RhsClass, typename BiasClass, typename OutClass, | |
typename BiasType = typename BiasClass::value_type, | |
typename OutType = typename OutClass::value_type> | |
typename std::enable_if<!IsFixed32Type<OutType>::value, int>::type | |
SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out, | |
float temperature, int tid, SpinBarrier* barrier, | |
std::minstd_rand* gen, | |
CacheAlignedVector<float>* scratch) const { | |
SpMM_bias(rhs, bias, out, /*relu=*/false, tid, barrier); | |
return out->Sample(temperature, gen, scratch); | |
} | |
// Fixed32 version. | |
template <typename RhsClass, typename BiasClass, typename OutClass, | |
typename BiasType = typename BiasClass::value_type, | |
typename OutType = typename OutClass::value_type> | |
typename std::enable_if<IsFixed32Type<OutType>::value, int>::type | |
SpMM_bias_Sample(const RhsClass& rhs, const BiasClass& bias, OutClass* out, | |
float temperature, int tid, SpinBarrier* barrier, | |
std::minstd_rand* gen, | |
CacheAlignedVector<float>* scratch) const { | |
// We don't pass the barrier on, as we have more work to do. | |
SpMM_bias(rhs, bias, out, /*relu=*/false, tid); | |
return out->ReducingSample(gen, scratch, tid, temperature, barrier); | |
} | |
void Print() const { | |
std::cout << "Weights\n"; | |
weights_.Print(); | |
std::cout << std::endl; | |
std::cout << "Deltas\n"; | |
col_deltas_.Print(); | |
std::cout << std::endl; | |
std::cout << "nnz\n"; | |
nnz_per_row_.Print(); | |
std::cout << std::endl; | |
} | |
// Split the computation amongst threads by rows based on the number of | |
// non zeros, with the addition of a constant to account for the work of the | |
// bias and the horizontal add at the end, and also guarantees that each | |
// thread writes only whole cache lines, based on the size of OutType. | |
// The |cache_line_size| arg is used only for testing. Normally it is provided | |
// through the architecture #defines. | |
// Each thread gets a contiguous row range (|split_points|). | |
// Thread t does rows [ split_points[t], split_points[t + 1] ) | |
// Each thread also needs to know how many non zeros were before it to skip | |
// (|nnz_to_skip|). And finally it also needs to know what the offset into | |
// the rhs vector would have been at the split point (|rhs_to_skip|). | |
// | |
// Some tricky corner cases where the number of non-zeros doesn't split | |
// nicely amongst the number of requested threads are not handled and default | |
// to one thread; these cases are only going to happen in tests and not in | |
// the matrices that correspond in real models. | |
// | |
// Returns the maximum number of threads that can be used; <= |num_threads|. | |
template <typename OutType = int32_t> | |
int PrepareForThreads(int num_threads, int cache_line_size = -1) { | |
CHECK_GT(num_threads, 0); | |
// we've already prepared for this number of threads, nothing to do | |
if (num_threads == num_threads_) return num_threads_; | |
num_threads_ = num_threads; | |
thread_bounds_.PrepareForThreads( | |
block_width_, block_height_, num_threads_, | |
ReducedRowsPerCacheLine<OutType>(cache_line_size), reduced_rows_, | |
nnz_per_row_.data()); | |
return num_threads_; | |
} | |
// Computes and stores the |rhs_indices_| from the |col_deltas_|. | |
void ComputeRHSIndices() { | |
std::vector<int> cumulative_deltas = CumulativeColDeltas(); | |
std::vector<DeltaType> rhs_indices(cumulative_deltas.size() + | |
reduced_rows_); | |
int total_indices = 0; | |
int delta_index = 0; | |
for (int r = 0; r < reduced_rows_; ++r) { | |
for (int n = 0; n < nnz_per_row_[r]; ++n, ++delta_index) { | |
rhs_indices[total_indices++] = | |
cumulative_deltas[delta_index] / block_width_; | |
} | |
} | |
rhs_indices_ = CacheAlignedVector<DeltaType>(rhs_indices); | |
} | |
// Computes and stores the |col_deltas_| from the |rhs_indices_|. | |
void ComputeColDeltas() { | |
std::vector<int> col_deltas(rhs_indices_.size()); | |
int prev_index = 0; | |
for (int i = 0; i < rhs_indices_.size(); ++i) { | |
int offset = rhs_indices_[i] - prev_index; | |
prev_index = rhs_indices_[i]; | |
col_deltas[i] = offset * block_width_ * sizeof(RhsType); | |
} | |
col_deltas_ = CacheAlignedVector<DeltaType>(col_deltas); | |
} | |
// Computes and returns the inclusive prefix sum of the deltas, ie absolute | |
// positions. | |
std::vector<int> CumulativeColDeltas() const { | |
std::vector<int> cum_col_deltas(col_deltas_.size()); | |
for (int i = 0; i < col_deltas_.size(); ++i) { | |
cum_col_deltas[i] = col_deltas_[i] / sizeof(RhsType); | |
if (i > 0) cum_col_deltas[i] += cum_col_deltas[i - 1]; | |
} | |
return cum_col_deltas; | |
} | |
private: | |
constexpr std::size_t FixedParameterSize() const { | |
return sizeof(int) // rows | |
+ sizeof(int) // cols | |
+ sizeof(int) // reduced_rows | |
+ sizeof(int) // reduced_cols | |
+ sizeof(int) // block_width | |
+ sizeof(int) // block_height | |
+ sizeof(float) // sparsity | |
+ sizeof(int) // col_multiple | |
+ sizeof(int) // num_threads_ | |
+ sizeof(int) // weights_.size() | |
+ sizeof(int) // col_deltas_.size() | |
+ sizeof(int); // nnz_per_row_.size() | |
} | |
// Possible block sizes are only those that are supported by the computation | |
// default is 1x1, other options are 4x4 and 16x1. | |
template <typename InputType> | |
void DetermineBlockSize(const MaskedSparseMatrix<InputType>& masked_matrix) { | |
const std::vector<std::pair<int, int>> kPreferredOrder = {{4, 4}}; | |
int rows = masked_matrix.rows(); | |
int cols = masked_matrix.cols(); | |
for (const auto& block_size : kPreferredOrder) { | |
int block_height, block_width; | |
std::tie(block_height, block_width) = block_size; | |
if (cols % block_width != 0) continue; | |
int reduced_rows = (rows + block_height - 1) / block_height; | |
int reduced_cols = cols / block_width; | |
// For each possible block, confirm that it is either all 0s or all 1s. | |
bool all_same = true; | |
const auto& mask = masked_matrix.mask(); | |
for (int r = 0; r < reduced_rows; ++r) { | |
for (int c = 0; c < reduced_cols; ++c) { | |
int val = mask[r * block_height * cols + c * block_width]; | |
for (int i = 0; i < block_height; ++i) { | |
for (int j = 0; j < block_width; ++j) { | |
int index = (r * block_height + i) * cols + c * block_width + j; | |
if (index < masked_matrix.mask().size()) { | |
all_same &= (masked_matrix.mask()[index] == val); | |
} | |
} | |
} | |
} | |
} | |
// If this block configuration is possible, accept it. | |
if (all_same) { | |
block_height_ = block_height; | |
block_width_ = block_width; | |
return; | |
} | |
} | |
// No large blocks were found, default to 1x1. | |
block_height_ = 1; | |
block_width_ = 1; | |
} | |
// CSR descriptors are for the reduced matrix, weights is the full matrix. | |
template <typename InputType> | |
void MakeColumnsMultiple(const std::vector<int>& row_offsets, | |
std::vector<int>* reduced_mask, | |
std::vector<InputType>* weights) { | |
if (col_multiple_ > 0) { | |
// Make sure each row has a number of columns that is a multiple of | |
// |col_multiple|. | |
for (int r = 1; r < row_offsets.size(); ++r) { | |
int num_row = row_offsets[r] - row_offsets[r - 1]; | |
int num_needed = col_multiple_ - num_row % col_multiple_; | |
if (num_needed < col_multiple_) { | |
// Find gaps in the columns where we can insert a column of 0 weights. | |
int num_added = 0; | |
for (int c = 0; c < reduced_cols_; ++c) { | |
if ((*reduced_mask)[(r - 1) * reduced_cols_ + c] == 0) { | |
(*reduced_mask)[(r - 1) * reduced_cols_ + c] = 1; | |
// Zero out the weights that correspond to this block. | |
for (int i = 0; i < block_height_; ++i) { | |
for (int j = 0; j < block_width_; ++j) { | |
(*weights)[((r - 1) * block_height_ + i) * cols_ + | |
block_width_ * c + j] = InputType(0.f); | |
} | |
} | |
num_added++; | |
} | |
if (num_added == num_needed) break; | |
} | |
} | |
} | |
} | |
} | |
// Given the final dense mask and weights, convert to the compressed | |
// block CSR representation. | |
template <typename InputType> | |
void MaskAndWeightsToCsr(const std::vector<int>& mask, | |
const std::vector<InputType>& weights, | |
std::vector<int>* nnz_per_row, | |
std::vector<int>* col_indices, | |
std::vector<WeightType>* weights_csr) { | |
std::vector<int> row_offsets = {0}; | |
int nnz = 0; | |
// Standard CSR format. | |
if (block_width_ == 1 && block_height_ == 1) { | |
for (int r = 0; r < rows_; ++r) { | |
for (int c = 0; c < cols_; ++c) { | |
if (mask[r * cols_ + c] == 1) { | |
nnz++; | |
col_indices->push_back(c); | |
weights_csr->push_back(WeightType(weights[r * cols_ + c])); | |
} | |
} | |
row_offsets.push_back(nnz); | |
} | |
} else if (block_width_ == 4 && block_height_ == 4) { | |
// Weights are stored contiguously for each block in this case. | |
for (int r = 0; r < reduced_rows_; ++r) { | |
for (int c = 0; c < reduced_cols_; ++c) { | |
if (mask[r * reduced_cols_ + c] == 1) { | |
col_indices->push_back(c); | |
nnz++; | |
for (int i = 0; i < block_height_; ++i) { | |
for (int j = 0; j < block_width_; ++j) { | |
int row_index = (block_height_ * r + i) * cols_; | |
int w_index = row_index + block_width_ * c + j; | |
WeightType weight = w_index < weights.size() | |
? WeightType(weights[w_index]) | |
: WeightType(0.0f); | |
weights_csr->push_back(weight); | |
} | |
} | |
} | |
} | |
row_offsets.push_back(nnz); | |
} | |
} | |
for (int i = 1; i < row_offsets.size(); ++i) | |
nnz_per_row->push_back(row_offsets[i] - row_offsets[i - 1]); | |
} | |
// Returns the number of block rows per cache line. This is the minimum unit | |
// into which the calculation is broken for threads. | |
template <typename OutType> | |
int ReducedRowsPerCacheLine(int override_cache_line_size = -1) const { | |
int line_size = kCacheLineSize; | |
if (override_cache_line_size >= 1) line_size = override_cache_line_size; | |
return std::max<int>(line_size / (block_height_ * sizeof(OutType)), 1); | |
} | |
int col_multiple_; | |
int rows_; | |
int cols_; | |
int reduced_rows_; | |
int reduced_cols_; | |
float sparsity_; | |
int block_width_; | |
int block_height_; | |
int num_threads_; | |
std::string name_; | |
CacheAlignedVector<WeightType> weights_; | |
CacheAlignedVector<DeltaType> col_deltas_; | |
CacheAlignedVector<int> nnz_per_row_; | |
// |thread_bounds_| and |rhs_indices_| don't need to be serialized as they are | |
// always recalculated from serialized data. | |
CacheAlignedVector<DeltaType> rhs_indices_; | |
Matmul<WeightType, RhsType> matmul_; | |
ThreadBounds thread_bounds_; | |
static constexpr int kCacheLineSize = 64; | |
}; | |
// Converts a sparse matrix represented with (|mask|, |weights|, |size|) into | |
// the CSR format, and returns that as a serialized string. | |
template <typename MaskType> | |
std::string ConvertDenseToSparseRepresentation_Int16Deltas( | |
const std::vector<MaskType>& mask, const std::vector<float>& weights, | |
const int rows, const int cols) { | |
MaskedSparseMatrix<float> masked_weights(rows, cols, mask.data(), | |
weights.data()); | |
CsrBlockSparseMatrix<csrblocksparse::bfloat16, float, int16_t> | |
sparse_masked_weights(masked_weights); | |
std::string buffer; | |
sparse_masked_weights.WriteToFlatBuffer(&buffer); | |
return buffer; | |
} | |
} // namespace csrblocksparse | |