# from PIL import Image # import blobfile as bf # from mpi4py import MPI import numpy as np from torch.utils.data import DataLoader, Dataset from transformers import ( AutoModelForCausalLM, AutoConfig, AutoTokenizer, default_data_collator, PreTrainedTokenizerFast, PreTrainedTokenizer, ) # from datasets import load_dataset import sys, os import torch # sys.path.insert(0, os.path.join(sys.path[0], '../../transformers/examples/pytorch/language-modeling')) # from custom_trainer import GPT2LMHeadModelCompress, BERTModelCompress, AutoEncoderWithNoise from collections import Counter, defaultdict from functools import partial from itertools import chain def load_data_text( *, data_dir, batch_size, image_size, class_cond=False, deterministic=False, data_args=None, task_mode="roc", model=None, padding_mode="block", split="train", load_vocab=None, ): """ For a dataset, create a generator over (images, kwargs) pairs. Each images is an NCHW float tensor, and the kwargs dict contains zero or more keys, each of which map to a batched Tensor of their own. The kwargs dict can be used for class labels, in which case the key is "y" and the values are integer tensors of class labels. :param data_dir: a dataset directory. :param batch_size: the batch size of each returned pair. :param image_size: the size to which images are resized. :param class_cond: if True, include a "y" key in returned dicts for class label. If classes are not available and this is true, an exception will be raised. :param deterministic: if True, yield results in a deterministic order. """ print("hello loading text data. ") if data_args.experiment.startswith("random") and model is None: model = None # elif data_args.experiment.startswith('random') and model is not None: # print('loading initialized random embeddings. ') if task_mode == "roc" or task_mode == "roc-aug": pass # training_data, model = get_corpus_rocstory(data_args, model, image_size, # padding_mode=padding_mode, split=split, # load_vocab=load_vocab) elif task_mode == "simple-wiki": pass # training_data, model = get_corpus_rocstory(data_args, model, image_size, # padding_mode=padding_mode, split=split, # load_vocab=load_vocab) elif task_mode == "e2e-tgt": print("hello loading e2e-tgt. ") training_data, model = get_corpus_rocstory( data_args, model, image_size, padding_mode=padding_mode, split=split, load_vocab=load_vocab, ) # elif task_mode == 'yelp': # print('hello loading yelp ') # training_data, model = get_corpus_rocstory(data_args, model, image_size, # padding_mode=padding_mode, split=split, # load_vocab=load_vocab) # elif task_mode == 'commonGen' or task_mode == 'commonGen-aug': # print('hello loading common-gen ') # training_data, model = get_corpus_rocstory(data_args, model, image_size, # padding_mode=padding_mode, split=split, # load_vocab=load_vocab) # elif task_mode == 'e2e': # training_data, model = get_corpus_rocstory(data_args, model, image_size, # padding_mode=padding_mode, split=split, # load_vocab=load_vocab) # elif task_mode == 'book': # tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased') # training_data, model = get_corpus_book(data_args, tokenizer, model, image_size, # padding_mode=padding_mode, split=split,) if ( data_args.modality in ["roc-aug", "roc", "book", "yelp", "commonGen", "commonGen-aug"] and data_args.cache_mode == "no" ): pass # dataset = TextDataset_NoCache( # training_data, # image_size, # data_args, # model_arch=data_args.model_arch, # model_emb=model # ) else: dataset = TextDataset( training_data, image_size, data_args, model_arch=data_args.model_arch, ) if deterministic: pass # data_loader = DataLoader( # dataset, # batch_size=batch_size, # 20, # drop_last=True, # shuffle=False, # num_workers=1, # ) else: data_loader = DataLoader( dataset, batch_size=batch_size, # 20, drop_last=True, shuffle=True, num_workers=1, ) while True: yield from data_loader def helper_tokenize_encode_cond(sentence_lst, vocab_dict, model, seqlen, data_args): result_train_lst = [] group_lst = defaultdict(list) with torch.no_grad(): for src_ids, input_ids in sentence_lst: tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids] tokenized_src = [vocab_dict.get(x, vocab_dict["UNK"]) for x in src_ids] input_ids = [0] + tokenized_ + [1] group_lst["word_ids"].append(input_ids) group_lst["src_ids"].append(tokenized_src) print(group_lst["word_ids"][:2]) print("padding mode is pad") max_length = seqlen group_lst["word_ids"] = _collate_batch_helper( group_lst["word_ids"], vocab_dict["PAD"], max_length ) max_src_length = max([len(xx) for xx in group_lst["src_ids"]]) print(max_src_length, seqlen) max_src_length = min(seqlen, max_src_length) group_lst["src_ids"], group_lst["src_mask"] = _collate_batch_helper( group_lst["src_ids"], vocab_dict["PAD"], max_src_length, return_mask=True ) for input_ids, src_ids, src_mask in zip( group_lst["word_ids"], group_lst["src_ids"], group_lst["src_mask"] ): if data_args.experiment.startswith("random"): hidden_state = model(torch.tensor(input_ids)) elif data_args.experiment == "gpt2_pre_compress": input_ids2 = torch.tensor(input_ids).to(model.device) input_embs = model.transformer.wte(input_ids2) # input_embs hidden_state = model.down_proj(input_embs) hidden_state = hidden_state * data_args.emb_scale_factor result_train_lst.append( { "input_ids": input_ids, "hidden_states": hidden_state.cpu().tolist(), "src_ids": src_ids, "src_mask": src_mask, } ) return result_train_lst def helper_tokenize_stream( sentence_lst, vocab_dict, model, seqlen, data_args, padding_mode, ): import psutil # Process.memory_info is expressed in bytes, so convert to megabytes print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") from datasets import Dataset as Dataset2 raw_datasets = Dataset2.from_dict({"text": sentence_lst}) print(raw_datasets) print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") def tokenize_function(examples): if isinstance(vocab_dict, dict): input_ids = [ [0] + [vocab_dict.get(x, vocab_dict["UNK"]) for x in seq] + [1] for seq in examples["text"] ] elif isinstance(vocab_dict, PreTrainedTokenizerFast): examples["text"] = [" ".join(seq) for seq in examples["text"]] input_ids = vocab_dict(examples["text"], add_special_tokens=True)[ "input_ids" ] result_dict = {"input_ids": input_ids} # clm input could be much much longer than block_size return result_dict tokenized_datasets = raw_datasets.map( tokenize_function, batched=True, num_proc=4, remove_columns=["text"], load_from_cache_file=True, desc="Running tokenizer on dataset", ) print(tokenized_datasets) print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") if padding_mode == "block": block_size = seqlen def group_texts(examples): concatenated_examples = { k: list(chain(*examples[k])) for k in examples.keys() } total_length = len(concatenated_examples[list(examples.keys())[0]]) if total_length >= block_size: total_length = (total_length // block_size) * block_size result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result lm_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=data_args.preprocessing_num_workers, load_from_cache_file=not data_args.overwrite_cache, desc=f"Grouping texts in chunks of {block_size}", ) else: def pad_function(group_lst): max_length = seqlen if isinstance(vocab_dict, dict): group_lst["input_ids"] = _collate_batch_helper( group_lst["input_ids"], vocab_dict["PAD"], max_length ) else: group_lst["input_ids"] = _collate_batch_helper( group_lst["input_ids"], vocab_dict.pad_token_id, max_length ) return group_lst # Process.memory_info is expressed in bytes, so convert to megabytes print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") lm_datasets = tokenized_datasets.map( pad_function, batched=True, num_proc=1, desc=f"padding", ) print(lm_datasets, "padded dataset") print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") import datasets raw_datasets = datasets.DatasetDict() raw_datasets["train"] = lm_datasets print(f"RAM used: {psutil.Process().memory_info().rss / (1024 * 1024):.2f} MB") return raw_datasets def helper_tokenize_encode( sentence_lst, vocab_dict, model, seqlen, data_args, padding_mode, ): result_train_lst = [] group_lst = defaultdict(list) with torch.no_grad(): for input_ids in sentence_lst: tokenized_ = [vocab_dict.get(x, vocab_dict["UNK"]) for x in input_ids] input_ids = [0] + tokenized_ + [1] group_lst["word_ids"].append(input_ids) print(group_lst["word_ids"][:2]) if padding_mode == "block": print("padding mode is block") concatenated_examples = {k: sum(group_lst[k], []) for k in group_lst.keys()} total_length = len(concatenated_examples[list(group_lst.keys())[0]]) block_size = seqlen total_length = (total_length // block_size) * block_size # Split by chunks of max_len. group_lst = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } elif padding_mode == "pad": print("padding mode is pad") max_length = seqlen group_lst["word_ids"] = _collate_batch_helper( group_lst["word_ids"], vocab_dict["PAD"], max_length ) for input_ids in group_lst["word_ids"]: if data_args.experiment.startswith("random"): hidden_state = model(torch.tensor(input_ids)) elif data_args.experiment == "gpt2_pre_compress": input_ids2 = torch.tensor(input_ids).to(model.device) input_embs = model.transformer.wte(input_ids2) # input_embs hidden_state = model.down_proj(input_embs) hidden_state = hidden_state * data_args.emb_scale_factor elif data_args.experiment == "glove": hidden_state = model(torch.tensor(input_ids)) result_train_lst.append( {"input_ids": input_ids, "hidden_states": hidden_state.cpu().tolist()} ) return result_train_lst def load_glove_model(File): print("Loading Glove Model") glove_model = {} with open(File, "r") as f: for line in f: split_line = line.split() word = split_line[0] embedding = torch.tensor(np.array(split_line[1:], dtype=np.float64)) # embedding = np.array(split_line[1:], dtype=np.float64) glove_model[word] = embedding print(f"{len(glove_model)} words loaded!") return glove_model def load_glove(vocab): model = torch.nn.Embedding(len(vocab), 50) glove_model = load_glove_model("predictability/glove/glove.6B.50d.txt") array_lst = [] count_ = 0 for word, idx in vocab.items(): if word in glove_model: array_lst.append(glove_model[word]) else: count_ += 1 array_lst.append(torch.randn(50)) print(f"{count_} out of {len(vocab)} is initialized. ") array_lst = torch.stack(array_lst) print(torch.norm(array_lst, dim=-1).mean()) model.weight.data = array_lst return model def get_corpus_rocstory( data_args, model, image_size, padding_mode="block", split="train", load_vocab=None ): import csv, torch, json from spacy.lang.en import English if data_args.experiment_mode == "lm": if data_args.modality == "roc": pass # print('loading dataset from ROCStory') # nlp = English() # tokenizer = nlp.tokenizer # sentence_lst = [] # print(f'loading from {data_args.roc_train}') # if split == 'train': # print('loading form the TRAIN set') # path = f'{data_args.roc_train}/roc_train.json' # elif split == 'valid': # print('loading form the VALID set') # path = f'{data_args.roc_train}/roc_valid.json' # else: # assert False, "invalid split for ROC dataset" # with open(path, 'r') as roc_reader: # for row in roc_reader: # sentences = json.loads(row)[0].strip() # word_lst = [x.text for x in tokenizer(sentences)] # sentence_lst.append(word_lst) # # with open(data_args.roc_train, 'r') as csvfile: # # roc_reader = csv.reader(csvfile) #delimiter=' ', quotechar='|') # # for row in roc_reader: # # # tokenize. # # sentences = " ".join(row[2:]) # # word_lst = [x.text for x in tokenizer(sentences)] # # sentence_lst.append(word_lst) # # sentence_lst = sentence_lst[1:] # print(sentence_lst[:2]) if data_args.modality == "roc-aug": pass # print('loading dataset from ROCStory') # nlp = English() # tokenizer = nlp.tokenizer # sentence_lst = [] # if split == 'train': # print('loading form the TRAIN set') # path_lst = [f'{data_args.roc_train}/roc_train.json'] # path_lst.append('diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt') # # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc.json') # # path_lst.append('diffusion_lm/improved-diffusion/cache/ar_model_augment_roc2.json') # elif split == 'valid': # print('loading form the VALID set') # path_lst = [f'{data_args.roc_train}/roc_valid.json'] # else: # assert False, "invalid split for ROC dataset" # print(path_lst) # for path in path_lst: # if path.endswith('txt'): # with open(path, 'r') as roc_reader: # for row in roc_reader: # sentences = row.strip() # word_lst = [x.text for x in tokenizer(sentences)] # sentence_lst.append(word_lst) # else: # with open(path, 'r') as roc_reader: # for row in roc_reader: # sentences = json.loads(row)[0].strip() # word_lst = [x.text for x in tokenizer(sentences)] # sentence_lst.append(word_lst) # print(sentence_lst[:2],sentence_lst[-2:], 'dataset size=',len(sentence_lst)) elif data_args.modality == "simple-wiki": pass # print('loading dataset from simple wikipedia') # sentence_lst = [] # with open(data_args.wiki_train, 'r') as ff: # for row in ff: # word_lst = row.lower().split() # sentence_lst.append(word_lst) # print(sentence_lst[:2]) elif data_args.modality == "e2e-tgt": print("loading dataset from simple e2e dataset") sentence_lst = [] nlp = English() tokenizer = nlp.tokenizer if split == "train": print("loading form the TRAIN set") path = ( "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_train.txt" ) # path = f'../{data_args.e2e_train}/src1_train.txt' elif split == "valid": print("loading form the VALID set") path = f"../{data_args.e2e_train}/src1_valid.txt" path = ( "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_valid.txt" ) elif split == "test": print("loading form the TEST set") path = f"../{data_args.e2e_train}/src1_test.txt" path = "/data0/gonghaisong/Diffusion-LM/datasets/e2e_data/src1_test.txt" elif split == "debug": print("loading form the DEBUG set") path = data_args.debug_path import json with open(path, "r") as ff: for line in ff: sentence_lst.append(json.loads(line)[0].split(" ")) sentence_lst = sentence_lst + sentence_lst if split in ["train", "valid", "test"]: with open(path, "r") as ff: for row in ff: word_lst = row.split("||")[1] word_lst = [x.text for x in tokenizer(word_lst)] sentence_lst.append(word_lst) print(sentence_lst[:2]) elif data_args.modality == "yelp": print("loading dataset from simple YelpNLG dataset") sentence_lst = [] nlp = English() tokenizer = nlp.tokenizer if split == "train": print("loading form the TRAIN set") path = f"{data_args.yelp_train}/yelpnlg-train.csv" elif split == "valid": print("loading form the VALID set") path = f"{data_args.yelp_train}/yelpnlg-dev.csv" elif split == "test": print("loading form the TEST set") path = f"{data_args.yelp_train}/yelpnlg-test.csv" if split in ["train", "valid", "test"]: with open(path, "r") as csvfile: yelp_reader = csv.reader(csvfile) # delimiter=' ', quotechar='|') for row in yelp_reader: sentences = row[1] word_lst = [x.text for x in tokenizer(sentences)] sentence_lst.append(word_lst) sentence_lst = sentence_lst[1:] print(sentence_lst[:2]) elif data_args.modality == "commonGen": print("loading dataset from simple YelpNLG dataset") sentence_lst = [] nlp = English() tokenizer = nlp.tokenizer if split == "train": print("loading form the TRAIN set") path = f"{data_args.commonGen_train}/commongen.train.jsonl" elif split == "valid": print("loading form the VALID set") path = f"{data_args.commonGen_train}/commongen.dev.jsonl" elif split == "test": print("loading form the TEST set") path = f"{data_args.commonGen_train}/commongen.test.jsonl" if split in ["train", "valid", "test"]: with open(path, "r") as ff: for line in ff: line = json.loads(line) for sentences in line["scene"]: word_lst = [x.text for x in tokenizer(sentences)] sentence_lst.append(word_lst) print(sentence_lst[:2]) elif data_args.modality == "commonGen-aug": print("loading dataset from simple YelpNLG dataset") sentence_lst = [] nlp = English() tokenizer = nlp.tokenizer if split == "train": print("loading form the TRAIN set") path = f"{data_args.commonGen_train}/commongen.train.jsonl" path_lst = [f"{data_args.roc_train}/roc_train.json"] path_lst.append( "diffusion_lm/improved-diffusion/diff_models/rocstories_gptj.txt" ) elif split == "valid": print("loading form the VALID set") path = f"{data_args.commonGen_train}/commongen.dev.jsonl" path_lst = [] elif split == "test": print("loading form the TEST set") path = f"{data_args.commonGen_train}/commongen.test.jsonl" path_lst = [] if split in ["train", "valid", "test"]: with open(path, "r") as ff: for line in ff: line = json.loads(line) for sentences in line["scene"]: word_lst = [x.text for x in tokenizer(sentences)] sentence_lst.append(word_lst) print(sentence_lst[:2]) import itertools for path in path_lst: if path.endswith("txt"): with open(path, "r") as roc_reader: for row in roc_reader: sentences = row.strip() word_lst = [x.text for x in tokenizer(sentences)] spl = [[]] for x, y in itertools.groupby(word_lst, lambda z: z == "."): spl[-1].extend(y) if x: spl.append([]) sentence_lst.extend(spl[:-1]) else: with open(path, "r") as roc_reader: for row in roc_reader: sentences = json.loads(row)[0].strip() word_lst = [x.text for x in tokenizer(sentences)] spl = [[]] for x, y in itertools.groupby(word_lst, lambda z: z == "."): spl[-1].extend(y) if x: spl.append([]) sentence_lst.extend(spl[:-1]) print(sentence_lst[-2:]) # get tokenizer. if load_vocab is None: counter = Counter() for input_ids in sentence_lst: counter.update(input_ids) if data_args.experiment_mode == "conditional_gen": if data_args.modality == "e2e": print("loading dataset from simple e2e dataset") sentence_lst = [] nlp = English() tokenizer = nlp.tokenizer if split == "train": path = f"{data_args.e2e_train}/src1_train.txt" with open(path, "r") as ff: for row in ff: src_lst, word_lst = row.split("||") word_lst = [x.text for x in tokenizer(word_lst)] src_lst = [x.text for x in tokenizer(src_lst)] sentence_lst.append((src_lst, word_lst)) elif split == "valid": path = f"{data_args.e2e_train}/src1_valid.txt" sentence_lst = read_e2e_files(path, data_args, tokenizer) print(sentence_lst[:2]) # get tokenizer. if load_vocab is None: counter = Counter() for src_ids, input_ids in sentence_lst: counter.update(input_ids) counter.update(src_ids) if load_vocab is None: vocab_dict = {"START": 0, "END": 1, "UNK": 2, "PAD": 3} for k, v in counter.items(): if v > 10: vocab_dict[k] = len(vocab_dict) print(len(counter), len(vocab_dict)) path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json" print(f"save the vocab to {path_save_vocab}") with open(path_save_vocab, "w") as f: json.dump(vocab_dict, f) else: vocab_dict = load_vocab path_save_vocab = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/vocab.json" if not os.path.exists(path_save_vocab): print(f"save the vocab to {path_save_vocab}") if isinstance(vocab_dict, dict): with open(path_save_vocab, "w") as f: json.dump(vocab_dict, f) assert vocab_dict["START"] == 0 elif isinstance(vocab_dict, PreTrainedTokenizerFast): vocab_dict.save_pretrained(data_args.checkpoint_path) else: assert False, "invalid type of vocab_dict" if model is None and data_args.experiment == "random": model = torch.nn.Embedding(len(vocab_dict), data_args.in_channel) print("initializing the random embeddings", model) torch.nn.init.normal_(model.weight) path_save = "/data0/gonghaisong/Diffusion-LM/improved-diffusion/diffusion_models/diff_e2e-tgt_block_rand16_transformer_lr0.0001_0.0_2000_sqrt_Lsimple_h128_s2_d0.1_sd102_xstart_e2e/random_emb.torch" print( f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch" ) torch.save(model.state_dict(), path_save) # path_save = f'{data_args.checkpoint_path}/random_emb.torch' # if not os.path.exists(path_save) and data_args.experiment == 'random': # torch.save(model.state_dict(), path_save) if ( data_args.experiment_mode == "lm" and data_args.modality in ["roc-aug", "roc", "yelp", "commonGen", "commonGen-aug"] and data_args.cache_mode == "no" ): train_dataset = helper_tokenize_stream( sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode ) return train_dataset, model elif data_args.experiment_mode == "lm": result_train_lst = helper_tokenize_encode( sentence_lst, vocab_dict, model, image_size**2, data_args, padding_mode ) elif data_args.experiment_mode == "conditional_gen": result_train_lst = helper_tokenize_encode_cond( sentence_lst, vocab_dict, model, image_size**2, data_args ) return {"train": result_train_lst}, model def write_e2e_corr(prompt_lst, file_dict, corr_path): print(len(prompt_lst)) with open(corr_path, "w") as f: for x in prompt_lst: for line in file_dict[x]: print(" ".join(line), file=f) print("", file=f) def write_e2e_src(prompt_lst, corr_path): with open(corr_path, "w") as f: for x in prompt_lst: print(" ".join(x), file=f) return def read_e2e_files(path, args, tokenizer): file_dict = {} with open(path, "r") as f: for line in f: src_lst, word_lst = line.strip().split("||") tgt = tuple([x.text for x in tokenizer(word_lst)]) src = tuple([x.text for x in tokenizer(src_lst)]) if src not in file_dict: file_dict[src] = [] file_dict[src].append(tgt) temp = "1" prompt_text_dict = file_dict prompt_text_lst = list(prompt_text_dict.keys()) gold_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "gold")) print("gold dir", gold_dir) write_e2e_corr(prompt_text_lst, prompt_text_dict, gold_dir) src_dir = os.path.join(args.out_dir, "{}_{}_{}".format(temp, args.split, "src")) write_e2e_src(prompt_text_lst, src_dir) final_lst = [(xx, prompt_text_dict[xx][0]) for xx in prompt_text_lst] return final_lst def get_corpus_book( data_args, tokenizer, model, image_size, padding_mode="block", split="train", ): max_length = image_size**2 import os assert padding_mode == "block" raw_datasets = load_dataset("bookcorpus") if "validation" not in raw_datasets.keys(): raw_datasets["validation"] = load_dataset( "bookcorpus", split=f"train[:1%]", ) raw_datasets["train"] = load_dataset( "bookcorpus", split=f"train[1%:]", ) print(raw_datasets) column_names = raw_datasets["train"].column_names def tokenize_function(examples): output = tokenizer(examples["text"], add_special_tokens=False) return output tokenized_datasets = raw_datasets.map( tokenize_function, batched=True, num_proc=data_args.preprocessing_num_workers, remove_columns=column_names, load_from_cache_file=True, ) print(tokenized_datasets) block_size = max_length # Main data processing function that will concatenate all texts from our dataset and generate chunks of block_size. def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) if total_length >= block_size: total_length = (total_length // block_size) * block_size result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } return result lm_datasets = tokenized_datasets.map( group_texts, batched=True, num_proc=4, load_from_cache_file=True, desc=f"Grouping texts in chunks of {block_size}", ) print(lm_datasets) if model is None: if data_args.training_mode.startswith("e2e"): print("since its e2e, initialize a dummy embedding") model = torch.nn.Embedding(len(tokenizer), 1) else: model = torch.nn.Embedding(len(tokenizer), data_args.in_channel) print("initializing the random embeddings", model) torch.nn.init.normal_(model.weight) path_save = f"{data_args.checkpoint_path}/random_emb.torch" print( f"save the random encoder to {data_args.checkpoint_path}/random_emb.torch" ) torch.save(model.state_dict(), path_save) if split == "train": return lm_datasets, model else: lm_datasets["train"] = lm_datasets["validation"] return lm_datasets, model class TextDataset(Dataset): def __init__( self, text_datasets, resolution, data_args, model_arch="conv-unet", classes=None, shard=0, num_shards=1, eigen_transform=None, mapping_func=None, model_emb=None, ): super().__init__() self.resolution = resolution self.text_datasets = text_datasets self.length = len(self.text_datasets["train"]) self.model_arch = model_arch self.data_args = data_args print(self.resolution) self.eigen_transform = eigen_transform self.mapping_func = mapping_func self.model_emb = model_emb # self.local_images = image_paths[shard:][::num_shards] # self.local_classes = None if classes is None else classes[shard:][::num_shards] def __len__(self): return self.length def __getitem__(self, idx): # We are not on a new enough PIL to support the `reducing_gap` # argument, which uses BOX downsampling at powers of two first. # Thus, we do it by hand to improve downsample quality. if self.model_arch == "conv-unet": pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'], # dtype=np.float32).reshape(self.resolution, self.resolution, -1) # # print(self.eigen_transform.shape) # if self.eigen_transform is not None: # old_shape = arr.shape # arr = arr.reshape(1, -1) - self.eigen_transform['mean'] # arr = arr @ self.eigen_transform['map'] # arr = arr.reshape(old_shape) # if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0: # arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype) # out_dict = {} # out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids']) # # if self.local_classes is not None: # # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) # # print(out_dict.keys()) # return np.transpose(arr, [2, 0, 1]), out_dict elif self.model_arch == "1d-unet": pass # arr = np.array(self.text_datasets['train'][idx]['hidden_states'], # dtype=np.float32) # seqlen, dim # if self.eigen_transform is not None: # old_shape = arr.shape # arr = arr.reshape(1, -1) - self.eigen_transform['mean'] # arr = arr @ self.eigen_transform['map'] # arr = arr.reshape(old_shape) # if hasattr(self.data_args, 'noise_level') and self.data_args.noise_level > 0: # arr = arr + self.data_args.noise_level * np.random.randn(*arr.shape).astype(arr.dtype) # arr = np.transpose(arr, [1, 0]) # out_dict = {} # out_dict['input_ids'] = np.array(self.text_datasets['train'][idx]['input_ids']) # # out_dict['mapping_func'] = self.mapping_func # # if self.local_classes is not None: # # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) # # print(arr.shape) # return arr, out_dict else: arr = np.array( self.text_datasets["train"][idx]["hidden_states"], dtype=np.float32 ) if self.eigen_transform is not None: old_shape = arr.shape # arr = arr.reshape(1, -1) @ self.eigen_transform arr = arr.reshape(1, -1) - self.eigen_transform["mean"] arr = arr @ self.eigen_transform["map"] arr = arr.reshape(old_shape) if ( hasattr(self.data_args, "noise_level") and self.data_args.noise_level > 0 ): # print(arr.dtype) # print(self.data_args.noise_level, 'using the noise level.') arr = arr + self.data_args.noise_level * np.random.randn( *arr.shape ).astype(arr.dtype) # print(arr.dtype) out_dict = {} out_dict["input_ids"] = np.array( self.text_datasets["train"][idx]["input_ids"] ) # out_dict['mapping_func'] = self.mapping_func if self.data_args.experiment_mode == "conditional_gen": out_dict["src_ids"] = np.array( self.text_datasets["train"][idx]["src_ids"] ) out_dict["src_mask"] = np.array( self.text_datasets["train"][idx]["src_mask"] ) # if self.local_classes is not None: # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) return arr, out_dict # print(arr.dtype) # arr = arr.float() # print(arr.shape) class TextDataset_NoCache(Dataset): def __init__( self, text_datasets, resolution, data_args, model_arch="conv-unet", classes=None, shard=0, num_shards=1, eigen_transform=None, mapping_func=None, model_emb=None, ): super().__init__() self.resolution = resolution self.text_datasets = text_datasets self.length = len(self.text_datasets["train"]) self.model_arch = model_arch self.data_args = data_args print(self.resolution) self.eigen_transform = eigen_transform self.mapping_func = mapping_func self.model_emb = model_emb # self.local_images = image_paths[shard:][::num_shards] # self.local_classes = None if classes is None else classes[shard:][::num_shards] def __len__(self): return self.length def __getitem__(self, idx): # We are not on a new enough PIL to support the `reducing_gap` # argument, which uses BOX downsampling at powers of two first. # Thus, we do it by hand to improve downsample quality. with torch.no_grad(): input_ids = self.text_datasets["train"][idx]["input_ids"] model = self.model_emb if self.data_args.experiment.startswith("random"): hidden_state = model(torch.tensor(input_ids)) elif self.data_args.experiment == "gpt2_pre_compress": input_ids2 = torch.tensor(input_ids).to(model.device) input_embs = model.transformer.wte(input_ids2) # input_embs hidden_state = model.down_proj(input_embs) hidden_state = hidden_state * data_args.emb_scale_factor if self.model_arch == "conv-unet": arr = np.array(hidden_state, dtype=np.float32).reshape( self.resolution, self.resolution, -1 ) # print(self.eigen_transform.shape) if self.eigen_transform is not None: old_shape = arr.shape arr = arr.reshape(1, -1) - self.eigen_transform["mean"] arr = arr @ self.eigen_transform["map"] arr = arr.reshape(old_shape) if ( hasattr(self.data_args, "noise_level") and self.data_args.noise_level > 0 ): arr = arr + self.data_args.noise_level * np.random.randn( *arr.shape ).astype(arr.dtype) out_dict = {} out_dict["input_ids"] = np.array( self.text_datasets["train"][idx]["input_ids"] ) # if self.local_classes is not None: # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) # print(out_dict.keys()) return np.transpose(arr, [2, 0, 1]), out_dict elif self.model_arch == "1d-unet": arr = np.array(hidden_state, dtype=np.float32) # seqlen, dim if self.eigen_transform is not None: old_shape = arr.shape arr = arr.reshape(1, -1) - self.eigen_transform["mean"] arr = arr @ self.eigen_transform["map"] arr = arr.reshape(old_shape) if ( hasattr(self.data_args, "noise_level") and self.data_args.noise_level > 0 ): arr = arr + self.data_args.noise_level * np.random.randn( *arr.shape ).astype(arr.dtype) arr = np.transpose(arr, [1, 0]) out_dict = {} out_dict["input_ids"] = np.array( self.text_datasets["train"][idx]["input_ids"] ) # out_dict['mapping_func'] = self.mapping_func # if self.local_classes is not None: # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) # print(arr.shape) return arr, out_dict else: arr = np.array(hidden_state, dtype=np.float32) if self.eigen_transform is not None: old_shape = arr.shape # arr = arr.reshape(1, -1) @ self.eigen_transform arr = arr.reshape(1, -1) - self.eigen_transform["mean"] arr = arr @ self.eigen_transform["map"] arr = arr.reshape(old_shape) if ( hasattr(self.data_args, "noise_level") and self.data_args.noise_level > 0 ): # print(arr.dtype) # print(self.data_args.noise_level, 'using the noise level.') arr = arr + self.data_args.noise_level * np.random.randn( *arr.shape ).astype(arr.dtype) # print(arr.dtype) out_dict = {} out_dict["input_ids"] = np.array( self.text_datasets["train"][idx]["input_ids"] ) # out_dict['mapping_func'] = self.mapping_func if self.data_args.experiment_mode == "conditional_gen": out_dict["src_ids"] = np.array( self.text_datasets["train"][idx]["src_ids"] ) out_dict["src_mask"] = np.array( self.text_datasets["train"][idx]["src_mask"] ) # if self.local_classes is not None: # out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64) return arr, out_dict def _collate_batch_helper(examples, pad_token_id, max_length, return_mask=False): result = torch.full( [len(examples), max_length], pad_token_id, dtype=torch.int64 ).tolist() mask_ = torch.full( [len(examples), max_length], pad_token_id, dtype=torch.int64 ).tolist() for i, example in enumerate(examples): curr_len = min(len(example), max_length) result[i][:curr_len] = example[:curr_len] mask_[i][:curr_len] = [1] * curr_len if return_mask: return result, mask_ return result def _torch_collate_batch(examples, pad_token_id, max_length): """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary.""" import numpy as np import torch # Tensorize if necessary. if isinstance(examples[0], (list, tuple, np.ndarray)): examples = [torch.tensor(e, dtype=torch.long) for e in examples] # length_of_first = examples[0].size(0) # Check if padding is necessary. # are_tensors_same_length = all(x.size(0) == length_of_first for x in examples) # if are_tensors_same_length and (pad_to_multiple_of is None or length_of_first % pad_to_multiple_of == 0): # return torch.stack(examples, dim=0) # Creating the full tensor and filling it with our data. # max_length = max(x.size(0) for x in examples) # if pad_to_multiple_of is not None and (max_length % pad_to_multiple_of != 0): # max_length = ((max_length // pad_to_multiple_of) + 1) * pad_to_multiple_of result = examples[0].new_full([len(examples), max_length], pad_token_id) for i, example in enumerate(examples): if True: result[i, : example.shape[0]] = example else: result[i, -example.shape[0] :] = example return result