File size: 3,346 Bytes
dda1539
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import pickle
import sys

import torch

class NGram():
  def __init__(self, corpus, corpus_counts, type):
    self.corpus = corpus
    self.counts = corpus_counts
    self.type = type

  def prob(self, key, next):
    """
    Args:
      key (tuple): tuple of token ID's forming prior
      next (int): probability of next token
    """
    l = len(key)
    if self.type == "bigram":
      assert l == 1
      key = key[0]
    elif self.type == "trigram":
      assert l == 2
    elif self.type == "fourgram":
      assert l == 3
    elif self.type == "fivegram":
      assert l == 4
    elif self.type == "sixgram":
      assert l == 5
    elif self.type == "sevengram":
      assert l == 6
      
    count = 0
    if key in self.corpus:
      count = self.corpus[key].get(next, 0)
      total = sum(self.corpus[key].values())
      return count / total
    else:
      return -1
    
  def ntd(self, key, vocab_size=32000):
    """
    Args:
      key (tuple): tuple of token ID's forming prior
    Returns:
      prob_tensor (torch.Tensor): (vocab_size, ) of full next token probabilities
    """
    if key in self.corpus:
      prob_tensor = torch.zeros(vocab_size)
      total = sum(self.corpus[key].values())
      for next_token in self.corpus[key]:
        prob_tensor[next_token] = self.corpus[key][next_token] / total
      return prob_tensor
    else:
      return None

def make_models(ckpt_path, bigram, trigram, fourgram, fivegram, sixgram, sevengram):
  """
  Loads and returns a list correspoding to bigram to sevengram models, containing
  the models that whose parameters are `True`. See below for expected corpus names.
  Args:
    ckpt_path (str): Location of ngram models
    bigram-sevengram: Which models to load
  Returns:
    List of n-gram models
  """
  models = []
  if bigram:
    print("Making bigram...")
    with open(f"{ckpt_path}/b_d_final.pkl", "rb") as f:
        bigram = pickle.load(f)
    bigram_model = NGram(bigram, None, "bigram")    
    models.append(bigram_model)
    print(sys.getsizeof(bigram))
    
  if trigram:
    print("Making trigram...")
    with open(f"{ckpt_path}/t_d_final.pkl", "rb") as f:
        trigram = pickle.load(f)
    trigram_model = NGram(trigram, None, "trigram")
    models.append(trigram_model)
    print(sys.getsizeof(trigram))
    
  if fourgram:
    print("Making fourgram...")
    with open(f"{ckpt_path}/fo_d_final.pkl", "rb") as f:
        fourgram = pickle.load(f)
    fourgram_model = NGram(fourgram, None, "fourgram")
    models.append(fourgram_model)
    print(sys.getsizeof(fourgram))
  
  if fivegram:
    print("Making fivegram...")
    with open(f"{ckpt_path}/fi_d_final.pkl", "rb") as f:
        fivegram = pickle.load(f)
    fivegram_model = NGram(fivegram, None, "fivegram")
    models.append(fivegram_model)
    print(sys.getsizeof(fivegram))
      
  if sixgram:
    print("Making sixgram...")
    with open(f"{ckpt_path}/si_d_final.pkl", "rb") as f:
        sixgram = pickle.load(f)
    sixgram_model = NGram(sixgram, None, "sixgram")
    models.append(sixgram_model)
    print(sys.getsizeof(sixgram))

  if sevengram:
    print("Making sevengram...")
    with open(f"{ckpt_path}/se_d_final.pkl", "rb") as f:
        sevengram = pickle.load(f)
    sevengram_model = NGram(sevengram, None, "sevengram")
    models.append(sevengram_model)

  return models