|
#include "merge_vocab.hh" |
|
|
|
#include "../enumerate_vocab.hh" |
|
#include "universal_vocab.hh" |
|
#include "../lm_exception.hh" |
|
#include "../vocab.hh" |
|
#include "../../util/file_piece.hh" |
|
|
|
#include <queue> |
|
#include <string> |
|
#include <iostream> |
|
#include <vector> |
|
|
|
namespace lm { |
|
namespace interpolate { |
|
namespace { |
|
|
|
class VocabFileReader { |
|
public: |
|
explicit VocabFileReader(const int fd, size_t model_num, uint64_t offset = 0); |
|
|
|
VocabFileReader &operator++(); |
|
operator bool() const { return !eof_; } |
|
uint64_t operator*() const { return Value(); } |
|
|
|
uint64_t Value() const { return hash_value_; } |
|
size_t ModelNum() const { return model_num_; } |
|
WordIndex CurrentIndex() const { return current_index_; } |
|
|
|
StringPiece Word() const { return word_; } |
|
|
|
private: |
|
uint64_t hash_value_; |
|
WordIndex current_index_; |
|
bool eof_; |
|
size_t model_num_; |
|
StringPiece word_; |
|
util::FilePiece file_piece_; |
|
}; |
|
|
|
VocabFileReader::VocabFileReader(const int fd, const size_t model_num, uint64_t offset) : |
|
hash_value_(0), |
|
current_index_(0), |
|
eof_(false), |
|
model_num_(model_num), |
|
file_piece_(util::DupOrThrow(fd)) { |
|
word_ = file_piece_.ReadLine('\0'); |
|
UTIL_THROW_IF(word_ != "<unk>", |
|
FormatLoadException, |
|
"Vocabulary words are in the wrong place."); |
|
|
|
++*this; |
|
} |
|
|
|
VocabFileReader &VocabFileReader::operator++() { |
|
try { |
|
word_ = file_piece_.ReadLine('\0'); |
|
} catch(util::EndOfFileException &e) { |
|
eof_ = true; |
|
return *this; |
|
} |
|
uint64_t prev_hash_value = hash_value_; |
|
hash_value_ = ngram::detail::HashForVocab(word_.data(), word_.size()); |
|
|
|
|
|
UTIL_THROW_IF(hash_value_ < prev_hash_value, FormatLoadException, |
|
": word index not monotonically increasing." |
|
<< " model_num: " << model_num_ |
|
<< " prev hash: " << prev_hash_value |
|
<< " new hash: " << hash_value_); |
|
|
|
++current_index_; |
|
return *this; |
|
} |
|
|
|
class CompareFiles { |
|
public: |
|
bool operator()(const VocabFileReader* x, |
|
const VocabFileReader* y) |
|
{ return x->Value() > y->Value(); } |
|
}; |
|
|
|
class Readers : public util::FixedArray<VocabFileReader> { |
|
public: |
|
Readers(std::size_t number) : util::FixedArray<VocabFileReader>(number) {} |
|
void push_back(int fd, std::size_t i) { |
|
new(end()) VocabFileReader(fd, i); |
|
Constructed(); |
|
} |
|
}; |
|
|
|
} |
|
|
|
WordIndex MergeVocab(util::FixedArray<int> &files, UniversalVocab &vocab, EnumerateVocab &enumerate) { |
|
typedef std::priority_queue<VocabFileReader*, std::vector<VocabFileReader*>, CompareFiles> HeapType; |
|
HeapType heap; |
|
Readers readers(files.size()); |
|
for (size_t i = 0; i < files.size(); ++i) { |
|
readers.push_back(files[i], i); |
|
heap.push(&readers.back()); |
|
|
|
vocab.InsertUniversalIdx(i, 0, 0); |
|
} |
|
|
|
uint64_t prev_hash_value = 0; |
|
|
|
WordIndex global_index = 0; |
|
|
|
enumerate.Add(0, "<unk>"); |
|
while (!heap.empty()) { |
|
VocabFileReader* top_vocab_file = heap.top(); |
|
if (top_vocab_file->Value() != prev_hash_value) { |
|
enumerate.Add(++global_index, top_vocab_file->Word()); |
|
} |
|
vocab.InsertUniversalIdx(top_vocab_file->ModelNum(), |
|
top_vocab_file->CurrentIndex(), |
|
global_index); |
|
|
|
prev_hash_value = top_vocab_file->Value(); |
|
|
|
heap.pop(); |
|
if (++(*top_vocab_file)) { |
|
heap.push(top_vocab_file); |
|
} |
|
} |
|
return global_index + 1; |
|
} |
|
|
|
} |
|
} |
|
|
|
|