|
#ifndef LM_QUANTIZE_H |
|
#define LM_QUANTIZE_H |
|
|
|
#include "lm/blank.hh" |
|
#include "lm/config.hh" |
|
#include "lm/max_order.hh" |
|
#include "lm/model_type.hh" |
|
#include "util/bit_packing.hh" |
|
|
|
#include <algorithm> |
|
#include <vector> |
|
|
|
#include <stdint.h> |
|
|
|
#include <iostream> |
|
|
|
namespace lm { |
|
namespace ngram { |
|
|
|
struct Config; |
|
class BinaryFormat; |
|
|
|
|
|
class DontQuantize { |
|
public: |
|
static const ModelType kModelTypeAdd = static_cast<ModelType>(0); |
|
static void UpdateConfigFromBinary(const BinaryFormat &, uint64_t, Config &) {} |
|
static uint64_t Size(uint8_t , const Config &) { return 0; } |
|
static uint8_t MiddleBits(const Config &) { return 63; } |
|
static uint8_t LongestBits(const Config &) { return 31; } |
|
|
|
class MiddlePointer { |
|
public: |
|
MiddlePointer(const DontQuantize & , unsigned char , util::BitAddress address) : address_(address) {} |
|
|
|
MiddlePointer() : address_(NULL, 0) {} |
|
|
|
bool Found() const { |
|
return address_.base != NULL; |
|
} |
|
|
|
float Prob() const { |
|
return util::ReadNonPositiveFloat31(address_.base, address_.offset); |
|
} |
|
|
|
float Backoff() const { |
|
return util::ReadFloat32(address_.base, address_.offset + 31); |
|
} |
|
|
|
float Rest() const { return Prob(); } |
|
|
|
void Write(float prob, float backoff) { |
|
util::WriteNonPositiveFloat31(address_.base, address_.offset, prob); |
|
util::WriteFloat32(address_.base, address_.offset + 31, backoff); |
|
} |
|
|
|
private: |
|
util::BitAddress address_; |
|
}; |
|
|
|
class LongestPointer { |
|
public: |
|
explicit LongestPointer(const DontQuantize &, util::BitAddress address) : address_(address) {} |
|
|
|
LongestPointer() : address_(NULL, 0) {} |
|
|
|
bool Found() const { |
|
return address_.base != NULL; |
|
} |
|
|
|
float Prob() const { |
|
return util::ReadNonPositiveFloat31(address_.base, address_.offset); |
|
} |
|
|
|
void Write(float prob) { |
|
util::WriteNonPositiveFloat31(address_.base, address_.offset, prob); |
|
} |
|
|
|
private: |
|
util::BitAddress address_; |
|
}; |
|
|
|
DontQuantize() {} |
|
|
|
void SetupMemory(void * , unsigned char , const Config & ) {} |
|
|
|
static const bool kTrain = false; |
|
|
|
void Train(uint8_t , std::vector<float> &, std::vector<float> &) {} |
|
void TrainProb(uint8_t, std::vector<float> &) {} |
|
|
|
void FinishedLoading(const Config &) {} |
|
}; |
|
|
|
class SeparatelyQuantize { |
|
private: |
|
class Bins { |
|
public: |
|
|
|
Bins() {} |
|
|
|
Bins(uint8_t bits, float *begin) : begin_(begin), end_(begin_ + (1ULL << bits)), bits_(bits), mask_((1ULL << bits) - 1) {} |
|
|
|
float *Populate() { return begin_; } |
|
|
|
uint64_t EncodeProb(float value) const { |
|
return Encode(value, 0); |
|
} |
|
|
|
uint64_t EncodeBackoff(float value) const { |
|
if (value == 0.0) { |
|
return HasExtension(value) ? kExtensionQuant : kNoExtensionQuant; |
|
} |
|
return Encode(value, 2); |
|
} |
|
|
|
float Decode(std::size_t off) const { return begin_[off]; } |
|
|
|
uint8_t Bits() const { return bits_; } |
|
|
|
uint64_t Mask() const { return mask_; } |
|
|
|
private: |
|
uint64_t Encode(float value, size_t reserved) const { |
|
const float *above = std::lower_bound(static_cast<const float*>(begin_) + reserved, end_, value); |
|
if (above == begin_ + reserved) return reserved; |
|
if (above == end_) return end_ - begin_ - 1; |
|
return above - begin_ - (value - *(above - 1) < *above - value); |
|
} |
|
|
|
float *begin_; |
|
const float *end_; |
|
uint8_t bits_; |
|
uint64_t mask_; |
|
}; |
|
|
|
public: |
|
static const ModelType kModelTypeAdd = kQuantAdd; |
|
|
|
static void UpdateConfigFromBinary(const BinaryFormat &file, uint64_t offset, Config &config); |
|
|
|
static uint64_t Size(uint8_t order, const Config &config) { |
|
uint64_t longest_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.prob_bits)) * sizeof(float); |
|
uint64_t middle_table = (static_cast<uint64_t>(1) << static_cast<uint64_t>(config.backoff_bits)) * sizeof(float) + longest_table; |
|
|
|
return (order - 2) * middle_table + longest_table + 8; |
|
} |
|
|
|
static uint8_t MiddleBits(const Config &config) { return config.prob_bits + config.backoff_bits; } |
|
static uint8_t LongestBits(const Config &config) { return config.prob_bits; } |
|
|
|
class MiddlePointer { |
|
public: |
|
MiddlePointer(const SeparatelyQuantize &quant, unsigned char order_minus_2, const util::BitAddress &address) : bins_(quant.GetTables(order_minus_2)), address_(address) {} |
|
|
|
MiddlePointer() : address_(NULL, 0) {} |
|
|
|
bool Found() const { return address_.base != NULL; } |
|
|
|
float Prob() const { |
|
return ProbBins().Decode(util::ReadInt25(address_.base, address_.offset + BackoffBins().Bits(), ProbBins().Bits(), ProbBins().Mask())); |
|
} |
|
|
|
float Backoff() const { |
|
return BackoffBins().Decode(util::ReadInt25(address_.base, address_.offset, BackoffBins().Bits(), BackoffBins().Mask())); |
|
} |
|
|
|
float Rest() const { return Prob(); } |
|
|
|
void Write(float prob, float backoff) const { |
|
uint64_t prob_encoded = ProbBins().EncodeProb(prob); |
|
uint64_t backoff_encoded = BackoffBins().EncodeBackoff(backoff); |
|
#if BYTE_ORDER == LITTLE_ENDIAN |
|
prob_encoded <<= BackoffBins().Bits(); |
|
#elif BYTE_ORDER == BIG_ENDIAN |
|
backoff_encoded <<= ProbBins().Bits(); |
|
#endif |
|
util::WriteInt57(address_.base, address_.offset, ProbBins().Bits() + BackoffBins().Bits(), |
|
prob_encoded | backoff_encoded); |
|
} |
|
|
|
private: |
|
const Bins &ProbBins() const { return bins_[0]; } |
|
const Bins &BackoffBins() const { return bins_[1]; } |
|
const Bins *bins_; |
|
|
|
util::BitAddress address_; |
|
}; |
|
|
|
class LongestPointer { |
|
public: |
|
LongestPointer(const SeparatelyQuantize &quant, const util::BitAddress &address) : table_(&quant.LongestTable()), address_(address) {} |
|
|
|
LongestPointer() : address_(NULL, 0) {} |
|
|
|
bool Found() const { return address_.base != NULL; } |
|
|
|
void Write(float prob) const { |
|
util::WriteInt25(address_.base, address_.offset, table_->Bits(), table_->EncodeProb(prob)); |
|
} |
|
|
|
float Prob() const { |
|
return table_->Decode(util::ReadInt25(address_.base, address_.offset, table_->Bits(), table_->Mask())); |
|
} |
|
|
|
private: |
|
const Bins *table_; |
|
util::BitAddress address_; |
|
}; |
|
|
|
SeparatelyQuantize() {} |
|
|
|
void SetupMemory(void *start, unsigned char order, const Config &config); |
|
|
|
static const bool kTrain = true; |
|
|
|
void Train(uint8_t order, std::vector<float> &prob, std::vector<float> &backoff); |
|
|
|
void TrainProb(uint8_t order, std::vector<float> &prob); |
|
|
|
void FinishedLoading(const Config &config); |
|
|
|
const Bins *GetTables(unsigned char order_minus_2) const { return tables_[order_minus_2]; } |
|
|
|
const Bins &LongestTable() const { return longest_; } |
|
|
|
private: |
|
Bins tables_[KENLM_MAX_ORDER - 1][2]; |
|
|
|
Bins longest_; |
|
|
|
uint8_t *actual_base_; |
|
|
|
uint8_t prob_bits_, backoff_bits_; |
|
}; |
|
|
|
} |
|
} |
|
|
|
#endif |
|
|