astra / src /evaluate_embeddings.py
suryadev1's picture
v1
6a34fd4
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import numpy
import pickle
import tqdm
from bert import BERT
from vocab import Vocab
from dataset import TokenizerDataset
import argparse
from itertools import combinations
def generate_subset(s):
subsets = []
for r in range(len(s) + 1):
combinations_result = combinations(s, r)
if r==1:
subsets.extend(([item] for sublist in combinations_result for item in sublist))
else:
subsets.extend((list(sublist) for sublist in combinations_result))
subsets_dict = {i:s for i, s in enumerate(subsets)}
return subsets_dict
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-workspace_name', type=str, default=None)
parser.add_argument("-seq_len", type=int, default=100, help="maximum sequence length")
parser.add_argument('-pretrain', type=bool, default=False)
parser.add_argument('-masked_pred', type=bool, default=False)
parser.add_argument('-epoch', type=str, default=None)
# parser.add_argument('-set_label', type=bool, default=False)
# parser.add_argument('--label_standard', nargs='+', type=str, help='List of optional tasks')
options = parser.parse_args()
folder_path = options.workspace_name+"/" if options.workspace_name else ""
# if options.set_label:
# label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2'})
# pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb"))
# else:
# label_standard = pickle.load(open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "rb"))
# print(f"options.label _standard: {options.label_standard}")
vocab_path = f"{folder_path}check/pretraining/vocab.txt"
# vocab_path = f"{folder_path}pretraining/vocab.txt"
print("Loading Vocab", vocab_path)
vocab_obj = Vocab(vocab_path)
vocab_obj.load_vocab()
print("Vocab Size: ", len(vocab_obj.vocab))
# label_standard = list(pickle.load(open(f"dataset/CL4999_1920/{options.workspace_name}/unique_problems_list.pkl", "rb")))
# label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2', 'OptionalTask_1', 'OptionalTask_2'})
# pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb"))
if options.masked_pred:
str_code = "masked_prediction"
output_name = f"{folder_path}output/bert_trained.seq_model.ep{options.epoch}"
else:
str_code = "masked"
output_name = f"{folder_path}output/bert_trained.seq_encoder.model.ep{options.epoch}"
folder_path = folder_path+"check/"
# folder_path = folder_path
if options.pretrain:
pretrain_file = f"{folder_path}pretraining/pretrain.txt"
pretrain_label = f"{folder_path}pretraining/pretrain_opt.pkl"
# pretrain_file = f"{folder_path}finetuning/train.txt"
# pretrain_label = f"{folder_path}finetuning/train_label.txt"
embedding_file_path = f"{folder_path}embeddings/pretrain_embeddings_{str_code}_{options.epoch}.pkl"
print("Loading Pretrain Dataset ", pretrain_file)
pretrain_dataset = TokenizerDataset(pretrain_file, pretrain_label, vocab_obj, seq_len=options.seq_len)
print("Creating Dataloader")
pretrain_data_loader = DataLoader(pretrain_dataset, batch_size=32, num_workers=4)
else:
val_file = f"{folder_path}pretraining/test.txt"
val_label = f"{folder_path}pretraining/test_opt.txt"
# val_file = f"{folder_path}finetuning/test.txt"
# val_label = f"{folder_path}finetuning/test_label.txt"
embedding_file_path = f"{folder_path}embeddings/test_embeddings_{str_code}_{options.epoch}.pkl"
print("Loading Validation Dataset ", val_file)
val_dataset = TokenizerDataset(val_file, val_label, vocab_obj, seq_len=options.seq_len)
print("Creating Dataloader")
val_data_loader = DataLoader(val_dataset, batch_size=32, num_workers=4)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)
print("Load Pre-trained BERT model...")
print(output_name)
bert = torch.load(output_name, map_location=device)
# learned_parameters = model_ep0.state_dict()
for param in bert.parameters():
param.requires_grad = False
if options.pretrain:
print("Pretrain-embeddings....")
data_iter = tqdm.tqdm(enumerate(pretrain_data_loader),
desc="pre-train",
total=len(pretrain_data_loader),
bar_format="{l_bar}{r_bar}")
pretrain_embeddings = []
for i, data in data_iter:
data = {key: value.to(device) for key, value in data.items()}
hrep = bert(data["bert_input"], data["segment_label"])
# print(hrep[:,0].cpu().detach().numpy())
embeddings = [h for h in hrep[:,0].cpu().detach().numpy()]
pretrain_embeddings.extend(embeddings)
pickle.dump(pretrain_embeddings, open(embedding_file_path,"wb"))
# pickle.dump(pretrain_embeddings, open("embeddings/finetune_cfa_train_embeddings.pkl","wb"))
else:
print("Validation-embeddings....")
data_iter = tqdm.tqdm(enumerate(val_data_loader),
desc="validation",
total=len(val_data_loader),
bar_format="{l_bar}{r_bar}")
val_embeddings = []
for i, data in data_iter:
data = {key: value.to(device) for key, value in data.items()}
hrep = bert(data["bert_input"], data["segment_label"])
# print(,hrep[:,0].shape)
embeddings = [h for h in hrep[:,0].cpu().detach().numpy()]
val_embeddings.extend(embeddings)
pickle.dump(val_embeddings, open(embedding_file_path,"wb"))
# pickle.dump(val_embeddings, open("embeddings/finetune_cfa_test_embeddings.pkl","wb"))