File size: 3,586 Bytes
1ce325b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
#include "merge_vocab.hh"
#include "../enumerate_vocab.hh"
#include "universal_vocab.hh"
#include "../lm_exception.hh"
#include "../vocab.hh"
#include "../../util/file_piece.hh"
#include <queue>
#include <string>
#include <iostream>
#include <vector>
namespace lm {
namespace interpolate {
namespace {
class VocabFileReader {
public:
explicit VocabFileReader(const int fd, size_t model_num, uint64_t offset = 0);
VocabFileReader &operator++();
operator bool() const { return !eof_; }
uint64_t operator*() const { return Value(); }
uint64_t Value() const { return hash_value_; }
size_t ModelNum() const { return model_num_; }
WordIndex CurrentIndex() const { return current_index_; }
StringPiece Word() const { return word_; }
private:
uint64_t hash_value_;
WordIndex current_index_;
bool eof_;
size_t model_num_;
StringPiece word_;
util::FilePiece file_piece_;
};
VocabFileReader::VocabFileReader(const int fd, const size_t model_num, uint64_t offset) :
hash_value_(0),
current_index_(0),
eof_(false),
model_num_(model_num),
file_piece_(util::DupOrThrow(fd)) {
word_ = file_piece_.ReadLine('\0');
UTIL_THROW_IF(word_ != "<unk>",
FormatLoadException,
"Vocabulary words are in the wrong place.");
// setup to initial value
++*this;
}
VocabFileReader &VocabFileReader::operator++() {
try {
word_ = file_piece_.ReadLine('\0');
} catch(util::EndOfFileException &e) {
eof_ = true;
return *this;
}
uint64_t prev_hash_value = hash_value_;
hash_value_ = ngram::detail::HashForVocab(word_.data(), word_.size());
// hash values should be monotonically increasing
UTIL_THROW_IF(hash_value_ < prev_hash_value, FormatLoadException,
": word index not monotonically increasing."
<< " model_num: " << model_num_
<< " prev hash: " << prev_hash_value
<< " new hash: " << hash_value_);
++current_index_;
return *this;
}
class CompareFiles {
public:
bool operator()(const VocabFileReader* x,
const VocabFileReader* y)
{ return x->Value() > y->Value(); }
};
class Readers : public util::FixedArray<VocabFileReader> {
public:
Readers(std::size_t number) : util::FixedArray<VocabFileReader>(number) {}
void push_back(int fd, std::size_t i) {
new(end()) VocabFileReader(fd, i);
Constructed();
}
};
} // namespace
WordIndex MergeVocab(util::FixedArray<int> &files, UniversalVocab &vocab, EnumerateVocab &enumerate) {
typedef std::priority_queue<VocabFileReader*, std::vector<VocabFileReader*>, CompareFiles> HeapType;
HeapType heap;
Readers readers(files.size());
for (size_t i = 0; i < files.size(); ++i) {
readers.push_back(files[i], i);
heap.push(&readers.back());
// initialize first index to 0 for <unk>
vocab.InsertUniversalIdx(i, 0, 0);
}
uint64_t prev_hash_value = 0;
// global_index starts with <unk> which is 0
WordIndex global_index = 0;
enumerate.Add(0, "<unk>");
while (!heap.empty()) {
VocabFileReader* top_vocab_file = heap.top();
if (top_vocab_file->Value() != prev_hash_value) {
enumerate.Add(++global_index, top_vocab_file->Word());
}
vocab.InsertUniversalIdx(top_vocab_file->ModelNum(),
top_vocab_file->CurrentIndex(),
global_index);
prev_hash_value = top_vocab_file->Value();
heap.pop();
if (++(*top_vocab_file)) {
heap.push(top_vocab_file);
}
}
return global_index + 1;
}
} // namespace interpolate
} // namespace lm
|