Spaces:
Running
Running
import torch | |
# bert results | |
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator | |
import sys, yaml, os | |
# print( os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling')) | |
# sys.path.insert(0, 'diffusion_lm/transformers/examples/pytorch/language-modeling') | |
# sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling')) | |
# from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise | |
def load_models(modality, mode, model_name_or_path, emb_dim, file, extra_args=None): | |
if mode in ['random', 'random1', 'random_up_proj', 'glove']: | |
if modality == 'synth': | |
pass# print(file, 'deciding what to load::: ') | |
# if 'synth128' in file: | |
# config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k128_trainc20000.yaml' | |
# else: | |
# config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k32_trainc20000.yaml' | |
# import sys, os | |
# sys.path.insert(0, 'diffusion_lm/synthetic_data/rnns-stacks') | |
# from dataset import Dataset as SynthDataset | |
# args_synth = yaml.load(open(config)) | |
# dataset = SynthDataset(args_synth) | |
# model = torch.nn.Embedding(len(dataset.vocab), emb_dim) | |
# print('initializing the random embeddings', model) | |
# # print(os.path.split(file.split('.')[0])[-1]) | |
# # path_save = '{}/random_emb.torch'.format(file) | |
# path_save = '{}/random_emb.torch'.format(file) | |
# model.load_state_dict(torch.load(path_save)) | |
# print(dataset.vocab) | |
# tokenizer = {v: k for k, v in dataset.vocab.items()} | |
else: | |
import json | |
if modality == 'book' or (extra_args is not None and extra_args.use_bert_tokenizer == 'yes'): | |
pass# tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
# if 'e2e' in file and modality == 'book': | |
# emb_dim = 1 | |
else: | |
path_save_tokenizer = '{}/vocab.json'.format(file) | |
path_save_tokenizer = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json' | |
print(f'loading from {path_save_tokenizer}') | |
with open(path_save_tokenizer, 'r') as f: | |
vocab = json.load(f) | |
print(len(vocab)) | |
tokenizer = {v: k for k, v in vocab.items()} | |
model = torch.nn.Embedding(len(tokenizer), emb_dim) | |
path_save = '{}/random_emb.torch'.format(file) | |
path_save = '/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch' | |
model.load_state_dict(torch.load(path_save)) | |
return model, tokenizer | |
def load_tokenizer(modality, mode, model_name_or_path): | |
if mode in ['random', 'random_up_proj', 'glove']: | |
if modality == 'synth': | |
print(model_name_or_path, 'deciding what to load::: ') | |
if 'synth128' in model_name_or_path: | |
config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k128_trainc20000.yaml' | |
else: | |
config = 'diffusion_lm/synthetic_data/configs/emnlp2020/experiments/difflm_seed0_m3_k32_trainc20000.yaml' | |
import sys, os | |
sys.path.insert(0, 'diffusion_lm/synthetic_data/rnns-stacks') | |
from dataset import Dataset as SynthDataset | |
args_synth = yaml.load(open(config)) | |
dataset = SynthDataset(args_synth) | |
tokenizer = {v: k for k, v in dataset.vocab.items()} | |
elif modality =='book': | |
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
else: | |
import json | |
path_save_tokenizer = '{}/vocab.json'.format(model_name_or_path) | |
with open(path_save_tokenizer, 'r') as f: | |
vocab = json.load(f) | |
tokenizer = {v: k for k, v in vocab.items()} | |
return tokenizer | |
def rounding_func(mode, text_emb_lst, model, tokenizer, emb_scale_factor=1.0): | |
decoded_out_lst = [] | |
if mode in ['random', 'random_up_proj', 'glove']: | |
down_proj_emb = model.weight # input_embs | |
down_proj_emb2 = None | |
def get_knn(down_proj_emb, text_emb, dist='cos'): | |
if dist == 'cos': | |
adjacency = down_proj_emb @ text_emb.transpose(1, 0).to(down_proj_emb.device) | |
elif dist == 'l2': | |
adjacency = down_proj_emb.unsqueeze(1).expand(-1, text_emb.size(0), -1) - text_emb.unsqueeze(0).expand( | |
down_proj_emb.size(0), -1, -1) | |
adjacency = -torch.norm(adjacency, dim=-1) | |
topk_out = torch.topk(adjacency, k=6, dim=0) | |
return topk_out.values, topk_out.indices | |
dist = 'l2' | |
# print(npzfile['arr_0'].shape) | |
for text_emb in text_emb_lst: | |
import torch | |
text_emb = torch.tensor(text_emb) | |
# print(text_emb.shape) | |
if len(text_emb.shape) > 2: | |
text_emb = text_emb.view(-1, text_emb.size(-1)) | |
else: | |
text_emb = text_emb | |
val, indices = get_knn((down_proj_emb2 if dist == 'cos' else down_proj_emb), | |
text_emb.to(down_proj_emb.device), dist=dist) | |
# generated_lst.append(tuple(indices[0].tolist())) | |
# print(indices[0].tolist()) | |
# for i in range(64): | |
# print([tokenizer[x.item()] for x in indices[:,i]]) | |
decoded_out = " ".join([tokenizer[i] for i in indices[0].tolist()]) | |
decoded_out_lst.append(decoded_out) | |
return decoded_out_lst | |