|
#include "lm/model.hh" |
|
#include "util/file_stream.hh" |
|
#include "util/file.hh" |
|
#include "util/file_piece.hh" |
|
#include "util/usage.hh" |
|
#include "util/thread_pool.hh" |
|
|
|
#include <boost/range/iterator_range.hpp> |
|
#include <boost/program_options.hpp> |
|
|
|
#include <iostream> |
|
|
|
#include <stdint.h> |
|
|
|
namespace { |
|
|
|
template <class Model, class Width> void ConvertToBytes(const Model &model, int fd_in) { |
|
util::FilePiece in(fd_in); |
|
util::FileStream out(1); |
|
Width width; |
|
StringPiece word; |
|
const Width end_sentence = (Width)model.GetVocabulary().EndSentence(); |
|
while (true) { |
|
while (in.ReadWordSameLine(word)) { |
|
width = (Width)model.GetVocabulary().Index(word); |
|
out.write(&width, sizeof(Width)); |
|
} |
|
if (!in.ReadLineOrEOF(word)) break; |
|
out.write(&end_sentence, sizeof(Width)); |
|
} |
|
} |
|
|
|
template <class Model, class Width> class Worker { |
|
public: |
|
explicit Worker(const Model &model, double &add_total) : model_(model), total_(0.0), add_total_(add_total) {} |
|
|
|
|
|
~Worker() { add_total_ += total_; } |
|
|
|
typedef boost::iterator_range<Width *> Request; |
|
|
|
void operator()(Request request) { |
|
const lm::ngram::State *const begin_state = &model_.BeginSentenceState(); |
|
const lm::ngram::State *next_state = begin_state; |
|
const Width kEOS = model_.GetVocabulary().EndSentence(); |
|
float sum = 0.0; |
|
|
|
const Width *even_end = request.begin() + (request.size() & ~1); |
|
|
|
const Width *i; |
|
for (i = request.begin(); i != even_end;) { |
|
sum += model_.FullScore(*next_state, *i, state_[1]).prob; |
|
next_state = (*i++ == kEOS) ? begin_state : &state_[1]; |
|
sum += model_.FullScore(*next_state, *i, state_[0]).prob; |
|
next_state = (*i++ == kEOS) ? begin_state : &state_[0]; |
|
} |
|
|
|
if (request.size() & 1) { |
|
sum += model_.FullScore(*next_state, *i, state_[2]).prob; |
|
next_state = (*i++ == kEOS) ? begin_state : &state_[2]; |
|
} |
|
total_ += sum; |
|
} |
|
|
|
private: |
|
const Model &model_; |
|
double total_; |
|
double &add_total_; |
|
|
|
lm::ngram::State state_[3]; |
|
}; |
|
|
|
struct Config { |
|
int fd_in; |
|
std::size_t threads; |
|
std::size_t buf_per_thread; |
|
bool query; |
|
}; |
|
|
|
template <class Model, class Width> void QueryFromBytes(const Model &model, const Config &config) { |
|
util::FileStream out(1); |
|
out << "Threads: " << config.threads << '\n'; |
|
const Width kEOS = model.GetVocabulary().EndSentence(); |
|
double total = 0.0; |
|
|
|
const std::size_t kInQueue = 3; |
|
std::size_t total_queue = config.threads + kInQueue; |
|
std::vector<Width> backing(config.buf_per_thread * total_queue); |
|
double loaded_cpu; |
|
double loaded_wall; |
|
uint64_t queries = 0; |
|
{ |
|
util::RecyclingThreadPool<Worker<Model, Width> > pool(total_queue, config.threads, Worker<Model, Width>(model, total), boost::iterator_range<Width *>((Width*)0, (Width*)0)); |
|
|
|
for (std::size_t i = 0; i < total_queue; ++i) { |
|
pool.PopulateRecycling(boost::iterator_range<Width *>(&backing[i * config.buf_per_thread], &backing[i * config.buf_per_thread])); |
|
} |
|
|
|
loaded_cpu = util::CPUTime(); |
|
loaded_wall = util::WallTime(); |
|
out << "To Load, CPU: " << loaded_cpu << " Wall: " << loaded_wall << '\n'; |
|
boost::iterator_range<Width *> overhang((Width*)0, (Width*)0); |
|
while (true) { |
|
boost::iterator_range<Width *> buf = pool.Consume(); |
|
std::memmove(buf.begin(), overhang.begin(), overhang.size() * sizeof(Width)); |
|
std::size_t got = util::ReadOrEOF(config.fd_in, buf.begin() + overhang.size(), (config.buf_per_thread - overhang.size()) * sizeof(Width)); |
|
if (!got && overhang.empty()) break; |
|
UTIL_THROW_IF2(got % sizeof(Width), "File size not a multiple of vocab id size " << sizeof(Width)); |
|
Width *read_end = buf.begin() + overhang.size() + got / sizeof(Width); |
|
Width *last_eos; |
|
for (last_eos = read_end - 1; ; --last_eos) { |
|
UTIL_THROW_IF2(last_eos <= buf.begin(), "Encountered a sentence longer than the buffer size of " << config.buf_per_thread << " words. Rerun with increased buffer size. TODO: adaptable buffer"); |
|
if (*last_eos == kEOS) break; |
|
} |
|
buf = boost::iterator_range<Width*>(buf.begin(), last_eos + 1); |
|
overhang = boost::iterator_range<Width*>(last_eos + 1, read_end); |
|
queries += buf.size(); |
|
pool.Produce(buf); |
|
} |
|
} |
|
|
|
double after_cpu = util::CPUTime(); |
|
double after_wall = util::WallTime(); |
|
util::FileStream(2, 70) << "Probability sum: " << total << '\n'; |
|
out << "Queries: " << queries << '\n'; |
|
out << "Excluding load, CPU: " << (after_cpu - loaded_cpu) << " Wall: " << (after_wall - loaded_wall) << '\n'; |
|
double cpu_per_entry = ((after_cpu - loaded_cpu) / static_cast<double>(queries)); |
|
double wall_per_entry = ((after_wall - loaded_wall) / static_cast<double>(queries)); |
|
out << "Seconds per query excluding load, CPU: " << cpu_per_entry << " Wall: " << wall_per_entry << '\n'; |
|
out << "Queries per second excluding load, CPU: " << (1.0/cpu_per_entry) << " Wall: " << (1.0/wall_per_entry) << '\n'; |
|
out << "RSSMax: " << util::RSSMax() << '\n'; |
|
} |
|
|
|
template <class Model, class Width> void DispatchFunction(const Model &model, const Config &config) { |
|
if (config.query) { |
|
QueryFromBytes<Model, Width>(model, config); |
|
} else { |
|
ConvertToBytes<Model, Width>(model, config.fd_in); |
|
} |
|
} |
|
|
|
template <class Model> void DispatchWidth(const char *file, const Config &config) { |
|
lm::ngram::Config model_config; |
|
model_config.load_method = util::READ; |
|
Model model(file, model_config); |
|
uint64_t bound = model.GetVocabulary().Bound(); |
|
if (bound <= 256) { |
|
DispatchFunction<Model, uint8_t>(model, config); |
|
} else if (bound <= 65536) { |
|
DispatchFunction<Model, uint16_t>(model, config); |
|
} else if (bound <= (1ULL << 32)) { |
|
DispatchFunction<Model, uint32_t>(model, config); |
|
} else { |
|
DispatchFunction<Model, uint64_t>(model, config); |
|
} |
|
} |
|
|
|
void Dispatch(const char *file, const Config &config) { |
|
using namespace lm::ngram; |
|
lm::ngram::ModelType model_type; |
|
if (lm::ngram::RecognizeBinary(file, model_type)) { |
|
switch(model_type) { |
|
case PROBING: |
|
DispatchWidth<lm::ngram::ProbingModel>(file, config); |
|
break; |
|
case REST_PROBING: |
|
DispatchWidth<lm::ngram::RestProbingModel>(file, config); |
|
break; |
|
case TRIE: |
|
DispatchWidth<lm::ngram::TrieModel>(file, config); |
|
break; |
|
case QUANT_TRIE: |
|
DispatchWidth<lm::ngram::QuantTrieModel>(file, config); |
|
break; |
|
case ARRAY_TRIE: |
|
DispatchWidth<lm::ngram::ArrayTrieModel>(file, config); |
|
break; |
|
case QUANT_ARRAY_TRIE: |
|
DispatchWidth<lm::ngram::QuantArrayTrieModel>(file, config); |
|
break; |
|
default: |
|
UTIL_THROW(util::Exception, "Unrecognized kenlm model type " << model_type); |
|
} |
|
} else { |
|
UTIL_THROW(util::Exception, "Binarize before running benchmarks."); |
|
} |
|
} |
|
|
|
} |
|
|
|
int main(int argc, char *argv[]) { |
|
try { |
|
Config config; |
|
config.fd_in = 0; |
|
std::string model; |
|
namespace po = boost::program_options; |
|
po::options_description options("Benchmark options"); |
|
options.add_options() |
|
("help,h", po::bool_switch(), "Show help message") |
|
("model,m", po::value<std::string>(&model)->required(), "Model to query or convert vocab ids") |
|
("threads,t", po::value<std::size_t>(&config.threads)->default_value(boost::thread::hardware_concurrency()), "Threads to use (querying only; TODO vocab conversion)") |
|
("buffer,b", po::value<std::size_t>(&config.buf_per_thread)->default_value(4096), "Number of words to buffer per task.") |
|
("vocab,v", po::bool_switch(), "Convert strings to vocab ids") |
|
("query,q", po::bool_switch(), "Query from vocab ids"); |
|
po::variables_map vm; |
|
po::store(po::parse_command_line(argc, argv, options), vm); |
|
if (argc == 1 || vm["help"].as<bool>()) { |
|
std::cerr << "Benchmark program for KenLM. Intended usage:\n" |
|
<< "#Convert text to vocabulary ids offline. These ids are tied to a model.\n" |
|
<< argv[0] << " -v -m $model <$text >$text.vocab\n" |
|
<< "#Ensure files are in RAM.\n" |
|
<< "cat $text.vocab $model >/dev/null\n" |
|
<< "#Timed query against the model.\n" |
|
<< argv[0] << " -q -m $model <$text.vocab\n"; |
|
return 0; |
|
} |
|
po::notify(vm); |
|
if (!(vm["vocab"].as<bool>() ^ vm["query"].as<bool>())) { |
|
std::cerr << "Specify exactly one of -v (vocab conversion) or -q (query)." << std::endl; |
|
return 0; |
|
} |
|
config.query = vm["query"].as<bool>(); |
|
if (!config.threads) { |
|
std::cerr << "Specify a non-zero number of threads with -t." << std::endl; |
|
} |
|
Dispatch(model.c_str(), config); |
|
} catch (const std::exception &e) { |
|
std::cerr << e.what() << std::endl; |
|
return 1; |
|
} |
|
return 0; |
|
} |
|
|