File size: 4,234 Bytes
d1a84ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "sparse_matmul/compute/thread_bounds.h"

#include <vector>

#include "glog/logging.h"

namespace csrblocksparse {

void ThreadBounds::PrepareForThreads(int block_width, int block_height,
                                     int num_threads,
                                     int reduced_rows_per_cache_row,
                                     int reduced_rows, const int* nnz_per_row) {
  CHECK_GT(num_threads, 0);
  block_width_ = block_width;
  block_height_ = block_height;
  ComputeThreadSplitPoints(num_threads, reduced_rows_per_cache_row,
                           reduced_rows, nnz_per_row);
  weight_starts_.clear();
  rhs_indices_starts_.clear();
  bias_starts_.clear();
  weight_starts_.reserve(row_starts_.size());
  rhs_indices_starts_.reserve(row_starts_.size());
  bias_starts_.reserve(row_starts_.size());

  // Compute the start indices of each of the types, given what we know about
  // padding, and number of |nnz_per_row|.
  int weight_index = 0;
  int rhs_indices_index = 0;
  int bias_index = 0;
  int row = 0;
  for (int start : row_starts_) {
    while (row < start) {
      weight_index += nnz_per_row[row] * block_width_ * block_height_;
      rhs_indices_index += nnz_per_row[row];
      bias_index += block_height_;
      ++row;
    }
    weight_starts_.push_back(weight_index);
    rhs_indices_starts_.push_back(rhs_indices_index);
    bias_starts_.push_back(bias_index);
  }
}

// Computes the block row (reduced) index of the start of each thread.
void ThreadBounds::ComputeThreadSplitPoints(int num_threads,
                                            int reduced_rows_per_cache_row,
                                            int reduced_rows,
                                            const int* nnz_per_row) {
  row_starts_.assign(/*n=*/1, /*val=*/0);
  // Break the rule if the matrix is too small to allow one per thread, which
  // occurs only during tests.
  if (reduced_rows_per_cache_row * num_threads > reduced_rows)
    reduced_rows_per_cache_row = std::max(reduced_rows / num_threads, 1);
  int cache_rows = (reduced_rows + reduced_rows_per_cache_row - 1) /
                   reduced_rows_per_cache_row;

  // Compute exclusive prefix sum of the amount of work per row.
  std::vector<int> work_upto_row(cache_rows + 1, 0);
  int extra_row_work = 2 * reduced_rows_per_cache_row;
  for (int i = 0; i < cache_rows; ++i) {
    int new_nnz = 0;
    for (int j = 0; j < reduced_rows_per_cache_row; ++j) {
      // if |reduced_rows_per_cache_row| isn't an exact multiple of the
      // matrix size, then we need to be careful here.
      int index = i * reduced_rows_per_cache_row + j;
      if (index < reduced_rows) new_nnz += nnz_per_row[index];
    }
    work_upto_row[i + 1] = new_nnz + extra_row_work + work_upto_row[i];
  }
  int total_work = work_upto_row.back();
  // Find the split point point based on assigned approximately equal amount
  // of work for each thread.
  int prev_split = 0;
  for (int i = 1; i <= num_threads; ++i) {
    int split = std::distance(
        work_upto_row.begin(),
        std::lower_bound(work_upto_row.begin(), work_upto_row.end(),
                         i * total_work / num_threads));
    int split_row = split * reduced_rows_per_cache_row;
    if (i == num_threads) {
      split_row = reduced_rows;
    }

    VLOG(2) << "tid=" << i - 1 << " num rows=" << split_row - row_starts_.back()
            << " work=" << work_upto_row[split] - work_upto_row[prev_split];
    row_starts_.push_back(split_row);
    prev_split = split;
  }
  VLOG(2) << "total rows=" << reduced_rows << " total work=" << total_work;
}

}  // namespace csrblocksparse