ndhieunguyen's picture
feat: remove mpi4py
77180e4
raw
history blame
45.3 kB
# from PIL import Image
# import blobfile as bf
# from mpi4py import MPI
import numpy as np
from torch.utils.data import DataLoader, Dataset
from transformers import (
AutoModelForCausalLM,
AutoConfig,
AutoTokenizer,
default_data_collator,
PreTrainedTokenizerFast,
PreTrainedTokenizer,
)
# from datasets import load_dataset
import sys, os
import torch
# sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling'))
# from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise
from collections import Counter, defaultdict
from functools import partial
from itertools import chain
def load_data_text(
*,
data_dir,
batch_size,
image_size,
class_cond=False,
deterministic=False,
data_args=None,
task_mode="roc",
model=None,
padding_mode="block",
split="train",
load_vocab=None,
):
"""
For a dataset, create a generator over (images, kwargs) pairs.
Each images is an NCHW float tensor, and the kwargs dict contains zero or
more keys, each of which map to a batched Tensor of their own.
The kwargs dict can be used for class labels, in which case the key is "y"
and the values are integer tensors of class labels.
:param data_dir: a dataset directory.
:param batch_size: the batch size of each returned pair.
:param image_size: the size to which images are resized.
:param class_cond: if True, include a "y" key in returned dicts for class
label. If classes are not available and this is true, an
exception will be raised.
:param deterministic: if True, yield results in a deterministic order.
"""
print("hello loading text data. ")
if data_args.experiment.startswith("random") and model is None:
model = None
# elif data_args.experiment.startswith('random') and model is not None:
# print('loading initialized random embeddings. ')
if task_mode == "roc" or task_mode == "roc-aug":
pass
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
# padding_mode=padding_mode, split=split,
# load_vocab=load_vocab)
elif task_mode == "simple-wiki":
pass
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
# padding_mode=padding_mode, split=split,
# load_vocab=load_vocab)
elif task_mode == "e2e-tgt":
print("hello loading e2e-tgt. ")
training_data, model = get_corpus_rocstory(
data_args,
model,
image_size,
padding_mode=padding_mode,
split=split,
load_vocab=load_vocab,
)
# elif task_mode == 'yelp':
# print('hello loading yelp ')
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
# padding_mode=padding_mode, split=split,
# load_vocab=load_vocab)
# elif task_mode == 'commonGen' or task_mode == 'commonGen-aug':
# print('hello loading common-gen ')
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
# padding_mode=padding_mode, split=split,
# load_vocab=load_vocab)
# elif task_mode == 'e2e':
# training_data, model = get_corpus_rocstory(data_args, model, image_size,
# padding_mode=padding_mode, split=split,
# load_vocab=load_vocab)
# elif task_mode == 'book':
# tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
# training_data, model = get_corpus_book(data_args, tokenizer, model, image_size,
# padding_mode=padding_mode, split=split,)
if (
data_args.modality
in ["roc-aug", "roc", "book", "yelp", "commonGen", "commonGen-aug"]
and data_args.cache_mode == "no"
):
pass # dataset = TextDataset_NoCache(
# training_data,
# image_size,
# data_args,
# model_arch=data_args.model_arch,
# model_emb=model
# )
else:
dataset = TextDataset(
training_data,
image_size,
data_args,
model_arch=data_args.model_arch,
)
if deterministic:
pass # data_loader = DataLoader(
# dataset,
# batch_size=batch_size, # 20,
# drop_last=True,
# shuffle=False,
# num_workers=1,
# )
else:
data_loader = DataLoader(
dataset,
batch_size=batch_size, # 20,
drop_last=True,
shuffle=True,
num_workers=1,
)
while True:
yield from data_loader
def helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, seqlen, data_args):
result_train_lst = []
group_lst = defaultdict(list)
with torch.no_grad():
for src_ids, input_ids in sentence_lst:
tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids]
tokenized_src = [vocab_dict.get(x, vocab_dict["UNK"]) for x in src_ids]
input_ids = [0] + tokenized_ + [1]
group_lst["word_ids"].append(input_ids)
group_lst["src_ids"].append(tokenized_src)
print(group_lst["word_ids"][:2])
print("padding mode is pad")
max_length = seqlen
group_lst["word_ids"] = _collate_batch_helper(
group_lst["word_ids"], vocab_dict["PAD"], max_length
)
max_src_length = max([len(xx) for xx in group_lst["src_ids"]])
print(max_src_length, seqlen)
max_src_length = min(seqlen, max_src_length)
group_lst["src_ids"], group_lst["src_mask"] = _collate_batch_helper(
group_lst["src_ids"], vocab_dict["PAD"], max_src_length, return_mask=True
)
for input_ids, src_ids, src_mask in zip(
group_lst["word_ids"], group_lst["src_ids"], group_lst["src_mask"]
):
if data_args.experiment.startswith("random"):
hidden_state = model(torch.tensor(input_ids))
elif data_args.experiment == "gpt2_pre_compress":
input_ids2 = torch.tensor(input_ids).to(model.device)
input_embs = model.transformer.wte(input_ids2) # input_embs
hidden_state = model.down_proj(input_embs)
hidden_state = hidden_state * data_args.emb_scale_factor
result_train_lst.append(
{
"input_ids": input_ids,
"hidden_states": hidden_state.cpu().tolist(),
"src_ids": src_ids,
"src_mask": src_mask,
}
)
return result_train_lst
def helper_tokenize_stream(
sentence_lst,
vocab_dict,
model,
seqlen,
data_args,
padding_mode,
):
import psutil
# Process.memory_info is expressed in bytes, so convert to megabytes
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
from datasets import Dataset as Dataset2
raw_datasets = Dataset2.from_dict({"text": sentence_lst})
print(raw_datasets)
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
def tokenize_function(examples):
if isinstance(vocab_dict, dict):
input_ids = [
[0] + [vocab_dict.get(x, vocab_dict["UNK"]) for x in seq] + [1]
for seq in examples["text"]
]
elif isinstance(vocab_dict, PreTrainedTokenizerFast):
examples["text"] = [" ".join(seq) for seq in examples["text"]]
input_ids = vocab_dict(examples["text"], add_special_tokens=True)[
"input_ids"
]
result_dict = {"input_ids": input_ids}
# clm input could be much much longer than block_size
return result_dict
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
num_proc=4,
remove_columns=["text"],
load_from_cache_file=True,
desc="Running tokenizer on dataset",
)
print(tokenized_datasets)
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
if padding_mode == "block":
block_size = seqlen
def group_texts(examples):
concatenated_examples = {
k: list(chain(*examples[k])) for k in examples.keys()
}
total_length = len(concatenated_examples[list(examples.keys())[0]])
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=data_args.preprocessing_num_workers,
load_from_cache_file=not data_args.overwrite_cache,
desc=f"Grouping texts in chunks of {block_size}",
)
else:
def pad_function(group_lst):
max_length = seqlen
if isinstance(vocab_dict, dict):
group_lst["input_ids"] = _collate_batch_helper(
group_lst["input_ids"], vocab_dict["PAD"], max_length
)
else:
group_lst["input_ids"] = _collate_batch_helper(
group_lst["input_ids"], vocab_dict.pad_token_id, max_length
)
return group_lst
# Process.memory_info is expressed in bytes, so convert to megabytes
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
lm_datasets = tokenized_datasets.map(
pad_function,
batched=True,
num_proc=1,
desc=f"padding",
)
print(lm_datasets, "padded dataset")
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
import datasets
raw_datasets = datasets.DatasetDict()
raw_datasets["train"] = lm_datasets
print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB")
return raw_datasets
def helper_tokenize_encode(
sentence_lst,
vocab_dict,
model,
seqlen,
data_args,
padding_mode,
):
result_train_lst = []
group_lst = defaultdict(list)
with torch.no_grad():
for input_ids in sentence_lst:
tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids]
input_ids = [0] + tokenized_ + [1]
group_lst["word_ids"].append(input_ids)
print(group_lst["word_ids"][:2])
if padding_mode == "block":
print("padding mode is block")
concatenated_examples = {k: sum(group_lst[k], []) for k in group_lst.keys()}
total_length = len(concatenated_examples[list(group_lst.keys())[0]])
block_size = seqlen
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
group_lst = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
elif padding_mode == "pad":
print("padding mode is pad")
max_length = seqlen
group_lst["word_ids"] = _collate_batch_helper(
group_lst["word_ids"], vocab_dict["PAD"], max_length
)
for input_ids in group_lst["word_ids"]:
if data_args.experiment.startswith("random"):
hidden_state = model(torch.tensor(input_ids))
elif data_args.experiment == "gpt2_pre_compress":
input_ids2 = torch.tensor(input_ids).to(model.device)
input_embs = model.transformer.wte(input_ids2) # input_embs
hidden_state = model.down_proj(input_embs)
hidden_state = hidden_state * data_args.emb_scale_factor
elif data_args.experiment == "glove":
hidden_state = model(torch.tensor(input_ids))
result_train_lst.append(
{"input_ids": input_ids, "hidden_states": hidden_state.cpu().tolist()}
)
return result_train_lst
def load_glove_model(File):
print("Loading Glove Model")
glove_model = {}
with open(File, "r") as f:
for line in f:
split_line = line.split()
word = split_line[0]
embedding = torch.tensor(np.array(split_line[1:], dtype=np.float64))
# embedding = np.array(split_line[1:], dtype=np.float64)
glove_model[word] = embedding
print(f"{len(glove_model)} words loaded!")
return glove_model
def load_glove(vocab):
model = torch.nn.Embedding(len(vocab), 50)
glove_model = load_glove_model("predictability/glove/glove.6B.50d.txt")
array_lst = []
count_ = 0
for word, idx in vocab.items():
if word in glove_model:
array_lst.append(glove_model[word])
else:
count_ += 1
array_lst.append(torch.randn(50))
print(f"{count_} out of {len(vocab)} is initialized. ")
array_lst = torch.stack(array_lst)
print(torch.norm(array_lst, dim=-1).mean())
model.weight.data = array_lst
return model
def get_corpus_rocstory(
data_args, model, image_size, padding_mode="block", split="train", load_vocab=None
):
import csv, torch, json
from spacy.lang.en import English
if data_args.experiment_mode == "lm":
if data_args.modality == "roc":
pass
# print('loading dataset from ROCStory')
# nlp = English()
# tokenizer = nlp.tokenizer
# sentence_lst = []
# print(f'loading from {data_args.roc_train}')
# if split == 'train':
# print('loading form the TRAIN set')
# path = f'{data_args.roc_train}/roc_train.json'
# elif split == 'valid':
# print('loading form the VALID set')
# path = f'{data_args.roc_train}/roc_valid.json'
# else:
# assert False, "invalid split for ROC dataset"
# with open(path, 'r') as roc_reader:
# for row in roc_reader:
# sentences = json.loads(row)[0].strip()
# word_lst = [x.text for x in tokenizer(sentences)]
# sentence_lst.append(word_lst)
# # with open(data_args.roc_train, 'r') as csvfile:
# # roc_reader = csv.reader(csvfile) #delimiter=' ', quotechar='|')
# # for row in roc_reader:
# # # tokenize.
# # sentences = " ".join(row[2:])
# # word_lst = [x.text for x in tokenizer(sentences)]
# # sentence_lst.append(word_lst)
# # sentence_lst = sentence_lst[1:]
# print(sentence_lst[:2])
if data_args.modality == "roc-aug":
pass
# print('loading dataset from ROCStory')
# nlp = English()
# tokenizer = nlp.tokenizer
# sentence_lst = []
# if split == 'train':
# print('loading form the TRAIN set')
# path_lst = [f'{data_args.roc_train}/roc_train.json']
# path_lst.append('diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt')
# # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc.json')
# # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc2.json')
# elif split == 'valid':
# print('loading form the VALID set')
# path_lst = [f'{data_args.roc_train}/roc_valid.json']
# else:
# assert False, "invalid split for ROC dataset"
# print(path_lst)
# for path in path_lst:
# if path.endswith('txt'):
# with open(path, 'r') as roc_reader:
# for row in roc_reader:
# sentences = row.strip()
# word_lst = [x.text for x in tokenizer(sentences)]
# sentence_lst.append(word_lst)
# else:
# with open(path, 'r') as roc_reader:
# for row in roc_reader:
# sentences = json.loads(row)[0].strip()
# word_lst = [x.text for x in tokenizer(sentences)]
# sentence_lst.append(word_lst)
# print(sentence_lst[:2],sentence_lst[-2:], 'dataset size=',len(sentence_lst))
elif data_args.modality == "simple-wiki":
pass
# print('loading dataset from simple wikipedia')
# sentence_lst = []
# with open(data_args.wiki_train, 'r') as ff:
# for row in ff:
# word_lst = row.lower().split()
# sentence_lst.append(word_lst)
# print(sentence_lst[:2])
elif data_args.modality == "e2e-tgt":
print("loading dataset from simple e2e dataset")
sentence_lst = []
nlp = English()
tokenizer = nlp.tokenizer
if split == "train":
print("loading form the TRAIN set")
path = (
"/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_train.txt"
)
# path = f'../{data_args.e2e_train}/src1_train.txt'
elif split == "valid":
print("loading form the VALID set")
path = f"../{data_args.e2e_train}/src1_valid.txt"
path = (
"/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_valid.txt"
)
elif split == "test":
print("loading form the TEST set")
path = f"../{data_args.e2e_train}/src1_test.txt"
path = "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_test.txt"
elif split == "debug":
print("loading form the DEBUG set")
path = data_args.debug_path
import json
with open(path, "r") as ff:
for line in ff:
sentence_lst.append(json.loads(line)[0].split(" "))
sentence_lst = sentence_lst + sentence_lst
if split in ["train", "valid", "test"]:
with open(path, "r") as ff:
for row in ff:
word_lst = row.split("||")[1]
word_lst = [x.text for x in tokenizer(word_lst)]
sentence_lst.append(word_lst)
print(sentence_lst[:2])
elif data_args.modality == "yelp":
print("loading dataset from simple YelpNLG dataset")
sentence_lst = []
nlp = English()
tokenizer = nlp.tokenizer
if split == "train":
print("loading form the TRAIN set")
path = f"{data_args.yelp_train}/yelpnlg-train.csv"
elif split == "valid":
print("loading form the VALID set")
path = f"{data_args.yelp_train}/yelpnlg-dev.csv"
elif split == "test":
print("loading form the TEST set")
path = f"{data_args.yelp_train}/yelpnlg-test.csv"
if split in ["train", "valid", "test"]:
with open(path, "r") as csvfile:
yelp_reader = csv.reader(csvfile) # delimiter=' ', quotechar='|')
for row in yelp_reader:
sentences = row[1]
word_lst = [x.text for x in tokenizer(sentences)]
sentence_lst.append(word_lst)
sentence_lst = sentence_lst[1:]
print(sentence_lst[:2])
elif data_args.modality == "commonGen":
print("loading dataset from simple YelpNLG dataset")
sentence_lst = []
nlp = English()
tokenizer = nlp.tokenizer
if split == "train":
print("loading form the TRAIN set")
path = f"{data_args.commonGen_train}/commongen.train.jsonl"
elif split == "valid":
print("loading form the VALID set")
path = f"{data_args.commonGen_train}/commongen.dev.jsonl"
elif split == "test":
print("loading form the TEST set")
path = f"{data_args.commonGen_train}/commongen.test.jsonl"
if split in ["train", "valid", "test"]:
with open(path, "r") as ff:
for line in ff:
line = json.loads(line)
for sentences in line["scene"]:
word_lst = [x.text for x in tokenizer(sentences)]
sentence_lst.append(word_lst)
print(sentence_lst[:2])
elif data_args.modality == "commonGen-aug":
print("loading dataset from simple YelpNLG dataset")
sentence_lst = []
nlp = English()
tokenizer = nlp.tokenizer
if split == "train":
print("loading form the TRAIN set")
path = f"{data_args.commonGen_train}/commongen.train.jsonl"
path_lst = [f"{data_args.roc_train}/roc_train.json"]
path_lst.append(
"diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt"
)
elif split == "valid":
print("loading form the VALID set")
path = f"{data_args.commonGen_train}/commongen.dev.jsonl"
path_lst = []
elif split == "test":
print("loading form the TEST set")
path = f"{data_args.commonGen_train}/commongen.test.jsonl"
path_lst = []
if split in ["train", "valid", "test"]:
with open(path, "r") as ff:
for line in ff:
line = json.loads(line)
for sentences in line["scene"]:
word_lst = [x.text for x in tokenizer(sentences)]
sentence_lst.append(word_lst)
print(sentence_lst[:2])
import itertools
for path in path_lst:
if path.endswith("txt"):
with open(path, "r") as roc_reader:
for row in roc_reader:
sentences = row.strip()
word_lst = [x.text for x in tokenizer(sentences)]
spl = [[]]
for x, y in itertools.groupby(word_lst, lambda z: z == "."):
spl[-1].extend(y)
if x:
spl.append([])
sentence_lst.extend(spl[:-1])
else:
with open(path, "r") as roc_reader:
for row in roc_reader:
sentences = json.loads(row)[0].strip()
word_lst = [x.text for x in tokenizer(sentences)]
spl = [[]]
for x, y in itertools.groupby(word_lst, lambda z: z == "."):
spl[-1].extend(y)
if x:
spl.append([])
sentence_lst.extend(spl[:-1])
print(sentence_lst[-2:])
# get tokenizer.
if load_vocab is None:
counter = Counter()
for input_ids in sentence_lst:
counter.update(input_ids)
if data_args.experiment_mode == "conditional_gen":
if data_args.modality == "e2e":
print("loading dataset from simple e2e dataset")
sentence_lst = []
nlp = English()
tokenizer = nlp.tokenizer
if split == "train":
path = f"{data_args.e2e_train}/src1_train.txt"
with open(path, "r") as ff:
for row in ff:
src_lst, word_lst = row.split("||")
word_lst = [x.text for x in tokenizer(word_lst)]
src_lst = [x.text for x in tokenizer(src_lst)]
sentence_lst.append((src_lst, word_lst))
elif split == "valid":
path = f"{data_args.e2e_train}/src1_valid.txt"
sentence_lst = read_e2e_files(path, data_args, tokenizer)
print(sentence_lst[:2])
# get tokenizer.
if load_vocab is None:
counter = Counter()
for src_ids, input_ids in sentence_lst:
counter.update(input_ids)
counter.update(src_ids)
if load_vocab is None:
vocab_dict = {"START": 0, "END": 1, "UNK": 2, "PAD": 3}
for k, v in counter.items():
if v > 10:
vocab_dict[k] = len(vocab_dict)
print(len(counter), len(vocab_dict))
path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json"
print(f"save the vocab to {path_save_vocab}")
with open(path_save_vocab, "w") as f:
json.dump(vocab_dict, f)
else:
vocab_dict = load_vocab
path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json"
if not os.path.exists(path_save_vocab):
print(f"save the vocab to {path_save_vocab}")
if isinstance(vocab_dict, dict):
with open(path_save_vocab, "w") as f:
json.dump(vocab_dict, f)
assert vocab_dict["START"] == 0
elif isinstance(vocab_dict, PreTrainedTokenizerFast):
vocab_dict.save_pretrained(data_args.checkpoint_path)
else:
assert False, "invalid type of vocab_dict"
if model is None and data_args.experiment == "random":
model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel)
print("initializing the random embeddings", model)
torch.nn.init.normal_(model.weight)
path_save = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch"
print(
f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch"
)
torch.save(model.state_dict(), path_save)
# path_save = f'{data_args.checkpoint_path}/random_emb.torch'
# if not os.path.exists(path_save) and data_args.experiment == 'random':
# torch.save(model.state_dict(), path_save)
if (
data_args.experiment_mode == "lm"
and data_args.modality
in ["roc-aug", "roc", "yelp", "commonGen", "commonGen-aug"]
and data_args.cache_mode == "no"
):
train_dataset = helper_tokenize_stream(
sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode
)
return train_dataset, model
elif data_args.experiment_mode == "lm":
result_train_lst = helper_tokenize_encode(
sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode
)
elif data_args.experiment_mode == "conditional_gen":
result_train_lst = helper_tokenize_encode_cond(
sentence_lst, vocab_dict, model, image_size**2, data_args
)
return {"train": result_train_lst}, model
def write_e2e_corr(prompt_lst, file_dict, corr_path):
print(len(prompt_lst))
with open(corr_path, "w") as f:
for x in prompt_lst:
for line in file_dict[x]:
print(" ".join(line), file=f)
print("", file=f)
def write_e2e_src(prompt_lst, corr_path):
with open(corr_path, "w") as f:
for x in prompt_lst:
print(" ".join(x), file=f)
return
def read_e2e_files(path, args, tokenizer):
file_dict = {}
with open(path, "r") as f:
for line in f:
src_lst, word_lst = line.strip().split("||")
tgt = tuple([x.text for x in tokenizer(word_lst)])
src = tuple([x.text for x in tokenizer(src_lst)])
if src not in file_dict:
file_dict[src] = []
file_dict[src].append(tgt)
temp = "1"
prompt_text_dict = file_dict
prompt_text_lst = list(prompt_text_dict.keys())
gold_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "gold"))
print("gold dir", gold_dir)
write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir)
src_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "src"))
write_e2e_src(prompt_text_lst, src_dir)
final_lst = [(xx, prompt_text_dict[xx][0]) for xx in prompt_text_lst]
return final_lst
def get_corpus_book(
data_args,
tokenizer,
model,
image_size,
padding_mode="block",
split="train",
):
max_length = image_size**2
import os
assert padding_mode == "block"
raw_datasets = load_dataset("bookcorpus")
if "validation" not in raw_datasets.keys():
raw_datasets["validation"] = load_dataset(
"bookcorpus",
split=f"train[:1%]",
)
raw_datasets["train"] = load_dataset(
"bookcorpus",
split=f"train[1%:]",
)
print(raw_datasets)
column_names = raw_datasets["train"].column_names
def tokenize_function(examples):
output = tokenizer(examples["text"], add_special_tokens=False)
return output
tokenized_datasets = raw_datasets.map(
tokenize_function,
batched=True,
num_proc=data_args.preprocessing_num_workers,
remove_columns=column_names,
load_from_cache_file=True,
)
print(tokenized_datasets)
block_size = max_length
# Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size.
def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
if total_length >= block_size:
total_length = (total_length // block_size) * block_size
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
return result
lm_datasets = tokenized_datasets.map(
group_texts,
batched=True,
num_proc=4,
load_from_cache_file=True,
desc=f"Grouping texts in chunks of {block_size}",
)
print(lm_datasets)
if model is None:
if data_args.training_mode.startswith("e2e"):
print("since its e2e, initialize a dummy embedding")
model = torch.nn.Embedding(len(tokenizer), 1)
else:
model = torch.nn.Embedding(len(tokenizer), data_args.in_channel)
print("initializing the random embeddings", model)
torch.nn.init.normal_(model.weight)
path_save = f"{data_args.checkpoint_path}/random_emb.torch"
print(
f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch"
)
torch.save(model.state_dict(), path_save)
if split == "train":
return lm_datasets, model
else:
lm_datasets["train"] = lm_datasets["validation"]
return lm_datasets, model
class TextDataset(Dataset):
def __init__(
self,
text_datasets,
resolution,
data_args,
model_arch="conv-unet",
classes=None,
shard=0,
num_shards=1,
eigen_transform=None,
mapping_func=None,
model_emb=None,
):
super().__init__()
self.resolution = resolution
self.text_datasets = text_datasets
self.length = len(self.text_datasets["train"])
self.model_arch = model_arch
self.data_args = data_args
print(self.resolution)
self.eigen_transform = eigen_transform
self.mapping_func = mapping_func
self.model_emb = model_emb
# self.local_images = image_paths[shard:][::num_shards]
# self.local_classes = None if classes is None else classes[shard:][::num_shards]
def __len__(self):
return self.length
def __getitem__(self, idx):
# We are not on a new enough PIL to support the `reducing_gap`
# argument, which uses BOX downsampling at powers of two first.
# Thus, we do it by hand to improve downsample quality.
if self.model_arch == "conv-unet":
pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
# dtype=np.float32).reshape(self.resolution, self.resolution, -1)
# # print(self.eigen_transform.shape)
# if self.eigen_transform is not None:
# old_shape = arr.shape
# arr = arr.reshape(1, -1) - self.eigen_transform['mean']
# arr = arr @ self.eigen_transform['map']
# arr = arr.reshape(old_shape)
# if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
# arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
# out_dict = {}
# out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
# # if self.local_classes is not None:
# # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
# # print(out_dict.keys())
# return np.transpose(arr, [2, 0, 1]), out_dict
elif self.model_arch == "1d-unet":
pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'],
# dtype=np.float32) # seqlen, dim
# if self.eigen_transform is not None:
# old_shape = arr.shape
# arr = arr.reshape(1, -1) - self.eigen_transform['mean']
# arr = arr @ self.eigen_transform['map']
# arr = arr.reshape(old_shape)
# if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0:
# arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype)
# arr = np.transpose(arr, [1, 0])
# out_dict = {}
# out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids'])
# # out_dict['mapping_func'] = self.mapping_func
# # if self.local_classes is not None:
# # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
# # print(arr.shape)
# return arr, out_dict
else:
arr = np.array(
self.text_datasets["train"][idx]["hidden_states"], dtype=np.float32
)
if self.eigen_transform is not None:
old_shape = arr.shape
# arr = arr.reshape(1, -1) @ self.eigen_transform
arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
arr = arr @ self.eigen_transform["map"]
arr = arr.reshape(old_shape)
if (
hasattr(self.data_args, "noise_level")
and self.data_args.noise_level > 0
):
# print(arr.dtype)
# print(self.data_args.noise_level, 'using the noise level.')
arr = arr + self.data_args.noise_level * np.random.randn(
*arr.shape
).astype(arr.dtype)
# print(arr.dtype)
out_dict = {}
out_dict["input_ids"] = np.array(
self.text_datasets["train"][idx]["input_ids"]
)
# out_dict['mapping_func'] = self.mapping_func
if self.data_args.experiment_mode == "conditional_gen":
out_dict["src_ids"] = np.array(
self.text_datasets["train"][idx]["src_ids"]
)
out_dict["src_mask"] = np.array(
self.text_datasets["train"][idx]["src_mask"]
)
# if self.local_classes is not None:
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
return arr, out_dict
# print(arr.dtype)
# arr = arr.float()
# print(arr.shape)
class TextDataset_NoCache(Dataset):
def __init__(
self,
text_datasets,
resolution,
data_args,
model_arch="conv-unet",
classes=None,
shard=0,
num_shards=1,
eigen_transform=None,
mapping_func=None,
model_emb=None,
):
super().__init__()
self.resolution = resolution
self.text_datasets = text_datasets
self.length = len(self.text_datasets["train"])
self.model_arch = model_arch
self.data_args = data_args
print(self.resolution)
self.eigen_transform = eigen_transform
self.mapping_func = mapping_func
self.model_emb = model_emb
# self.local_images = image_paths[shard:][::num_shards]
# self.local_classes = None if classes is None else classes[shard:][::num_shards]
def __len__(self):
return self.length
def __getitem__(self, idx):
# We are not on a new enough PIL to support the `reducing_gap`
# argument, which uses BOX downsampling at powers of two first.
# Thus, we do it by hand to improve downsample quality.
with torch.no_grad():
input_ids = self.text_datasets["train"][idx]["input_ids"]
model = self.model_emb
if self.data_args.experiment.startswith("random"):
hidden_state = model(torch.tensor(input_ids))
elif self.data_args.experiment == "gpt2_pre_compress":
input_ids2 = torch.tensor(input_ids).to(model.device)
input_embs = model.transformer.wte(input_ids2) # input_embs
hidden_state = model.down_proj(input_embs)
hidden_state = hidden_state * data_args.emb_scale_factor
if self.model_arch == "conv-unet":
arr = np.array(hidden_state, dtype=np.float32).reshape(
self.resolution, self.resolution, -1
)
# print(self.eigen_transform.shape)
if self.eigen_transform is not None:
old_shape = arr.shape
arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
arr = arr @ self.eigen_transform["map"]
arr = arr.reshape(old_shape)
if (
hasattr(self.data_args, "noise_level")
and self.data_args.noise_level > 0
):
arr = arr + self.data_args.noise_level * np.random.randn(
*arr.shape
).astype(arr.dtype)
out_dict = {}
out_dict["input_ids"] = np.array(
self.text_datasets["train"][idx]["input_ids"]
)
# if self.local_classes is not None:
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
# print(out_dict.keys())
return np.transpose(arr, [2, 0, 1]), out_dict
elif self.model_arch == "1d-unet":
arr = np.array(hidden_state, dtype=np.float32) # seqlen, dim
if self.eigen_transform is not None:
old_shape = arr.shape
arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
arr = arr @ self.eigen_transform["map"]
arr = arr.reshape(old_shape)
if (
hasattr(self.data_args, "noise_level")
and self.data_args.noise_level > 0
):
arr = arr + self.data_args.noise_level * np.random.randn(
*arr.shape
).astype(arr.dtype)
arr = np.transpose(arr, [1, 0])
out_dict = {}
out_dict["input_ids"] = np.array(
self.text_datasets["train"][idx]["input_ids"]
)
# out_dict['mapping_func'] = self.mapping_func
# if self.local_classes is not None:
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
# print(arr.shape)
return arr, out_dict
else:
arr = np.array(hidden_state, dtype=np.float32)
if self.eigen_transform is not None:
old_shape = arr.shape
# arr = arr.reshape(1, -1) @ self.eigen_transform
arr = arr.reshape(1, -1) - self.eigen_transform["mean"]
arr = arr @ self.eigen_transform["map"]
arr = arr.reshape(old_shape)
if (
hasattr(self.data_args, "noise_level")
and self.data_args.noise_level > 0
):
# print(arr.dtype)
# print(self.data_args.noise_level, 'using the noise level.')
arr = arr + self.data_args.noise_level * np.random.randn(
*arr.shape
).astype(arr.dtype)
# print(arr.dtype)
out_dict = {}
out_dict["input_ids"] = np.array(
self.text_datasets["train"][idx]["input_ids"]
)
# out_dict['mapping_func'] = self.mapping_func
if self.data_args.experiment_mode == "conditional_gen":
out_dict["src_ids"] = np.array(
self.text_datasets["train"][idx]["src_ids"]
)
out_dict["src_mask"] = np.array(
self.text_datasets["train"][idx]["src_mask"]
)
# if self.local_classes is not None:
# out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
return arr, out_dict
def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False):
result = torch.full(
[len(examples), max_length], pad_token_id, dtype=torch.int64
).tolist()
mask_ = torch.full(
[len(examples), max_length], pad_token_id, dtype=torch.int64
).tolist()
for i, example in enumerate(examples):
curr_len = min(len(example), max_length)
result[i][:curr_len] = example[:curr_len]
mask_[i][:curr_len] = [1] * curr_len
if return_mask:
return result, mask_
return result
def _torch_collate_batch(examples, pad_token_id, max_length):
"""Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""
import numpy as np
import torch
# Tensorize if necessary.
if isinstance(examples[0], (list, tuple, np.ndarray)):
examples = [torch.tensor(e, dtype=torch.long) for e in examples]
# length_of_first = examples[0].size(0)
# Check if padding is necessary.
# are_tensors_same_length = all(x.size(0) == length_of_first for x in examples)
# if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0):
# return torch.stack(examples, dim=0)
# Creating the full tensor and filling it with our data.
# max_length = max(x.size(0) for x in examples)
# if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0):
# max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of
result = examples[0].new_full([len(examples), max_length], pad_token_id)
for i, example in enumerate(examples):
if True:
result[i, : example.shape[0]] = example
else:
result[i, -example.shape[0] :] = example
return result