Spaces:
Sleeping
Sleeping
# from PIL import Image | |
# import blobfile as bf | |
# from mpi4py import MPI | |
import numpy as np | |
from torch.utils.data import DataLoader, Dataset | |
from transformers import ( | |
AutoModelForCausalLM, | |
AutoConfig, | |
AutoTokenizer, | |
default_data_collator, | |
PreTrainedTokenizerFast, | |
PreTrainedTokenizer, | |
) | |
# from datasets import load_dataset | |
import sys, os | |
import torch | |
# sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling')) | |
# from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise | |
from collections import Counter, defaultdict | |
from functools import partial | |
from itertools import chain | |
def load_data_text( | |
*, | |
data_dir, | |
batch_size, | |
image_size, | |
class_cond=False, | |
deterministic=False, | |
data_args=None, | |
task_mode="roc", | |
model=None, | |
padding_mode="block", | |
split="train", | |
load_vocab=None, | |
): | |
""" | |
For a dataset, create a generator over (images, kwargs) pairs. | |
Each images is an NCHW float tensor, and the kwargs dict contains zero or | |
more keys, each of which map to a batched Tensor of their own. | |
The kwargs dict can be used for class labels, in which case the key is "y" | |
and the values are integer tensors of class labels. | |
:param data_dir: a dataset directory. | |
:param batch_size: the batch size of each returned pair. | |
:param image_size: the size to which images are resized. | |
:param class_cond: if True, include a "y" key in returned dicts for class | |
label. If classes are not available and this is true, an | |
exception will be raised. | |
:param deterministic: if True, yield results in a deterministic order. | |
""" | |
print("hello loading text data. ") | |
if data_args.experiment.startswith("random") and model is None: | |
model = None | |
# elif data_args.experiment.startswith('random') and model is not None: | |
# print('loading initialized random embeddings. ') | |
if task_mode == "roc" or task_mode == "roc-aug": | |
pass | |
# training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
# padding_mode=padding_mode, split=split, | |
# load_vocab=load_vocab) | |
elif task_mode == "simple-wiki": | |
pass | |
# training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
# padding_mode=padding_mode, split=split, | |
# load_vocab=load_vocab) | |
elif task_mode == "e2e-tgt": | |
print("hello loading e2e-tgt. ") | |
training_data, model = get_corpus_rocstory( | |
data_args, | |
model, | |
image_size, | |
padding_mode=padding_mode, | |
split=split, | |
load_vocab=load_vocab, | |
) | |
# elif task_mode == 'yelp': | |
# print('hello loading yelp ') | |
# training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
# padding_mode=padding_mode, split=split, | |
# load_vocab=load_vocab) | |
# elif task_mode == 'commonGen' or task_mode == 'commonGen-aug': | |
# print('hello loading common-gen ') | |
# training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
# padding_mode=padding_mode, split=split, | |
# load_vocab=load_vocab) | |
# elif task_mode == 'e2e': | |
# training_data, model = get_corpus_rocstory(data_args, model, image_size, | |
# padding_mode=padding_mode, split=split, | |
# load_vocab=load_vocab) | |
# elif task_mode == 'book': | |
# tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') | |
# training_data, model = get_corpus_book(data_args, tokenizer, model, image_size, | |
# padding_mode=padding_mode, split=split,) | |
if ( | |
data_args.modality | |
in ["roc-aug", "roc", "book", "yelp", "commonGen", "commonGen-aug"] | |
and data_args.cache_mode == "no" | |
): | |
pass # dataset = TextDataset_NoCache( | |
# training_data, | |
# image_size, | |
# data_args, | |
# model_arch=data_args.model_arch, | |
# model_emb=model | |
# ) | |
else: | |
dataset = TextDataset( | |
training_data, | |
image_size, | |
data_args, | |
model_arch=data_args.model_arch, | |
) | |
if deterministic: | |
pass # data_loader = DataLoader( | |
# dataset, | |
# batch_size=batch_size, # 20, | |
# drop_last=True, | |
# shuffle=False, | |
# num_workers=1, | |
# ) | |
else: | |
data_loader = DataLoader( | |
dataset, | |
batch_size=batch_size, # 20, | |
drop_last=True, | |
shuffle=True, | |
num_workers=1, | |
) | |
while True: | |
yield from data_loader | |
def helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, seqlen, data_args): | |
result_train_lst = [] | |
group_lst = defaultdict(list) | |
with torch.no_grad(): | |
for src_ids, input_ids in sentence_lst: | |
tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids] | |
tokenized_src = [vocab_dict.get(x, vocab_dict["UNK"]) for x in src_ids] | |
input_ids = [0] + tokenized_ + [1] | |
group_lst["word_ids"].append(input_ids) | |
group_lst["src_ids"].append(tokenized_src) | |
print(group_lst["word_ids"][:2]) | |
print("padding mode is pad") | |
max_length = seqlen | |
group_lst["word_ids"] = _collate_batch_helper( | |
group_lst["word_ids"], vocab_dict["PAD"], max_length | |
) | |
max_src_length = max([len(xx) for xx in group_lst["src_ids"]]) | |
print(max_src_length, seqlen) | |
max_src_length = min(seqlen, max_src_length) | |
group_lst["src_ids"], group_lst["src_mask"] = _collate_batch_helper( | |
group_lst["src_ids"], vocab_dict["PAD"], max_src_length, return_mask=True | |
) | |
for input_ids, src_ids, src_mask in zip( | |
group_lst["word_ids"], group_lst["src_ids"], group_lst["src_mask"] | |
): | |
if data_args.experiment.startswith("random"): | |
hidden_state = model(torch.tensor(input_ids)) | |
elif data_args.experiment == "gpt2_pre_compress": | |
input_ids2 = torch.tensor(input_ids).to(model.device) | |
input_embs = model.transformer.wte(input_ids2) # input_embs | |
hidden_state = model.down_proj(input_embs) | |
hidden_state = hidden_state * data_args.emb_scale_factor | |
result_train_lst.append( | |
{ | |
"input_ids": input_ids, | |
"hidden_states": hidden_state.cpu().tolist(), | |
"src_ids": src_ids, | |
"src_mask": src_mask, | |
} | |
) | |
return result_train_lst | |
def helper_tokenize_stream( | |
sentence_lst, | |
vocab_dict, | |
model, | |
seqlen, | |
data_args, | |
padding_mode, | |
): | |
import psutil | |
# Process.memory_info is expressed in bytes, so convert to megabytes | |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
from datasets import Dataset as Dataset2 | |
raw_datasets = Dataset2.from_dict({"text": sentence_lst}) | |
print(raw_datasets) | |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
def tokenize_function(examples): | |
if isinstance(vocab_dict, dict): | |
input_ids = [ | |
[0] + [vocab_dict.get(x, vocab_dict["UNK"]) for x in seq] + [1] | |
for seq in examples["text"] | |
] | |
elif isinstance(vocab_dict, PreTrainedTokenizerFast): | |
examples["text"] = [" ".join(seq) for seq in examples["text"]] | |
input_ids = vocab_dict(examples["text"], add_special_tokens=True)[ | |
"input_ids" | |
] | |
result_dict = {"input_ids": input_ids} | |
# clm input could be much much longer than block_size | |
return result_dict | |
tokenized_datasets = raw_datasets.map( | |
tokenize_function, | |
batched=True, | |
num_proc=4, | |
remove_columns=["text"], | |
load_from_cache_file=True, | |
desc="Running tokenizer on dataset", | |
) | |
print(tokenized_datasets) | |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
if padding_mode == "block": | |
block_size = seqlen | |
def group_texts(examples): | |
concatenated_examples = { | |
k: list(chain(*examples[k])) for k in examples.keys() | |
} | |
total_length = len(concatenated_examples[list(examples.keys())[0]]) | |
if total_length >= block_size: | |
total_length = (total_length // block_size) * block_size | |
result = { | |
k: [t[i : i + block_size] for i in range(0, total_length, block_size)] | |
for k, t in concatenated_examples.items() | |
} | |
result["labels"] = result["input_ids"].copy() | |
return result | |
lm_datasets = tokenized_datasets.map( | |
group_texts, | |
batched=True, | |
num_proc=data_args.preprocessing_num_workers, | |
load_from_cache_file=not data_args.overwrite_cache, | |
desc=f"Grouping texts in chunks of {block_size}", | |
) | |
else: | |
def pad_function(group_lst): | |
max_length = seqlen | |
if isinstance(vocab_dict, dict): | |
group_lst["input_ids"] = _collate_batch_helper( | |
group_lst["input_ids"], vocab_dict["PAD"], max_length | |
) | |
else: | |
group_lst["input_ids"] = _collate_batch_helper( | |
group_lst["input_ids"], vocab_dict.pad_token_id, max_length | |
) | |
return group_lst | |
# Process.memory_info is expressed in bytes, so convert to megabytes | |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
lm_datasets = tokenized_datasets.map( | |
pad_function, | |
batched=True, | |
num_proc=1, | |
desc=f"padding", | |
) | |
print(lm_datasets, "padded dataset") | |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
import datasets | |
raw_datasets = datasets.DatasetDict() | |
raw_datasets["train"] = lm_datasets | |
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") | |
return raw_datasets | |
def helper_tokenize_encode( | |
sentence_lst, | |
vocab_dict, | |
model, | |
seqlen, | |
data_args, | |
padding_mode, | |
): | |
result_train_lst = [] | |
group_lst = defaultdict(list) | |
with torch.no_grad(): | |
for input_ids in sentence_lst: | |
tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids] | |
input_ids = [0] + tokenized_ + [1] | |
group_lst["word_ids"].append(input_ids) | |
print(group_lst["word_ids"][:2]) | |
if padding_mode == "block": | |
print("padding mode is block") | |
concatenated_examples = {k: sum(group_lst[k], []) for k in group_lst.keys()} | |
total_length = len(concatenated_examples[list(group_lst.keys())[0]]) | |
block_size = seqlen | |
total_length = (total_length // block_size) * block_size | |
# Split by chunks of max_len. | |
group_lst = { | |
k: [t[i : i + block_size] for i in range(0, total_length, block_size)] | |
for k, t in concatenated_examples.items() | |
} | |
elif padding_mode == "pad": | |
print("padding mode is pad") | |
max_length = seqlen | |
group_lst["word_ids"] = _collate_batch_helper( | |
group_lst["word_ids"], vocab_dict["PAD"], max_length | |
) | |
for input_ids in group_lst["word_ids"]: | |
if data_args.experiment.startswith("random"): | |
hidden_state = model(torch.tensor(input_ids)) | |
elif data_args.experiment == "gpt2_pre_compress": | |
input_ids2 = torch.tensor(input_ids).to(model.device) | |
input_embs = model.transformer.wte(input_ids2) # input_embs | |
hidden_state = model.down_proj(input_embs) | |
hidden_state = hidden_state * data_args.emb_scale_factor | |
elif data_args.experiment == "glove": | |
hidden_state = model(torch.tensor(input_ids)) | |
result_train_lst.append( | |
{"input_ids": input_ids, "hidden_states": hidden_state.cpu().tolist()} | |
) | |
return result_train_lst | |
def load_glove_model(File): | |
print("Loading Glove Model") | |
glove_model = {} | |
with open(File, "r") as f: | |
for line in f: | |
split_line = line.split() | |
word = split_line[0] | |
embedding = torch.tensor(np.array(split_line[1:], dtype=np.float64)) | |
# embedding = np.array(split_line[1:], dtype=np.float64) | |
glove_model[word] = embedding | |
print(f"{len(glove_model)} words loaded!") | |
return glove_model | |
def load_glove(vocab): | |
model = torch.nn.Embedding(len(vocab), 50) | |
glove_model = load_glove_model("predictability/glove/glove.6B.50d.txt") | |
array_lst = [] | |
count_ = 0 | |
for word, idx in vocab.items(): | |
if word in glove_model: | |
array_lst.append(glove_model[word]) | |
else: | |
count_ += 1 | |
array_lst.append(torch.randn(50)) | |
print(f"{count_} out of {len(vocab)} is initialized. ") | |
array_lst = torch.stack(array_lst) | |
print(torch.norm(array_lst, dim=-1).mean()) | |
model.weight.data = array_lst | |
return model | |
def get_corpus_rocstory( | |
data_args, model, image_size, padding_mode="block", split="train", load_vocab=None | |
): | |
import csv, torch, json | |
from spacy.lang.en import English | |
if data_args.experiment_mode == "lm": | |
if data_args.modality == "roc": | |
pass | |
# print('loading dataset from ROCStory') | |
# nlp = English() | |
# tokenizer = nlp.tokenizer | |
# sentence_lst = [] | |
# print(f'loading from {data_args.roc_train}') | |
# if split == 'train': | |
# print('loading form the TRAIN set') | |
# path = f'{data_args.roc_train}/roc_train.json' | |
# elif split == 'valid': | |
# print('loading form the VALID set') | |
# path = f'{data_args.roc_train}/roc_valid.json' | |
# else: | |
# assert False, "invalid split for ROC dataset" | |
# with open(path, 'r') as roc_reader: | |
# for row in roc_reader: | |
# sentences = json.loads(row)[0].strip() | |
# word_lst = [x.text for x in tokenizer(sentences)] | |
# sentence_lst.append(word_lst) | |
# # with open(data_args.roc_train, 'r') as csvfile: | |
# # roc_reader = csv.reader(csvfile) #delimiter=' ', quotechar='|') | |
# # for row in roc_reader: | |
# # # tokenize. | |
# # sentences = " ".join(row[2:]) | |
# # word_lst = [x.text for x in tokenizer(sentences)] | |
# # sentence_lst.append(word_lst) | |
# # sentence_lst = sentence_lst[1:] | |
# print(sentence_lst[:2]) | |
if data_args.modality == "roc-aug": | |
pass | |
# print('loading dataset from ROCStory') | |
# nlp = English() | |
# tokenizer = nlp.tokenizer | |
# sentence_lst = [] | |
# if split == 'train': | |
# print('loading form the TRAIN set') | |
# path_lst = [f'{data_args.roc_train}/roc_train.json'] | |
# path_lst.append('diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt') | |
# # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc.json') | |
# # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc2.json') | |
# elif split == 'valid': | |
# print('loading form the VALID set') | |
# path_lst = [f'{data_args.roc_train}/roc_valid.json'] | |
# else: | |
# assert False, "invalid split for ROC dataset" | |
# print(path_lst) | |
# for path in path_lst: | |
# if path.endswith('txt'): | |
# with open(path, 'r') as roc_reader: | |
# for row in roc_reader: | |
# sentences = row.strip() | |
# word_lst = [x.text for x in tokenizer(sentences)] | |
# sentence_lst.append(word_lst) | |
# else: | |
# with open(path, 'r') as roc_reader: | |
# for row in roc_reader: | |
# sentences = json.loads(row)[0].strip() | |
# word_lst = [x.text for x in tokenizer(sentences)] | |
# sentence_lst.append(word_lst) | |
# print(sentence_lst[:2],sentence_lst[-2:], 'dataset size=',len(sentence_lst)) | |
elif data_args.modality == "simple-wiki": | |
pass | |
# print('loading dataset from simple wikipedia') | |
# sentence_lst = [] | |
# with open(data_args.wiki_train, 'r') as ff: | |
# for row in ff: | |
# word_lst = row.lower().split() | |
# sentence_lst.append(word_lst) | |
# print(sentence_lst[:2]) | |
elif data_args.modality == "e2e-tgt": | |
print("loading dataset from simple e2e dataset") | |
sentence_lst = [] | |
nlp = English() | |
tokenizer = nlp.tokenizer | |
if split == "train": | |
print("loading form the TRAIN set") | |
path = ( | |
"/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_train.txt" | |
) | |
# path = f'../{data_args.e2e_train}/src1_train.txt' | |
elif split == "valid": | |
print("loading form the VALID set") | |
path = f"../{data_args.e2e_train}/src1_valid.txt" | |
path = ( | |
"/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_valid.txt" | |
) | |
elif split == "test": | |
print("loading form the TEST set") | |
path = f"../{data_args.e2e_train}/src1_test.txt" | |
path = "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_test.txt" | |
elif split == "debug": | |
print("loading form the DEBUG set") | |
path = data_args.debug_path | |
import json | |
with open(path, "r") as ff: | |
for line in ff: | |
sentence_lst.append(json.loads(line)[0].split(" ")) | |
sentence_lst = sentence_lst + sentence_lst | |
if split in ["train", "valid", "test"]: | |
with open(path, "r") as ff: | |
for row in ff: | |
word_lst = row.split("||")[1] | |
word_lst = [x.text for x in tokenizer(word_lst)] | |
sentence_lst.append(word_lst) | |
print(sentence_lst[:2]) | |
elif data_args.modality == "yelp": | |
print("loading dataset from simple YelpNLG dataset") | |
sentence_lst = [] | |
nlp = English() | |
tokenizer = nlp.tokenizer | |
if split == "train": | |
print("loading form the TRAIN set") | |
path = f"{data_args.yelp_train}/yelpnlg-train.csv" | |
elif split == "valid": | |
print("loading form the VALID set") | |
path = f"{data_args.yelp_train}/yelpnlg-dev.csv" | |
elif split == "test": | |
print("loading form the TEST set") | |
path = f"{data_args.yelp_train}/yelpnlg-test.csv" | |
if split in ["train", "valid", "test"]: | |
with open(path, "r") as csvfile: | |
yelp_reader = csv.reader(csvfile) # delimiter=' ', quotechar='|') | |
for row in yelp_reader: | |
sentences = row[1] | |
word_lst = [x.text for x in tokenizer(sentences)] | |
sentence_lst.append(word_lst) | |
sentence_lst = sentence_lst[1:] | |
print(sentence_lst[:2]) | |
elif data_args.modality == "commonGen": | |
print("loading dataset from simple YelpNLG dataset") | |
sentence_lst = [] | |
nlp = English() | |
tokenizer = nlp.tokenizer | |
if split == "train": | |
print("loading form the TRAIN set") | |
path = f"{data_args.commonGen_train}/commongen.train.jsonl" | |
elif split == "valid": | |
print("loading form the VALID set") | |
path = f"{data_args.commonGen_train}/commongen.dev.jsonl" | |
elif split == "test": | |
print("loading form the TEST set") | |
path = f"{data_args.commonGen_train}/commongen.test.jsonl" | |
if split in ["train", "valid", "test"]: | |
with open(path, "r") as ff: | |
for line in ff: | |
line = json.loads(line) | |
for sentences in line["scene"]: | |
word_lst = [x.text for x in tokenizer(sentences)] | |
sentence_lst.append(word_lst) | |
print(sentence_lst[:2]) | |
elif data_args.modality == "commonGen-aug": | |
print("loading dataset from simple YelpNLG dataset") | |
sentence_lst = [] | |
nlp = English() | |
tokenizer = nlp.tokenizer | |
if split == "train": | |
print("loading form the TRAIN set") | |
path = f"{data_args.commonGen_train}/commongen.train.jsonl" | |
path_lst = [f"{data_args.roc_train}/roc_train.json"] | |
path_lst.append( | |
"diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt" | |
) | |
elif split == "valid": | |
print("loading form the VALID set") | |
path = f"{data_args.commonGen_train}/commongen.dev.jsonl" | |
path_lst = [] | |
elif split == "test": | |
print("loading form the TEST set") | |
path = f"{data_args.commonGen_train}/commongen.test.jsonl" | |
path_lst = [] | |
if split in ["train", "valid", "test"]: | |
with open(path, "r") as ff: | |
for line in ff: | |
line = json.loads(line) | |
for sentences in line["scene"]: | |
word_lst = [x.text for x in tokenizer(sentences)] | |
sentence_lst.append(word_lst) | |
print(sentence_lst[:2]) | |
import itertools | |
for path in path_lst: | |
if path.endswith("txt"): | |
with open(path, "r") as roc_reader: | |
for row in roc_reader: | |
sentences = row.strip() | |
word_lst = [x.text for x in tokenizer(sentences)] | |
spl = [[]] | |
for x, y in itertools.groupby(word_lst, lambda z: z == "."): | |
spl[-1].extend(y) | |
if x: | |
spl.append([]) | |
sentence_lst.extend(spl[:-1]) | |
else: | |
with open(path, "r") as roc_reader: | |
for row in roc_reader: | |
sentences = json.loads(row)[0].strip() | |
word_lst = [x.text for x in tokenizer(sentences)] | |
spl = [[]] | |
for x, y in itertools.groupby(word_lst, lambda z: z == "."): | |
spl[-1].extend(y) | |
if x: | |
spl.append([]) | |
sentence_lst.extend(spl[:-1]) | |
print(sentence_lst[-2:]) | |
# get tokenizer. | |
if load_vocab is None: | |
counter = Counter() | |
for input_ids in sentence_lst: | |
counter.update(input_ids) | |
if data_args.experiment_mode == "conditional_gen": | |
if data_args.modality == "e2e": | |
print("loading dataset from simple e2e dataset") | |
sentence_lst = [] | |
nlp = English() | |
tokenizer = nlp.tokenizer | |
if split == "train": | |
path = f"{data_args.e2e_train}/src1_train.txt" | |
with open(path, "r") as ff: | |
for row in ff: | |
src_lst, word_lst = row.split("||") | |
word_lst = [x.text for x in tokenizer(word_lst)] | |
src_lst = [x.text for x in tokenizer(src_lst)] | |
sentence_lst.append((src_lst, word_lst)) | |
elif split == "valid": | |
path = f"{data_args.e2e_train}/src1_valid.txt" | |
sentence_lst = read_e2e_files(path, data_args, tokenizer) | |
print(sentence_lst[:2]) | |
# get tokenizer. | |
if load_vocab is None: | |
counter = Counter() | |
for src_ids, input_ids in sentence_lst: | |
counter.update(input_ids) | |
counter.update(src_ids) | |
if load_vocab is None: | |
vocab_dict = {"START": 0, "END": 1, "UNK": 2, "PAD": 3} | |
for k, v in counter.items(): | |
if v > 10: | |
vocab_dict[k] = len(vocab_dict) | |
print(len(counter), len(vocab_dict)) | |
path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json" | |
print(f"save the vocab to {path_save_vocab}") | |
with open(path_save_vocab, "w") as f: | |
json.dump(vocab_dict, f) | |
else: | |
vocab_dict = load_vocab | |
path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json" | |
if not os.path.exists(path_save_vocab): | |
print(f"save the vocab to {path_save_vocab}") | |
if isinstance(vocab_dict, dict): | |
with open(path_save_vocab, "w") as f: | |
json.dump(vocab_dict, f) | |
assert vocab_dict["START"] == 0 | |
elif isinstance(vocab_dict, PreTrainedTokenizerFast): | |
vocab_dict.save_pretrained(data_args.checkpoint_path) | |
else: | |
assert False, "invalid type of vocab_dict" | |
if model is None and data_args.experiment == "random": | |
model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel) | |
print("initializing the random embeddings", model) | |
torch.nn.init.normal_(model.weight) | |
path_save = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch" | |
print( | |
f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch" | |
) | |
torch.save(model.state_dict(), path_save) | |
# path_save = f'{data_args.checkpoint_path}/random_emb.torch' | |
# if not os.path.exists(path_save) and data_args.experiment == 'random': | |
# torch.save(model.state_dict(), path_save) | |
if ( | |
data_args.experiment_mode == "lm" | |
and data_args.modality | |
in ["roc-aug", "roc", "yelp", "commonGen", "commonGen-aug"] | |
and data_args.cache_mode == "no" | |
): | |
train_dataset = helper_tokenize_stream( | |
sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode | |
) | |
return train_dataset, model | |
elif data_args.experiment_mode == "lm": | |
result_train_lst = helper_tokenize_encode( | |
sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode | |
) | |
elif data_args.experiment_mode == "conditional_gen": | |
result_train_lst = helper_tokenize_encode_cond( | |
sentence_lst, vocab_dict, model, image_size**2, data_args | |
) | |
return {"train": result_train_lst}, model | |
def write_e2e_corr(prompt_lst, file_dict, corr_path): | |
print(len(prompt_lst)) | |
with open(corr_path, "w") as f: | |
for x in prompt_lst: | |
for line in file_dict[x]: | |
print(" ".join(line), file=f) | |
print("", file=f) | |
def write_e2e_src(prompt_lst, corr_path): | |
with open(corr_path, "w") as f: | |
for x in prompt_lst: | |
print(" ".join(x), file=f) | |
return | |
def read_e2e_files(path, args, tokenizer): | |
file_dict = {} | |
with open(path, "r") as f: | |
for line in f: | |
src_lst, word_lst = line.strip().split("||") | |
tgt = tuple([x.text for x in tokenizer(word_lst)]) | |
src = tuple([x.text for x in tokenizer(src_lst)]) | |
if src not in file_dict: | |
file_dict[src] = [] | |
file_dict[src].append(tgt) | |
temp = "1" | |
prompt_text_dict = file_dict | |
prompt_text_lst = list(prompt_text_dict.keys()) | |
gold_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "gold")) | |
print("gold dir", gold_dir) | |
write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir) | |
src_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "src")) | |
write_e2e_src(prompt_text_lst, src_dir) | |
final_lst = [(xx, prompt_text_dict[xx][0]) for xx in prompt_text_lst] | |
return final_lst | |
def get_corpus_book( | |
data_args, | |
tokenizer, | |
model, | |
image_size, | |
padding_mode="block", | |
split="train", | |
): | |
max_length = image_size**2 | |
import os | |
assert padding_mode == "block" | |
raw_datasets = load_dataset("bookcorpus") | |
if "validation" not in raw_datasets.keys(): | |
raw_datasets["validation"] = load_dataset( | |
"bookcorpus", | |
split=f"train[:1%]", | |
) | |
raw_datasets["train"] = load_dataset( | |
"bookcorpus", | |
split=f"train[1%:]", | |
) | |
print(raw_datasets) | |
column_names = raw_datasets["train"].column_names | |
def tokenize_function(examples): | |
output = tokenizer(examples["text"], add_special_tokens=False) | |
return output | |
tokenized_datasets = raw_datasets.map( | |
tokenize_function, | |
batched=True, | |
num_proc=data_args.preprocessing_num_workers, | |
remove_columns=column_names, | |
load_from_cache_file=True, | |
) | |
print(tokenized_datasets) | |
block_size = max_length | |
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. | |
def group_texts(examples): | |
# Concatenate all texts. | |
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} | |
total_length = len(concatenated_examples[list(examples.keys())[0]]) | |
if total_length >= block_size: | |
total_length = (total_length // block_size) * block_size | |
result = { | |
k: [t[i : i + block_size] for i in range(0, total_length, block_size)] | |
for k, t in concatenated_examples.items() | |
} | |
return result | |
lm_datasets = tokenized_datasets.map( | |
group_texts, | |
batched=True, | |
num_proc=4, | |
load_from_cache_file=True, | |
desc=f"Grouping texts in chunks of {block_size}", | |
) | |
print(lm_datasets) | |
if model is None: | |
if data_args.training_mode.startswith("e2e"): | |
print("since its e2e, initialize a dummy embedding") | |
model = torch.nn.Embedding(len(tokenizer), 1) | |
else: | |
model = torch.nn.Embedding(len(tokenizer), data_args.in_channel) | |
print("initializing the random embeddings", model) | |
torch.nn.init.normal_(model.weight) | |
path_save = f"{data_args.checkpoint_path}/random_emb.torch" | |
print( | |
f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch" | |
) | |
torch.save(model.state_dict(), path_save) | |
if split == "train": | |
return lm_datasets, model | |
else: | |
lm_datasets["train"] = lm_datasets["validation"] | |
return lm_datasets, model | |
class TextDataset(Dataset): | |
def __init__( | |
self, | |
text_datasets, | |
resolution, | |
data_args, | |
model_arch="conv-unet", | |
classes=None, | |
shard=0, | |
num_shards=1, | |
eigen_transform=None, | |
mapping_func=None, | |
model_emb=None, | |
): | |
super().__init__() | |
self.resolution = resolution | |
self.text_datasets = text_datasets | |
self.length = len(self.text_datasets["train"]) | |
self.model_arch = model_arch | |
self.data_args = data_args | |
print(self.resolution) | |
self.eigen_transform = eigen_transform | |
self.mapping_func = mapping_func | |
self.model_emb = model_emb | |
# self.local_images = image_paths[shard:][::num_shards] | |
# self.local_classes = None if classes is None else classes[shard:][::num_shards] | |
def __len__(self): | |
return self.length | |
def __getitem__(self, idx): | |
# We are not on a new enough PIL to support the `reducing_gap` | |
# argument, which uses BOX downsampling at powers of two first. | |
# Thus, we do it by hand to improve downsample quality. | |
if self.model_arch == "conv-unet": | |
pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'], | |
# dtype=np.float32).reshape(self.resolution, self.resolution, -1) | |
# # print(self.eigen_transform.shape) | |
# if self.eigen_transform is not None: | |
# old_shape = arr.shape | |
# arr = arr.reshape(1, -1) - self.eigen_transform['mean'] | |
# arr = arr @ self.eigen_transform['map'] | |
# arr = arr.reshape(old_shape) | |
# if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0: | |
# arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype) | |
# out_dict = {} | |
# out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids']) | |
# # if self.local_classes is not None: | |
# # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
# # print(out_dict.keys()) | |
# return np.transpose(arr, [2, 0, 1]), out_dict | |
elif self.model_arch == "1d-unet": | |
pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'], | |
# dtype=np.float32) # seqlen, dim | |
# if self.eigen_transform is not None: | |
# old_shape = arr.shape | |
# arr = arr.reshape(1, -1) - self.eigen_transform['mean'] | |
# arr = arr @ self.eigen_transform['map'] | |
# arr = arr.reshape(old_shape) | |
# if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0: | |
# arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype) | |
# arr = np.transpose(arr, [1, 0]) | |
# out_dict = {} | |
# out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids']) | |
# # out_dict['mapping_func'] = self.mapping_func | |
# # if self.local_classes is not None: | |
# # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
# # print(arr.shape) | |
# return arr, out_dict | |
else: | |
arr = np.array( | |
self.text_datasets["train"][idx]["hidden_states"], dtype=np.float32 | |
) | |
if self.eigen_transform is not None: | |
old_shape = arr.shape | |
# arr = arr.reshape(1, -1) @ self.eigen_transform | |
arr = arr.reshape(1, -1) - self.eigen_transform["mean"] | |
arr = arr @ self.eigen_transform["map"] | |
arr = arr.reshape(old_shape) | |
if ( | |
hasattr(self.data_args, "noise_level") | |
and self.data_args.noise_level > 0 | |
): | |
# print(arr.dtype) | |
# print(self.data_args.noise_level, 'using the noise level.') | |
arr = arr + self.data_args.noise_level * np.random.randn( | |
*arr.shape | |
).astype(arr.dtype) | |
# print(arr.dtype) | |
out_dict = {} | |
out_dict["input_ids"] = np.array( | |
self.text_datasets["train"][idx]["input_ids"] | |
) | |
# out_dict['mapping_func'] = self.mapping_func | |
if self.data_args.experiment_mode == "conditional_gen": | |
out_dict["src_ids"] = np.array( | |
self.text_datasets["train"][idx]["src_ids"] | |
) | |
out_dict["src_mask"] = np.array( | |
self.text_datasets["train"][idx]["src_mask"] | |
) | |
# if self.local_classes is not None: | |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
return arr, out_dict | |
# print(arr.dtype) | |
# arr = arr.float() | |
# print(arr.shape) | |
class TextDataset_NoCache(Dataset): | |
def __init__( | |
self, | |
text_datasets, | |
resolution, | |
data_args, | |
model_arch="conv-unet", | |
classes=None, | |
shard=0, | |
num_shards=1, | |
eigen_transform=None, | |
mapping_func=None, | |
model_emb=None, | |
): | |
super().__init__() | |
self.resolution = resolution | |
self.text_datasets = text_datasets | |
self.length = len(self.text_datasets["train"]) | |
self.model_arch = model_arch | |
self.data_args = data_args | |
print(self.resolution) | |
self.eigen_transform = eigen_transform | |
self.mapping_func = mapping_func | |
self.model_emb = model_emb | |
# self.local_images = image_paths[shard:][::num_shards] | |
# self.local_classes = None if classes is None else classes[shard:][::num_shards] | |
def __len__(self): | |
return self.length | |
def __getitem__(self, idx): | |
# We are not on a new enough PIL to support the `reducing_gap` | |
# argument, which uses BOX downsampling at powers of two first. | |
# Thus, we do it by hand to improve downsample quality. | |
with torch.no_grad(): | |
input_ids = self.text_datasets["train"][idx]["input_ids"] | |
model = self.model_emb | |
if self.data_args.experiment.startswith("random"): | |
hidden_state = model(torch.tensor(input_ids)) | |
elif self.data_args.experiment == "gpt2_pre_compress": | |
input_ids2 = torch.tensor(input_ids).to(model.device) | |
input_embs = model.transformer.wte(input_ids2) # input_embs | |
hidden_state = model.down_proj(input_embs) | |
hidden_state = hidden_state * data_args.emb_scale_factor | |
if self.model_arch == "conv-unet": | |
arr = np.array(hidden_state, dtype=np.float32).reshape( | |
self.resolution, self.resolution, -1 | |
) | |
# print(self.eigen_transform.shape) | |
if self.eigen_transform is not None: | |
old_shape = arr.shape | |
arr = arr.reshape(1, -1) - self.eigen_transform["mean"] | |
arr = arr @ self.eigen_transform["map"] | |
arr = arr.reshape(old_shape) | |
if ( | |
hasattr(self.data_args, "noise_level") | |
and self.data_args.noise_level > 0 | |
): | |
arr = arr + self.data_args.noise_level * np.random.randn( | |
*arr.shape | |
).astype(arr.dtype) | |
out_dict = {} | |
out_dict["input_ids"] = np.array( | |
self.text_datasets["train"][idx]["input_ids"] | |
) | |
# if self.local_classes is not None: | |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
# print(out_dict.keys()) | |
return np.transpose(arr, [2, 0, 1]), out_dict | |
elif self.model_arch == "1d-unet": | |
arr = np.array(hidden_state, dtype=np.float32) # seqlen, dim | |
if self.eigen_transform is not None: | |
old_shape = arr.shape | |
arr = arr.reshape(1, -1) - self.eigen_transform["mean"] | |
arr = arr @ self.eigen_transform["map"] | |
arr = arr.reshape(old_shape) | |
if ( | |
hasattr(self.data_args, "noise_level") | |
and self.data_args.noise_level > 0 | |
): | |
arr = arr + self.data_args.noise_level * np.random.randn( | |
*arr.shape | |
).astype(arr.dtype) | |
arr = np.transpose(arr, [1, 0]) | |
out_dict = {} | |
out_dict["input_ids"] = np.array( | |
self.text_datasets["train"][idx]["input_ids"] | |
) | |
# out_dict['mapping_func'] = self.mapping_func | |
# if self.local_classes is not None: | |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
# print(arr.shape) | |
return arr, out_dict | |
else: | |
arr = np.array(hidden_state, dtype=np.float32) | |
if self.eigen_transform is not None: | |
old_shape = arr.shape | |
# arr = arr.reshape(1, -1) @ self.eigen_transform | |
arr = arr.reshape(1, -1) - self.eigen_transform["mean"] | |
arr = arr @ self.eigen_transform["map"] | |
arr = arr.reshape(old_shape) | |
if ( | |
hasattr(self.data_args, "noise_level") | |
and self.data_args.noise_level > 0 | |
): | |
# print(arr.dtype) | |
# print(self.data_args.noise_level, 'using the noise level.') | |
arr = arr + self.data_args.noise_level * np.random.randn( | |
*arr.shape | |
).astype(arr.dtype) | |
# print(arr.dtype) | |
out_dict = {} | |
out_dict["input_ids"] = np.array( | |
self.text_datasets["train"][idx]["input_ids"] | |
) | |
# out_dict['mapping_func'] = self.mapping_func | |
if self.data_args.experiment_mode == "conditional_gen": | |
out_dict["src_ids"] = np.array( | |
self.text_datasets["train"][idx]["src_ids"] | |
) | |
out_dict["src_mask"] = np.array( | |
self.text_datasets["train"][idx]["src_mask"] | |
) | |
# if self.local_classes is not None: | |
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) | |
return arr, out_dict | |
def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False): | |
result = torch.full( | |
[len(examples), max_length], pad_token_id, dtype=torch.int64 | |
).tolist() | |
mask_ = torch.full( | |
[len(examples), max_length], pad_token_id, dtype=torch.int64 | |
).tolist() | |
for i, example in enumerate(examples): | |
curr_len = min(len(example), max_length) | |
result[i][:curr_len] = example[:curr_len] | |
mask_[i][:curr_len] = [1] * curr_len | |
if return_mask: | |
return result, mask_ | |
return result | |
def _torch_collate_batch(examples, pad_token_id, max_length): | |
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" | |
import numpy as np | |
import torch | |
# Tensorize if necessary. | |
if isinstance(examples[0], (list, tuple, np.ndarray)): | |
examples = [torch.tensor(e, dtype=torch.long) for e in examples] | |
# length_of_first = examples[0].size(0) | |
# Check if padding is necessary. | |
# are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) | |
# if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): | |
# return torch.stack(examples, dim=0) | |
# Creating the full tensor and filling it with our data. | |
# max_length = max(x.size(0) for x in examples) | |
# if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): | |
# max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of | |
result = examples[0].new_full([len(examples), max_length], pad_token_id) | |
for i, example in enumerate(examples): | |
if True: | |
result[i, : example.shape[0]] = example | |
else: | |
result[i, -example.shape[0] :] = example | |
return result | |