File size: 9,039 Bytes
1ce325b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 |
#include "model.hh"
#include "../util/file_stream.hh"
#include "../util/file.hh"
#include "../util/file_piece.hh"
#include "../util/usage.hh"
#include "../util/thread_pool.hh"
#include <boost/range/iterator_range.hpp>
#include <boost/program_options.hpp>
#include <iostream>
#include <stdint.h>
namespace {
template <class Model, class Width> void ConvertToBytes(const Model &model, int fd_in) {
util::FilePiece in(fd_in);
util::FileStream out(1);
Width width;
StringPiece word;
const Width end_sentence = (Width)model.GetVocabulary().EndSentence();
while (true) {
while (in.ReadWordSameLine(word)) {
width = (Width)model.GetVocabulary().Index(word);
out.write(&width, sizeof(Width));
}
if (!in.ReadLineOrEOF(word)) break;
out.write(&end_sentence, sizeof(Width));
}
}
template <class Model, class Width> class Worker {
public:
explicit Worker(const Model &model, double &add_total) : model_(model), total_(0.0), add_total_(add_total) {}
// Destructors happen in the main thread, so there's no race for add_total_.
~Worker() { add_total_ += total_; }
typedef boost::iterator_range<Width *> Request;
void operator()(Request request) {
const lm::ngram::State *const begin_state = &model_.BeginSentenceState();
const lm::ngram::State *next_state = begin_state;
const Width kEOS = model_.GetVocabulary().EndSentence();
float sum = 0.0;
// Do even stuff first.
const Width *even_end = request.begin() + (request.size() & ~1);
// Alternating states
const Width *i;
for (i = request.begin(); i != even_end;) {
sum += model_.FullScore(*next_state, *i, state_[1]).prob;
next_state = (*i++ == kEOS) ? begin_state : &state_[1];
sum += model_.FullScore(*next_state, *i, state_[0]).prob;
next_state = (*i++ == kEOS) ? begin_state : &state_[0];
}
// Odd corner case.
if (request.size() & 1) {
sum += model_.FullScore(*next_state, *i, state_[2]).prob;
next_state = (*i++ == kEOS) ? begin_state : &state_[2];
}
total_ += sum;
}
private:
const Model &model_;
double total_;
double &add_total_;
lm::ngram::State state_[3];
};
struct Config {
int fd_in;
std::size_t threads;
std::size_t buf_per_thread;
bool query;
};
template <class Model, class Width> void QueryFromBytes(const Model &model, const Config &config) {
util::FileStream out(1);
out << "Threads: " << config.threads << '\n';
const Width kEOS = model.GetVocabulary().EndSentence();
double total = 0.0;
// Number of items to have in queue in addition to everything in flight.
const std::size_t kInQueue = 3;
std::size_t total_queue = config.threads + kInQueue;
std::vector<Width> backing(config.buf_per_thread * total_queue);
double loaded_cpu;
double loaded_wall;
uint64_t queries = 0;
{
util::RecyclingThreadPool<Worker<Model, Width> > pool(total_queue, config.threads, Worker<Model, Width>(model, total), boost::iterator_range<Width *>((Width*)0, (Width*)0));
for (std::size_t i = 0; i < total_queue; ++i) {
pool.PopulateRecycling(boost::iterator_range<Width *>(&backing[i * config.buf_per_thread], &backing[i * config.buf_per_thread]));
}
loaded_cpu = util::CPUTime();
loaded_wall = util::WallTime();
out << "To Load, CPU: " << loaded_cpu << " Wall: " << loaded_wall << '\n';
boost::iterator_range<Width *> overhang((Width*)0, (Width*)0);
while (true) {
boost::iterator_range<Width *> buf = pool.Consume();
std::memmove(buf.begin(), overhang.begin(), overhang.size() * sizeof(Width));
std::size_t got = util::ReadOrEOF(config.fd_in, buf.begin() + overhang.size(), (config.buf_per_thread - overhang.size()) * sizeof(Width));
if (!got && overhang.empty()) break;
UTIL_THROW_IF2(got % sizeof(Width), "File size not a multiple of vocab id size " << sizeof(Width));
Width *read_end = buf.begin() + overhang.size() + got / sizeof(Width);
Width *last_eos;
for (last_eos = read_end - 1; ; --last_eos) {
UTIL_THROW_IF2(last_eos <= buf.begin(), "Encountered a sentence longer than the buffer size of " << config.buf_per_thread << " words. Rerun with increased buffer size. TODO: adaptable buffer");
if (*last_eos == kEOS) break;
}
buf = boost::iterator_range<Width*>(buf.begin(), last_eos + 1);
overhang = boost::iterator_range<Width*>(last_eos + 1, read_end);
queries += buf.size();
pool.Produce(buf);
}
} // Drain pool.
double after_cpu = util::CPUTime();
double after_wall = util::WallTime();
util::FileStream(2, 70) << "Probability sum: " << total << '\n';
out << "Queries: " << queries << '\n';
out << "Excluding load, CPU: " << (after_cpu - loaded_cpu) << " Wall: " << (after_wall - loaded_wall) << '\n';
double cpu_per_entry = ((after_cpu - loaded_cpu) / static_cast<double>(queries));
double wall_per_entry = ((after_wall - loaded_wall) / static_cast<double>(queries));
out << "Seconds per query excluding load, CPU: " << cpu_per_entry << " Wall: " << wall_per_entry << '\n';
out << "Queries per second excluding load, CPU: " << (1.0/cpu_per_entry) << " Wall: " << (1.0/wall_per_entry) << '\n';
out << "RSSMax: " << util::RSSMax() << '\n';
}
template <class Model, class Width> void DispatchFunction(const Model &model, const Config &config) {
if (config.query) {
QueryFromBytes<Model, Width>(model, config);
} else {
ConvertToBytes<Model, Width>(model, config.fd_in);
}
}
template <class Model> void DispatchWidth(const char *file, const Config &config) {
lm::ngram::Config model_config;
model_config.load_method = util::READ;
Model model(file, model_config);
uint64_t bound = model.GetVocabulary().Bound();
if (bound <= 256) {
DispatchFunction<Model, uint8_t>(model, config);
} else if (bound <= 65536) {
DispatchFunction<Model, uint16_t>(model, config);
} else if (bound <= (1ULL << 32)) {
DispatchFunction<Model, uint32_t>(model, config);
} else {
DispatchFunction<Model, uint64_t>(model, config);
}
}
void Dispatch(const char *file, const Config &config) {
using namespace lm::ngram;
lm::ngram::ModelType model_type;
if (lm::ngram::RecognizeBinary(file, model_type)) {
switch(model_type) {
case PROBING:
DispatchWidth<lm::ngram::ProbingModel>(file, config);
break;
case REST_PROBING:
DispatchWidth<lm::ngram::RestProbingModel>(file, config);
break;
case TRIE:
DispatchWidth<lm::ngram::TrieModel>(file, config);
break;
case QUANT_TRIE:
DispatchWidth<lm::ngram::QuantTrieModel>(file, config);
break;
case ARRAY_TRIE:
DispatchWidth<lm::ngram::ArrayTrieModel>(file, config);
break;
case QUANT_ARRAY_TRIE:
DispatchWidth<lm::ngram::QuantArrayTrieModel>(file, config);
break;
default:
UTIL_THROW(util::Exception, "Unrecognized kenlm model type " << model_type);
}
} else {
UTIL_THROW(util::Exception, "Binarize before running benchmarks.");
}
}
} // namespace
int main(int argc, char *argv[]) {
try {
Config config;
config.fd_in = 0;
std::string model;
namespace po = boost::program_options;
po::options_description options("Benchmark options");
options.add_options()
("help,h", po::bool_switch(), "Show help message")
("model,m", po::value<std::string>(&model)->required(), "Model to query or convert vocab ids")
("threads,t", po::value<std::size_t>(&config.threads)->default_value(boost::thread::hardware_concurrency()), "Threads to use (querying only; TODO vocab conversion)")
("buffer,b", po::value<std::size_t>(&config.buf_per_thread)->default_value(4096), "Number of words to buffer per task.")
("vocab,v", po::bool_switch(), "Convert strings to vocab ids")
("query,q", po::bool_switch(), "Query from vocab ids");
po::variables_map vm;
po::store(po::parse_command_line(argc, argv, options), vm);
if (argc == 1 || vm["help"].as<bool>()) {
std::cerr << "Benchmark program for KenLM. Intended usage:\n"
<< "#Convert text to vocabulary ids offline. These ids are tied to a model.\n"
<< argv[0] << " -v -m $model <$text >$text.vocab\n"
<< "#Ensure files are in RAM.\n"
<< "cat $text.vocab $model >/dev/null\n"
<< "#Timed query against the model.\n"
<< argv[0] << " -q -m $model <$text.vocab\n";
return 0;
}
po::notify(vm);
if (!(vm["vocab"].as<bool>() ^ vm["query"].as<bool>())) {
std::cerr << "Specify exactly one of -v (vocab conversion) or -q (query)." << std::endl;
return 0;
}
config.query = vm["query"].as<bool>();
if (!config.threads) {
std::cerr << "Specify a non-zero number of threads with -t." << std::endl;
}
Dispatch(model.c_str(), config);
} catch (const std::exception &e) {
std::cerr << e.what() << std::endl;
return 1;
}
return 0;
}
|