Spaces:
Sleeping
Sleeping
from torch.utils.data import DataLoader | |
import torch.nn as nn | |
import torch | |
import numpy | |
import pickle | |
import tqdm | |
from bert import BERT | |
from vocab import Vocab | |
from dataset import TokenizerDataset | |
import argparse | |
from itertools import combinations | |
def generate_subset(s): | |
subsets = [] | |
for r in range(len(s) + 1): | |
combinations_result = combinations(s, r) | |
if r==1: | |
subsets.extend(([item] for sublist in combinations_result for item in sublist)) | |
else: | |
subsets.extend((list(sublist) for sublist in combinations_result)) | |
subsets_dict = {i:s for i, s in enumerate(subsets)} | |
return subsets_dict | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('-workspace_name', type=str, default=None) | |
parser.add_argument("-seq_len", type=int, default=100, help="maximum sequence length") | |
parser.add_argument('-pretrain', type=bool, default=False) | |
parser.add_argument('-masked_pred', type=bool, default=False) | |
parser.add_argument('-epoch', type=str, default=None) | |
# parser.add_argument('-set_label', type=bool, default=False) | |
# parser.add_argument('--label_standard', nargs='+', type=str, help='List of optional tasks') | |
options = parser.parse_args() | |
folder_path = options.workspace_name+"/" if options.workspace_name else "" | |
# if options.set_label: | |
# label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2'}) | |
# pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb")) | |
# else: | |
# label_standard = pickle.load(open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "rb")) | |
# print(f"options.label _standard: {options.label_standard}") | |
vocab_path = f"{folder_path}check/pretraining/vocab.txt" | |
# vocab_path = f"{folder_path}pretraining/vocab.txt" | |
print("Loading Vocab", vocab_path) | |
vocab_obj = Vocab(vocab_path) | |
vocab_obj.load_vocab() | |
print("Vocab Size: ", len(vocab_obj.vocab)) | |
# label_standard = list(pickle.load(open(f"dataset/CL4999_1920/{options.workspace_name}/unique_problems_list.pkl", "rb"))) | |
# label_standard = generate_subset({'optional-tasks-1', 'optional-tasks-2', 'OptionalTask_1', 'OptionalTask_2'}) | |
# pickle.dump(label_standard, open(f"{folder_path}pretraining/pretrain_opt_label.pkl", "wb")) | |
if options.masked_pred: | |
str_code = "masked_prediction" | |
output_name = f"{folder_path}output/bert_trained.seq_model.ep{options.epoch}" | |
else: | |
str_code = "masked" | |
output_name = f"{folder_path}output/bert_trained.seq_encoder.model.ep{options.epoch}" | |
folder_path = folder_path+"check/" | |
# folder_path = folder_path | |
if options.pretrain: | |
pretrain_file = f"{folder_path}pretraining/pretrain.txt" | |
pretrain_label = f"{folder_path}pretraining/pretrain_opt.pkl" | |
# pretrain_file = f"{folder_path}finetuning/train.txt" | |
# pretrain_label = f"{folder_path}finetuning/train_label.txt" | |
embedding_file_path = f"{folder_path}embeddings/pretrain_embeddings_{str_code}_{options.epoch}.pkl" | |
print("Loading Pretrain Dataset ", pretrain_file) | |
pretrain_dataset = TokenizerDataset(pretrain_file, pretrain_label, vocab_obj, seq_len=options.seq_len) | |
print("Creating Dataloader") | |
pretrain_data_loader = DataLoader(pretrain_dataset, batch_size=32, num_workers=4) | |
else: | |
val_file = f"{folder_path}pretraining/test.txt" | |
val_label = f"{folder_path}pretraining/test_opt.txt" | |
# val_file = f"{folder_path}finetuning/test.txt" | |
# val_label = f"{folder_path}finetuning/test_label.txt" | |
embedding_file_path = f"{folder_path}embeddings/test_embeddings_{str_code}_{options.epoch}.pkl" | |
print("Loading Validation Dataset ", val_file) | |
val_dataset = TokenizerDataset(val_file, val_label, vocab_obj, seq_len=options.seq_len) | |
print("Creating Dataloader") | |
val_data_loader = DataLoader(val_dataset, batch_size=32, num_workers=4) | |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
print(device) | |
print("Load Pre-trained BERT model...") | |
print(output_name) | |
bert = torch.load(output_name, map_location=device) | |
# learned_parameters = model_ep0.state_dict() | |
for param in bert.parameters(): | |
param.requires_grad = False | |
if options.pretrain: | |
print("Pretrain-embeddings....") | |
data_iter = tqdm.tqdm(enumerate(pretrain_data_loader), | |
desc="pre-train", | |
total=len(pretrain_data_loader), | |
bar_format="{l_bar}{r_bar}") | |
pretrain_embeddings = [] | |
for i, data in data_iter: | |
data = {key: value.to(device) for key, value in data.items()} | |
hrep = bert(data["bert_input"], data["segment_label"]) | |
# print(hrep[:,0].cpu().detach().numpy()) | |
embeddings = [h for h in hrep[:,0].cpu().detach().numpy()] | |
pretrain_embeddings.extend(embeddings) | |
pickle.dump(pretrain_embeddings, open(embedding_file_path,"wb")) | |
# pickle.dump(pretrain_embeddings, open("embeddings/finetune_cfa_train_embeddings.pkl","wb")) | |
else: | |
print("Validation-embeddings....") | |
data_iter = tqdm.tqdm(enumerate(val_data_loader), | |
desc="validation", | |
total=len(val_data_loader), | |
bar_format="{l_bar}{r_bar}") | |
val_embeddings = [] | |
for i, data in data_iter: | |
data = {key: value.to(device) for key, value in data.items()} | |
hrep = bert(data["bert_input"], data["segment_label"]) | |
# print(,hrep[:,0].shape) | |
embeddings = [h for h in hrep[:,0].cpu().detach().numpy()] | |
val_embeddings.extend(embeddings) | |
pickle.dump(val_embeddings, open(embedding_file_path,"wb")) | |
# pickle.dump(val_embeddings, open("embeddings/finetune_cfa_test_embeddings.pkl","wb")) | |