File size: 4,123 Bytes
1ce325b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
#ifndef LM_FILTER_VOCAB_H
#define LM_FILTER_VOCAB_H
// Vocabulary-based filters for language models.
#include "../../util/multi_intersection.hh"
#include "../../util/string_piece.hh"
#include "../../util/string_piece_hash.hh"
#include "../../util/tokenize_piece.hh"
#include <boost/noncopyable.hpp>
#include <boost/range/iterator_range.hpp>
#include <boost/unordered/unordered_map.hpp>
#include <boost/unordered/unordered_set.hpp>
#include <string>
#include <vector>
namespace lm {
namespace vocab {
void ReadSingle(std::istream &in, boost::unordered_set<std::string> &out);
// Read one sentence vocabulary per line. Return the number of sentences.
unsigned int ReadMultiple(std::istream &in, boost::unordered_map<std::string, std::vector<unsigned int> > &out);
/* Is this a special tag like <s> or <UNK>? This actually includes anything
* surrounded with < and >, which most tokenizers separate for real words, so
* this should not catch real words as it looks at a single token.
*/
inline bool IsTag(const StringPiece &value) {
// The parser should never give an empty string.
assert(!value.empty());
return (value.data()[0] == '<' && value.data()[value.size() - 1] == '>');
}
class Single {
public:
typedef boost::unordered_set<std::string> Words;
explicit Single(const Words &vocab) : vocab_(vocab) {}
template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
for (Iterator i = begin; i != end; ++i) {
if (IsTag(*i)) continue;
if (FindStringPiece(vocab_, *i) == vocab_.end()) return false;
}
return true;
}
private:
const Words &vocab_;
};
class Union {
public:
typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
explicit Union(const Words &vocabs) : vocabs_(vocabs) {}
template <class Iterator> bool PassNGram(const Iterator &begin, const Iterator &end) {
sets_.clear();
for (Iterator i(begin); i != end; ++i) {
if (IsTag(*i)) continue;
Words::const_iterator found(FindStringPiece(vocabs_, *i));
if (vocabs_.end() == found) return false;
sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
}
return (sets_.empty() || util::FirstIntersection(sets_));
}
private:
const Words &vocabs_;
std::vector<boost::iterator_range<const unsigned int*> > sets_;
};
class Multiple {
public:
typedef boost::unordered_map<std::string, std::vector<unsigned int> > Words;
Multiple(const Words &vocabs) : vocabs_(vocabs) {}
private:
// Callback from AllIntersection that does AddNGram.
template <class Output> class Callback {
public:
Callback(Output &out, const StringPiece &line) : out_(out), line_(line) {}
void operator()(unsigned int index) {
out_.SingleAddNGram(index, line_);
}
private:
Output &out_;
const StringPiece &line_;
};
public:
template <class Iterator, class Output> void AddNGram(const Iterator &begin, const Iterator &end, const StringPiece &line, Output &output) {
sets_.clear();
for (Iterator i(begin); i != end; ++i) {
if (IsTag(*i)) continue;
Words::const_iterator found(FindStringPiece(vocabs_, *i));
if (vocabs_.end() == found) return;
sets_.push_back(boost::iterator_range<const unsigned int*>(&*found->second.begin(), &*found->second.end()));
}
if (sets_.empty()) {
output.AddNGram(line);
return;
}
Callback<Output> cb(output, line);
util::AllIntersection(sets_, cb);
}
template <class Output> void AddNGram(const StringPiece &ngram, const StringPiece &line, Output &output) {
AddNGram(util::TokenIter<util::SingleCharacter, true>(ngram, ' '), util::TokenIter<util::SingleCharacter, true>::end(), line, output);
}
void Flush() const {}
private:
const Words &vocabs_;
std::vector<boost::iterator_range<const unsigned int*> > sets_;
};
} // namespace vocab
} // namespace lm
#endif // LM_FILTER_VOCAB_H
|