Spaces:
Runtime error
Runtime error
// Copyright 2021 Google LLC | |
// | |
// Licensed under the Apache License, Version 2.0 (the "License"); | |
// you may not use this file except in compliance with the License. | |
// You may obtain a copy of the License at | |
// | |
// http://www.apache.org/licenses/LICENSE-2.0 | |
// | |
// Unless required by applicable law or agreed to in writing, software | |
// distributed under the License is distributed on an "AS IS" BASIS, | |
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
// See the License for the specific language governing permissions and | |
// limitations under the License. | |
namespace csrblocksparse { | |
void ThreadBounds::PrepareForThreads(int block_width, int block_height, | |
int num_threads, | |
int reduced_rows_per_cache_row, | |
int reduced_rows, const int* nnz_per_row) { | |
CHECK_GT(num_threads, 0); | |
block_width_ = block_width; | |
block_height_ = block_height; | |
ComputeThreadSplitPoints(num_threads, reduced_rows_per_cache_row, | |
reduced_rows, nnz_per_row); | |
weight_starts_.clear(); | |
rhs_indices_starts_.clear(); | |
bias_starts_.clear(); | |
weight_starts_.reserve(row_starts_.size()); | |
rhs_indices_starts_.reserve(row_starts_.size()); | |
bias_starts_.reserve(row_starts_.size()); | |
// Compute the start indices of each of the types, given what we know about | |
// padding, and number of |nnz_per_row|. | |
int weight_index = 0; | |
int rhs_indices_index = 0; | |
int bias_index = 0; | |
int row = 0; | |
for (int start : row_starts_) { | |
while (row < start) { | |
weight_index += nnz_per_row[row] * block_width_ * block_height_; | |
rhs_indices_index += nnz_per_row[row]; | |
bias_index += block_height_; | |
++row; | |
} | |
weight_starts_.push_back(weight_index); | |
rhs_indices_starts_.push_back(rhs_indices_index); | |
bias_starts_.push_back(bias_index); | |
} | |
} | |
// Computes the block row (reduced) index of the start of each thread. | |
void ThreadBounds::ComputeThreadSplitPoints(int num_threads, | |
int reduced_rows_per_cache_row, | |
int reduced_rows, | |
const int* nnz_per_row) { | |
row_starts_.assign(/*n=*/1, /*val=*/0); | |
// Break the rule if the matrix is too small to allow one per thread, which | |
// occurs only during tests. | |
if (reduced_rows_per_cache_row * num_threads > reduced_rows) | |
reduced_rows_per_cache_row = std::max(reduced_rows / num_threads, 1); | |
int cache_rows = (reduced_rows + reduced_rows_per_cache_row - 1) / | |
reduced_rows_per_cache_row; | |
// Compute exclusive prefix sum of the amount of work per row. | |
std::vector<int> work_upto_row(cache_rows + 1, 0); | |
int extra_row_work = 2 * reduced_rows_per_cache_row; | |
for (int i = 0; i < cache_rows; ++i) { | |
int new_nnz = 0; | |
for (int j = 0; j < reduced_rows_per_cache_row; ++j) { | |
// if |reduced_rows_per_cache_row| isn't an exact multiple of the | |
// matrix size, then we need to be careful here. | |
int index = i * reduced_rows_per_cache_row + j; | |
if (index < reduced_rows) new_nnz += nnz_per_row[index]; | |
} | |
work_upto_row[i + 1] = new_nnz + extra_row_work + work_upto_row[i]; | |
} | |
int total_work = work_upto_row.back(); | |
// Find the split point point based on assigned approximately equal amount | |
// of work for each thread. | |
int prev_split = 0; | |
for (int i = 1; i <= num_threads; ++i) { | |
int split = std::distance( | |
work_upto_row.begin(), | |
std::lower_bound(work_upto_row.begin(), work_upto_row.end(), | |
i * total_work / num_threads)); | |
int split_row = split * reduced_rows_per_cache_row; | |
if (i == num_threads) { | |
split_row = reduced_rows; | |
} | |
VLOG(2) << "tid=" << i - 1 << " num rows=" << split_row - row_starts_.back() | |
<< " work=" << work_upto_row[split] - work_upto_row[prev_split]; | |
row_starts_.push_back(split_row); | |
prev_split = split; | |
} | |
VLOG(2) << "total rows=" << reduced_rows << " total work=" << total_work; | |
} | |
} // namespace csrblocksparse | |