|
#include "backoff_reunification.hh" |
|
#include "../common/ngram_stream.hh" |
|
|
|
#define BOOST_TEST_MODULE InterpolateBackoffReunificationTest |
|
#include <boost/test/unit_test.hpp> |
|
|
|
namespace lm { |
|
namespace interpolate { |
|
|
|
namespace { |
|
|
|
|
|
|
|
template <uint8_t N> |
|
struct Gram { |
|
WordIndex ids[N]; |
|
float prob; |
|
float boff; |
|
}; |
|
|
|
template <uint8_t N> |
|
struct Grams { |
|
const static Gram<N> grams[]; |
|
}; |
|
|
|
template <> |
|
const Gram<1> Grams<1>::grams[] |
|
= {{{0}, -0.1f, -0.1f}, {{1}, -0.4f, -0.2f}, {{2}, -0.5f, -0.1f}}; |
|
|
|
template <> |
|
const Gram<2> Grams<2>::grams[] = {{{0, 0}, -0.05f, -0.05f}, |
|
{{1, 0}, -0.05f, -0.02f}, |
|
{{1, 1}, -0.2f, -0.04f}, |
|
{{2, 2}, -0.2f, -0.01f}}; |
|
|
|
template <> |
|
const Gram<3> Grams<3>::grams[] = {{{0, 0, 0}, -0.001f, -0.005f}, |
|
{{1, 0, 0}, -0.001f, -0.002f}, |
|
{{2, 0, 0}, -0.001f, -0.003f}, |
|
{{0, 1, 0}, -0.1f, -0.008f}, |
|
{{1, 1, 0}, -0.1f, -0.09f}, |
|
{{1, 1, 1}, -0.2f, -0.08f}}; |
|
|
|
template <uint8_t N> |
|
class WriteInput { |
|
public: |
|
void Run(const util::stream::ChainPosition &position) { |
|
lm::NGramStream<float> output(position); |
|
|
|
for (std::size_t i = 0; i < sizeof(Grams<N>::grams) / sizeof(Gram<N>); |
|
++i, ++output) { |
|
std::copy(Grams<N>::grams[i].ids, Grams<N>::grams[i].ids + N, |
|
output->begin()); |
|
output->Value() = Grams<N>::grams[i].prob; |
|
} |
|
output.Poison(); |
|
} |
|
}; |
|
|
|
template <uint8_t N> |
|
class WriteBackoffs { |
|
public: |
|
void Run(const util::stream::ChainPosition &position) { |
|
util::stream::Stream output(position); |
|
|
|
for (std::size_t i = 0; i < sizeof(Grams<N>::grams) / sizeof(Gram<N>); |
|
++i, ++output) { |
|
*reinterpret_cast<float *>(output.Get()) = Grams<N>::grams[i].boff; |
|
} |
|
output.Poison(); |
|
} |
|
}; |
|
|
|
template <uint8_t N> |
|
class CheckOutput { |
|
public: |
|
void Run(const util::stream::ChainPosition &position) { |
|
lm::NGramStream<ProbBackoff> stream(position); |
|
|
|
std::size_t i = 0; |
|
for (; stream; ++stream, ++i) { |
|
std::stringstream ss; |
|
for (WordIndex *idx = stream->begin(); idx != stream->end(); ++idx) |
|
ss << "(" << *idx << ")"; |
|
|
|
BOOST_CHECK(std::equal(stream->begin(), stream->end(), Grams<N>::grams[i].ids)); |
|
|
|
|
|
BOOST_CHECK_EQUAL(stream->Value().prob, Grams<N>::grams[i].prob); |
|
|
|
|
|
|
|
|
|
BOOST_CHECK_EQUAL(stream->Value().backoff, Grams<N>::grams[i].boff); |
|
|
|
|
|
|
|
} |
|
BOOST_CHECK_EQUAL(i , sizeof(Grams<N>::grams) / sizeof(Gram<N>)); |
|
|
|
|
|
|
|
|
|
} |
|
}; |
|
} |
|
|
|
BOOST_AUTO_TEST_CASE(BackoffReunificationTest) { |
|
util::stream::ChainConfig config; |
|
config.total_memory = 100; |
|
config.block_count = 1; |
|
|
|
util::stream::Chains prob_chains(3); |
|
config.entry_size = NGram<float>::TotalSize(1); |
|
prob_chains.push_back(config); |
|
prob_chains.back() >> WriteInput<1>(); |
|
|
|
config.entry_size = NGram<float>::TotalSize(2); |
|
prob_chains.push_back(config); |
|
prob_chains.back() >> WriteInput<2>(); |
|
|
|
config.entry_size = NGram<float>::TotalSize(3); |
|
prob_chains.push_back(config); |
|
prob_chains.back() >> WriteInput<3>(); |
|
|
|
util::stream::Chains boff_chains(3); |
|
config.entry_size = sizeof(float); |
|
boff_chains.push_back(config); |
|
boff_chains.back() >> WriteBackoffs<1>(); |
|
|
|
boff_chains.push_back(config); |
|
boff_chains.back() >> WriteBackoffs<2>(); |
|
|
|
boff_chains.push_back(config); |
|
boff_chains.back() >> WriteBackoffs<3>(); |
|
|
|
util::stream::ChainPositions prob_pos(prob_chains); |
|
util::stream::ChainPositions boff_pos(boff_chains); |
|
|
|
util::stream::Chains output_chains(3); |
|
for (std::size_t i = 0; i < 3; ++i) { |
|
config.entry_size = NGram<ProbBackoff>::TotalSize(i + 1); |
|
output_chains.push_back(config); |
|
} |
|
|
|
ReunifyBackoff(prob_pos, boff_pos, output_chains); |
|
|
|
output_chains[0] >> CheckOutput<1>(); |
|
output_chains[1] >> CheckOutput<2>(); |
|
output_chains[2] >> CheckOutput<3>(); |
|
|
|
prob_chains >> util::stream::kRecycle; |
|
boff_chains >> util::stream::kRecycle; |
|
|
|
output_chains.Wait(); |
|
} |
|
} |
|
} |
|
|