File size: 7,464 Bytes
d1a84ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
/*
 * Copyright 2021 Google LLC
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *      http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#ifndef LYRA_CODEC_SPARSE_MATMUL_OS_COOP_THREADS_H_
#define LYRA_CODEC_SPARSE_MATMUL_OS_COOP_THREADS_H_

#include <atomic>
#include <thread>  // NOLINT
#include <vector>

#define _COOP_THREADS_USE_STD_THREAD 1

#include "absl/memory/memory.h"
#include "glog/logging.h"

namespace csrblocksparse {

// A re-usable barrier. Keeps threads in extremely tight sync without
// relinquishing control.  All memory writes _before_ this barrier are visible
// to all threads _after_ this barrier.   Similar in spirit to
// pthreads_barrier.  If you expect arrival times at this barrier to be varied
// by more than microseconds, this is probably not the right synchronization
// primitive for you.  If |num_threads| exceeds the number of physical threads
// that can run simultaneously, then using this is certainly a bad idea
// (although it should still be correct).
//
// Callers MUST NOT call barrier from more threads than |num_threads|.  The
// result is undefined behavior.
class SpinBarrier {
 public:
  explicit SpinBarrier(int num_threads)
      : num_threads_(num_threads), threads_at_barrier_(0), barrier_step_(0) {}

  void barrier();

 private:
  const int num_threads_;
  std::atomic<int32_t> threads_at_barrier_;
  std::atomic<uint32_t> barrier_step_;  // unsigned to make overflow defined.
};

// Producer-consumer API using the same underlying mechanism as SpinBarrier.
// This class is intended to allow >=1 producers to produce data for >=1
// consumers, without blocking the producers.
// The consumer will block if it is ready before all the producer(s) have
// produced.
// WARNING: By design this lock does not work without some other barrier that
// prevents any producer from producing again, or consumer from consuming again
// until all consumers have consumed. Basically any loop that uses
// ProducerConsumer must have at least two consume() calls in each thread (on
// different instances) in order for the lock to work correctly.
class ProducerConsumer {
 public:
  ProducerConsumer(int num_producers, int num_consumers)
      : num_producers_(num_producers),
        num_consumers_(num_consumers),
        producers_ready_(0),
        consumers_passed_(0) {}

  // Indicates that the data produced by this thread is ready. Does NOT block.
  // NOTE that some other lock must exist between the call to this produce and
  // looping back to call produce again on the same ProducerConsumer, that
  // depends on all consumers having called consume. One such candidate would
  // be a call to SpinBarrier above by all producers and consumers.
  // Another candidate would be a separate ProducerConsumer object in which
  // these producers consume some data produced by the threads that consume
  // the data produced here. Eg.
  // tid      0       1       2       3
  // action 1 produce produce consume consume (on ProducerConsumer 1)
  // action 2 consume consume produce produce (on ProducerConsumer 2)
  // action 3 produce produce consume consume (on ProducerConsumer 3)
  // action 4 consume consume produce produce (on ProducerConsumer 4)
  // loop back to action 1.
  // NOTE: It is inadequate to loop back after action2, as thread 0 could loop
  // back and consume again on PC2 while thread 1 is still completing its call
  // to consume. It is still inadequate to loop back after action 3 for the same
  // reason (but tsan doesn't seem to pick this up.)
  inline void produce() {
    producers_ready_.fetch_add(1, std::memory_order_acq_rel);
  }

  // Waits if necessary for all producers to have produced before proceeding.
  // The ProducerConsumer cannot be reused until all consumers have consumed.
  // See detailed comment and example on produce().
  inline void consume() {
    // We can't do anything until all the producers have produced.
    while (producers_ready_.load(std::memory_order_acquire) < num_producers_) {
#if defined __aarch64__ || defined __arm__
      asm volatile("yield\n" ::: "memory");
#else
      // No pause for x86! The pause instruction on Skylake takes 141 clock
      // cycles, which in an AVX2-down-clocked CPU is getting on for 70ns.
#endif
    }
    // NOTE: It is tempting to move this fetch_add to before the wait loop to
    // reduce contention for the memory location, but that would break the lock,
    // as then the last to arrive could zero out the producers_ready before the
    // other consumers have noticed that all producers have produced.
    // With the fetch_add after the wait loop, we are guaranteed that all
    // producers have produced AND all consumers have noticed that they have
    // produced before we zero out the counters.
    int consumers = consumers_passed_.fetch_add(1, std::memory_order_acq_rel);
    if (consumers == num_consumers_ - 1) {
      // The last consumer to pass has to reset everything for the next time.
      producers_ready_.store(0, std::memory_order_relaxed);
      consumers_passed_.store(0, std::memory_order_relaxed);
    }
  }
  int num_producers() const { return num_producers_; }
  int num_consumers() const { return num_consumers_; }

 private:
  const int num_producers_;
  const int num_consumers_;
  std::atomic<int32_t> producers_ready_;
  std::atomic<int32_t> consumers_passed_;
};

// We define Thread here, so we can easily change its type later.

using Thread = std::thread;
using ThreadId = std::thread::id;

// Creates (|num_threads|-1) threads and executes a total of |num_threads|
// copies of |func| (executes one on the calling thread).
//
// Useful for long running func bodies that are intended to run in lock step.
// A possible use case for this style parallelism over a thread pool is when
// we want tight control over which memory is resident in the L2 cache of a
// processor.  With a pool we have no control over which thread gets assigned
// which portion of the computation resulting in L2 thrashing.  With this
// breakdown we can make sure each thread only acceses a specific L2-sized
// portion of memory.
//
// func's signature must be (SpinBarrier*, int thread_id, ...);
template <class Function, class... Args>
void LaunchOnThreadsWithBarrier(int num_threads, Function&& func,
                                Args&&... args) {
  SpinBarrier spin_barrier(num_threads);

  std::vector<std::unique_ptr<Thread>> threads;
  threads.reserve(num_threads);
  for (int tid = 1; tid < num_threads; ++tid) {
    auto f = [&, tid]() { func(&spin_barrier, tid, args...); };

    threads.emplace_back(absl::make_unique<Thread>(f));
#ifndef _COOP_THREADS_USE_STD_THREAD
    CHECK_OK(threads.back()->Start());
#endif
  }

  const int kLocalTid = 0;
  func(&spin_barrier, kLocalTid, args...);

  for (auto& thread : threads) {
#ifdef _COOP_THREADS_USE_STD_THREAD
    thread->join();
#else
    CHECK_OK(thread->Join());
#endif
  }
}

}  // namespace csrblocksparse

#endif  // LYRA_CODEC_SPARSE_MATMUL_OS_COOP_THREADS_H_