Ethan Shen
Initial commit
dda1539
raw
history blame
No virus
3.35 kB
import pickle
import sys
import torch
class NGram():
def __init__(self, corpus, corpus_counts, type):
self.corpus = corpus
self.counts = corpus_counts
self.type = type
def prob(self, key, next):
"""
Args:
key (tuple): tuple of token ID's forming prior
next (int): probability of next token
"""
l = len(key)
if self.type == "bigram":
assert l == 1
key = key[0]
elif self.type == "trigram":
assert l == 2
elif self.type == "fourgram":
assert l == 3
elif self.type == "fivegram":
assert l == 4
elif self.type == "sixgram":
assert l == 5
elif self.type == "sevengram":
assert l == 6
count = 0
if key in self.corpus:
count = self.corpus[key].get(next, 0)
total = sum(self.corpus[key].values())
return count / total
else:
return -1
def ntd(self, key, vocab_size=32000):
"""
Args:
key (tuple): tuple of token ID's forming prior
Returns:
prob_tensor (torch.Tensor): (vocab_size, ) of full next token probabilities
"""
if key in self.corpus:
prob_tensor = torch.zeros(vocab_size)
total = sum(self.corpus[key].values())
for next_token in self.corpus[key]:
prob_tensor[next_token] = self.corpus[key][next_token] / total
return prob_tensor
else:
return None
def make_models(ckpt_path, bigram, trigram, fourgram, fivegram, sixgram, sevengram):
"""
Loads and returns a list correspoding to bigram to sevengram models, containing
the models that whose parameters are `True`. See below for expected corpus names.
Args:
ckpt_path (str): Location of ngram models
bigram-sevengram: Which models to load
Returns:
List of n-gram models
"""
models = []
if bigram:
print("Making bigram...")
with open(f"{ckpt_path}/b_d_final.pkl", "rb") as f:
bigram = pickle.load(f)
bigram_model = NGram(bigram, None, "bigram")
models.append(bigram_model)
print(sys.getsizeof(bigram))
if trigram:
print("Making trigram...")
with open(f"{ckpt_path}/t_d_final.pkl", "rb") as f:
trigram = pickle.load(f)
trigram_model = NGram(trigram, None, "trigram")
models.append(trigram_model)
print(sys.getsizeof(trigram))
if fourgram:
print("Making fourgram...")
with open(f"{ckpt_path}/fo_d_final.pkl", "rb") as f:
fourgram = pickle.load(f)
fourgram_model = NGram(fourgram, None, "fourgram")
models.append(fourgram_model)
print(sys.getsizeof(fourgram))
if fivegram:
print("Making fivegram...")
with open(f"{ckpt_path}/fi_d_final.pkl", "rb") as f:
fivegram = pickle.load(f)
fivegram_model = NGram(fivegram, None, "fivegram")
models.append(fivegram_model)
print(sys.getsizeof(fivegram))
if sixgram:
print("Making sixgram...")
with open(f"{ckpt_path}/si_d_final.pkl", "rb") as f:
sixgram = pickle.load(f)
sixgram_model = NGram(sixgram, None, "sixgram")
models.append(sixgram_model)
print(sys.getsizeof(sixgram))
if sevengram:
print("Making sevengram...")
with open(f"{ckpt_path}/se_d_final.pkl", "rb") as f:
sevengram = pickle.load(f)
sevengram_model = NGram(sevengram, None, "sevengram")
models.append(sevengram_model)
return models