File size: 10,010 Bytes
1ce325b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
#include "merge_probabilities.hh"
#include "../common/ngram_stream.hh"
#include "bounded_sequence_encoding.hh"
#include "interpolate_info.hh"

#include <algorithm>
#include <limits>
#include <numeric>

namespace lm {
namespace interpolate {

/**
 * Helper to generate the BoundedSequenceEncoding used for writing the
 * from values.
 */
BoundedSequenceEncoding MakeEncoder(const InterpolateInfo &info, uint8_t order) {
  util::FixedArray<uint8_t> max_orders(info.orders.size());
  for (std::size_t i = 0; i < info.orders.size(); ++i) {
    max_orders.push_back(std::min(order, info.orders[i]));
  }
  return BoundedSequenceEncoding(max_orders.begin(), max_orders.end());
}

namespace {

/**
 * A simple wrapper class that holds information needed to read and write
 * the ngrams of a particular order. This class has the memory needed to
 * buffer the data needed for the recursive process of computing the
 * probabilities and "from" values for each component model.
 *
 * "From" values indicate, for each model, what order (as an index, so -1)
 * was backed off to in order to arrive at a probability. For example, if a
 * 5-gram model (order index 4) backed off twice, we would write a 2.
 */
class NGramHandler {
public:
  NGramHandler(uint8_t order, const InterpolateInfo &ifo,
               util::FixedArray<util::stream::ChainPositions> &models_by_order)
      : info(ifo),
        encoder(MakeEncoder(info, order)),
        out_record(order, encoder.EncodedLength()) {
    std::size_t count_has_order = 0;
    for (std::size_t i = 0; i < models_by_order.size(); ++i) {
      count_has_order += (models_by_order[i].size() >= order);
    }
    inputs_.Init(count_has_order);
    for (std::size_t i = 0; i < models_by_order.size(); ++i) {
      if (models_by_order[i].size() < order)
        continue;
      inputs_.push_back(models_by_order[i][order - 1]);
      if (inputs_.back()) {
        active_.resize(active_.size() + 1);
        active_.back().model = i;
        active_.back().stream = &inputs_.back();
      }
    }

    // have to init outside since NGramStreams doesn't forward to
    // GenericStreams ctor given a ChainPositions

    probs.Init(info.Models());
    from.Init(info.Models());
    for (std::size_t i = 0; i < info.Models(); ++i) {
      probs.push_back(0.0);
      from.push_back(0);
    }
  }

  struct StreamIndex {
    NGramStream<ProbBackoff> *stream;
    NGramStream<ProbBackoff> &Stream() { return *stream; }
    std::size_t model;
  };

  std::size_t ActiveSize() const {
    return active_.size();
  }

  /**
   * @return the input stream for a particular model that corresponds to
   * this ngram order
   */
  StreamIndex &operator[](std::size_t idx) {
    return active_[idx];
  }

  void erase(std::size_t idx) {
    active_.erase(active_.begin() + idx);
  }

  const InterpolateInfo &info;
  BoundedSequenceEncoding encoder;
  PartialProbGamma out_record;
  util::FixedArray<float> probs;
  util::FixedArray<uint8_t> from;

private:
  std::vector<StreamIndex> active_;
  NGramStreams<ProbBackoff> inputs_;
};

/**
 * A collection of NGramHandlers.
 */
class NGramHandlers : public util::FixedArray<NGramHandler> {
public:
  explicit NGramHandlers(std::size_t num)
      : util::FixedArray<NGramHandler>(num) {
  }

  void push_back(
      std::size_t order, const InterpolateInfo &info,
      util::FixedArray<util::stream::ChainPositions> &models_by_order) {
    new (end()) NGramHandler(order, info, models_by_order);
    Constructed();
  }
};

/**
 * The recursive helper function that computes probability and "from"
 * values for all ngrams matching a particular suffix.
 *
 * The current order can be computed as the suffix length + 1. Note that
 * the suffix could be empty (suffix_begin == suffix_end == NULL), in which
 * case we are handling unigrams with the UNK token as the fallback
 * probability.
 *
 * @param handlers The full collection of handlers
 * @param suffix_begin A start iterator for the suffix
 * @param suffix_end An end iterator for the suffix
 * @param fallback_probs The probabilities of this ngram if we need to
 *  back off (that is, the probability of the suffix)
 * @param fallback_from The order that the corresponding fallback
 *  probability in the fallback_probs is from
 * @param combined_fallback interpolated fallback_probs
 * @param outputs The output streams, one for each order
 */
void HandleSuffix(NGramHandlers &handlers, WordIndex *suffix_begin,
                  WordIndex *suffix_end,
                  const util::FixedArray<float> &fallback_probs,
                  const util::FixedArray<uint8_t> &fallback_from,
                  float combined_fallback,
                  util::stream::Streams &outputs) {
  uint8_t order = std::distance(suffix_begin, suffix_end) + 1;
  if (order > outputs.size()) return;

  util::stream::Stream &output = outputs[order - 1];
  NGramHandler &handler = handlers[order - 1];

  while (true) {
    // find the next smallest ngram which matches our suffix
    // TODO: priority queue driven.
    WordIndex *minimum = NULL;
    for (std::size_t i = 0; i < handler.ActiveSize(); ++i) {
      if (!std::equal(suffix_begin, suffix_end, handler[i].Stream()->begin() + 1))
        continue;

      // if we either haven't set a minimum yet or this one is smaller than
      // the minimum we found before, replace it
      WordIndex *last = handler[i].Stream()->begin();
      if (!minimum || *last < *minimum) { minimum = handler[i].Stream()->begin(); }
    }

    // no more ngrams of this order match our suffix, so we're done
    if (!minimum) return;

    handler.out_record.ReBase(output.Get());
    std::copy(minimum, minimum + order, handler.out_record.begin());

    // Default case is having backed off.
    std::copy(fallback_probs.begin(), fallback_probs.end(), handler.probs.begin());
    std::copy(fallback_from.begin(), fallback_from.end(), handler.from.begin());

    for (std::size_t i = 0; i < handler.ActiveSize();) {
      if (std::equal(handler.out_record.begin(), handler.out_record.end(),
                     handler[i].Stream()->begin())) {
        handler.probs[handler[i].model] = handler.info.lambdas[handler[i].model] * handler[i].Stream()->Value().prob;
        handler.from[handler[i].model] = order - 1;
        if (++handler[i].Stream()) {
          ++i;
        } else {
          handler.erase(i);
        }
      } else {
        ++i;
      }
    }
    handler.out_record.Prob() = std::accumulate(handler.probs.begin(), handler.probs.end(), 0.0);
    handler.out_record.LowerProb() = combined_fallback;
    handler.encoder.Encode(handler.from.begin(),
                           handler.out_record.FromBegin());

    // we've handled this particular ngram, so now recurse to the higher
    // order using the current ngram as the suffix
    HandleSuffix(handlers, handler.out_record.begin(), handler.out_record.end(),
                 handler.probs, handler.from, handler.out_record.Prob(), outputs);
    // consume the output
    ++output;
  }
}

/**
 * Kicks off the recursion for computing the probabilities and "from"
 * values for each ngram order. We begin by handling the UNK token that
 * should be at the front of each of the unigram input streams. This is
 * then output to the stream and it is used as the fallback for handling
 * our unigram case, the unigram used as the fallback for the bigram case,
 * etc.
 */
void HandleNGrams(NGramHandlers &handlers, util::stream::Streams &outputs) {
  PartialProbGamma unk_record(1, 0);
  // First: populate the unk probabilities by reading the first unigram
  // from each stream
  util::FixedArray<float> unk_probs(handlers[0].info.Models());

  // start by populating the ngram id from the first stream
  lm::NGram<ProbBackoff> ngram = *handlers[0][0].Stream();
  unk_record.ReBase(outputs[0].Get());
  std::copy(ngram.begin(), ngram.end(), unk_record.begin());
  unk_record.Prob() = 0;

  // then populate the probabilities into unk_probs while "multiply" the
  // model probabilities together into the unk record
  //
  // note that from doesn't need to be set for unigrams
  assert(handlers[0].ActiveSize() == handlers[0].info.Models());
  for (std::size_t i = 0; i < handlers[0].info.Models();) {
    ngram = *handlers[0][i].Stream();
    unk_probs.push_back(handlers[0].info.lambdas[i] * ngram.Value().prob);
    unk_record.Prob() += unk_probs[i];
    assert(*ngram.begin() == kUNK);
    if (++handlers[0][i].Stream()) {
      ++i;
    } else {
      handlers[0].erase(i);
    }
  }
  float unk_combined = unk_record.Prob();
  unk_record.LowerProb() = unk_combined;
  // flush the unk output record
  ++outputs[0];

  // Then, begin outputting everything in lexicographic order: first we'll
  // get the unigram then the first bigram with that context, then the
  // first trigram with that bigram context, etc., until we exhaust all of
  // the ngrams, then all of the (n-1)grams, etc.
  //
  // This function is the "root" of this recursive process.
  util::FixedArray<uint8_t> unk_from(handlers[0].info.Models());
  for (std::size_t i = 0; i < handlers[0].info.Models(); ++i) {
    unk_from.push_back(0);
  }

  // the two nulls are to encode that our "fallback" word is the "0-gram"
  // case, e.g. we "backed off" to UNK
  // TODO: stop generating vocab ids and LowerProb for unigrams.
  HandleSuffix(handlers, NULL, NULL, unk_probs, unk_from, unk_combined, outputs);

  // Verify we reached the end.  And poison!
  for (std::size_t i = 0; i < handlers.size(); ++i) {
    UTIL_THROW_IF2(handlers[i].ActiveSize(),
                     "MergeProbabilities did not exhaust all ngram streams");
    outputs[i].Poison();
  }
}
} // namespace

void MergeProbabilities::Run(const util::stream::ChainPositions &output_pos) {
  NGramHandlers handlers(output_pos.size());
  for (std::size_t i = 0; i < output_pos.size(); ++i) {
    handlers.push_back(i + 1, info_, models_by_order_);
  }

  util::stream::Streams outputs(output_pos);
  HandleNGrams(handlers, outputs);
}

}} // namespaces