|
#ifndef LM_MODEL_H |
|
#define LM_MODEL_H |
|
|
|
#include "bhiksha.hh" |
|
#include "binary_format.hh" |
|
#include "config.hh" |
|
#include "facade.hh" |
|
#include "quantize.hh" |
|
#include "search_hashed.hh" |
|
#include "search_trie.hh" |
|
#include "state.hh" |
|
#include "value.hh" |
|
#include "vocab.hh" |
|
#include "weights.hh" |
|
|
|
#include "../util/murmur_hash.hh" |
|
|
|
#include <algorithm> |
|
#include <vector> |
|
#include <cstring> |
|
|
|
namespace util { class FilePiece; } |
|
|
|
namespace lm { |
|
namespace ngram { |
|
namespace detail { |
|
|
|
|
|
|
|
template <class Search, class VocabularyT> class GenericModel : public base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> { |
|
private: |
|
typedef base::ModelFacade<GenericModel<Search, VocabularyT>, State, VocabularyT> P; |
|
public: |
|
|
|
static const ModelType kModelType; |
|
|
|
static const unsigned int kVersion = Search::kVersion; |
|
|
|
|
|
|
|
|
|
|
|
static uint64_t Size(const std::vector<uint64_t> &counts, const Config &config = Config()); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explicit GenericModel(const char *file, const Config &config = Config()); |
|
|
|
|
|
|
|
|
|
|
|
FullScoreReturn FullScore(const State &in_state, const WordIndex new_word, State &out_state) const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
FullScoreReturn FullScoreForgotState(const WordIndex *context_rbegin, const WordIndex *context_rend, const WordIndex new_word, State &out_state) const; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
void GetState(const WordIndex *context_rbegin, const WordIndex *context_rend, State &out_state) const; |
|
|
|
|
|
|
|
|
|
|
|
FullScoreReturn ExtendLeft( |
|
|
|
const WordIndex *add_rbegin, const WordIndex *add_rend, |
|
|
|
const float *backoff_in, |
|
|
|
uint64_t extend_pointer, |
|
|
|
unsigned char extend_length, |
|
|
|
float *backoff_out, |
|
|
|
unsigned char &next_use) const; |
|
|
|
|
|
|
|
|
|
|
|
float UnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const { |
|
|
|
return Search::kDifferentRest ? InternalUnRest(pointers_begin, pointers_end, first_length) : 0.0; |
|
} |
|
|
|
private: |
|
FullScoreReturn ScoreExceptBackoff(const WordIndex *const context_rbegin, const WordIndex *const context_rend, const WordIndex new_word, State &out_state) const; |
|
|
|
|
|
void ResumeScore(const WordIndex *context_rbegin, const WordIndex *const context_rend, unsigned char starting_order_minus_2, typename Search::Node &node, float *backoff_out, unsigned char &next_use, FullScoreReturn &ret) const; |
|
|
|
|
|
void SetupMemory(void *start, const std::vector<uint64_t> &counts, const Config &config); |
|
|
|
void InitializeFromARPA(int fd, const char *file, const Config &config); |
|
|
|
float InternalUnRest(const uint64_t *pointers_begin, const uint64_t *pointers_end, unsigned char first_length) const; |
|
|
|
BinaryFormat backing_; |
|
|
|
VocabularyT vocab_; |
|
|
|
Search search_; |
|
}; |
|
|
|
} |
|
|
|
|
|
|
|
#define LM_COMMA() , |
|
#define LM_NAME_MODEL(name, from)\ |
|
class name : public from {\ |
|
public:\ |
|
name(const char *file, const Config &config = Config()) : from(file, config) {}\ |
|
}; |
|
|
|
LM_NAME_MODEL(ProbingModel, detail::GenericModel<detail::HashedSearch<BackoffValue> LM_COMMA() ProbingVocabulary>); |
|
LM_NAME_MODEL(RestProbingModel, detail::GenericModel<detail::HashedSearch<RestValue> LM_COMMA() ProbingVocabulary>); |
|
LM_NAME_MODEL(TrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>); |
|
LM_NAME_MODEL(ArrayTrieModel, detail::GenericModel<trie::TrieSearch<DontQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>); |
|
LM_NAME_MODEL(QuantTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::DontBhiksha> LM_COMMA() SortedVocabulary>); |
|
LM_NAME_MODEL(QuantArrayTrieModel, detail::GenericModel<trie::TrieSearch<SeparatelyQuantize LM_COMMA() trie::ArrayBhiksha> LM_COMMA() SortedVocabulary>); |
|
|
|
|
|
typedef ::lm::ngram::ProbingVocabulary Vocabulary; |
|
typedef ProbingModel Model; |
|
|
|
|
|
|
|
|
|
base::Model *LoadVirtual(const char *file_name, const Config &config = Config(), ModelType if_arpa = PROBING); |
|
|
|
} |
|
} |
|
|
|
#endif |
|
|