ndhieunguyen's picture
Add application file
7dd9869
raw
history blame
6.04 kB
import torch
# bert results
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator
import sys, yaml, os
# print( os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
# sys.path.insert(0, 'diffusion_lm/transformers/examples/pytorch/language-modeling')
# sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
# from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise
def load_models(modality, mode, model_name_or_path, emb_dim, file, extra_args=None):
if mode in ['random', 'random1', 'random_up_proj', 'glove']:
if modality == 'synth':
pass# print(file, 'deciding what to load::: ')
# if 'synth128' in file:
# config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k128_trainc20000.yaml'
# else:
# config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k32_trainc20000.yaml'
# import sys, os
# sys.path.insert(0, 'diffusion_lm/synthetic_data/rnns-stacks')
# from dataset import Dataset as SynthDataset
# args_synth = yaml.load(open(config))
# dataset = SynthDataset(args_synth)
# model = torch.nn.Embedding(len(dataset.vocab), emb_dim)
# print('initializing the random embeddings', model)
# # print(os.path.split(file.split('.')[0])[-1])
# # path_save = '{}/random_emb.torch'.format(file)
# path_save = '{}/random_emb.torch'.format(file)
# model.load_state_dict(torch.load(path_save))
# print(dataset.vocab)
# tokenizer = {v: k for k, v in dataset.vocab.items()}
else:
import json
if modality == 'book' or (extra_args is not None and extra_args.use_bert_tokenizer == 'yes'):
pass# tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
# if 'e2e' in file and modality == 'book':
# emb_dim = 1
else:
path_save_tokenizer = '{}/vocab.json'.format(file)
path_save_tokenizer = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json'
print(f'loading from {path_save_tokenizer}')
with open(path_save_tokenizer, 'r') as f:
vocab = json.load(f)
print(len(vocab))
tokenizer = {v: k for k, v in vocab.items()}
model = torch.nn.Embedding(len(tokenizer), emb_dim)
path_save = '{}/random_emb.torch'.format(file)
path_save = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch'
model.load_state_dict(torch.load(path_save))
return model, tokenizer
def load_tokenizer(modality, mode, model_name_or_path):
if mode in ['random', 'random_up_proj', 'glove']:
if modality == 'synth':
print(model_name_or_path, 'deciding what to load::: ')
if 'synth128' in model_name_or_path:
config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k128_trainc20000.yaml'
else:
config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k32_trainc20000.yaml'
import sys, os
sys.path.insert(0, 'diffusion_lm/synthetic_data/rnns-stacks')
from dataset import Dataset as SynthDataset
args_synth = yaml.load(open(config))
dataset = SynthDataset(args_synth)
tokenizer = {v: k for k, v in dataset.vocab.items()}
elif modality =='book':
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
else:
import json
path_save_tokenizer = '{}/vocab.json'.format(model_name_or_path)
with open(path_save_tokenizer, 'r') as f:
vocab = json.load(f)
tokenizer = {v: k for k, v in vocab.items()}
return tokenizer
def rounding_func(mode, text_emb_lst, model, tokenizer, emb_scale_factor=1.0):
decoded_out_lst = []
if mode in ['random', 'random_up_proj', 'glove']:
down_proj_emb = model.weight # input_embs
down_proj_emb2 = None
def get_knn(down_proj_emb, text_emb, dist='cos'):
if dist == 'cos':
adjacency = down_proj_emb @ text_emb.transpose(1, 0).to(down_proj_emb.device)
elif dist == 'l2':
adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand(
down_proj_emb.size(0), -1, -1)
adjacency = -torch.norm(adjacency, dim=-1)
topk_out = torch.topk(adjacency, k=6, dim=0)
return topk_out.values, topk_out.indices
dist = 'l2'
# print(npzfile['arr_0'].shape)
for text_emb in text_emb_lst:
import torch
text_emb = torch.tensor(text_emb)
# print(text_emb.shape)
if len(text_emb.shape) > 2:
text_emb = text_emb.view(-1, text_emb.size(-1))
else:
text_emb = text_emb
val, indices = get_knn((down_proj_emb2 if dist == 'cos' else down_proj_emb),
text_emb.to(down_proj_emb.device), dist=dist)
# generated_lst.append(tuple(indices[0].tolist()))
# print(indices[0].tolist())
# for i in range(64):
# print([tokenizer[x.item()] for x in indices[:,i]])
decoded_out = " ".join([tokenizer[i] for i in indices[0].tolist()])
decoded_out_lst.append(decoded_out)
return decoded_out_lst