summary / fengshen /utils /convert_py_to_npy.py
fclong's picture
Upload 396 files
8ebda9e
import argparse
import torch
import glob
import os
import numpy as np
class MMapIndexDataset():
def __init__(self, datapath):
self.idxfp = np.load(datapath + '.npy', mmap_mode='r')
self.binfp = np.memmap(datapath + '.bin', dtype='long', mode='r')
def __len__(self):
return self.idxfp.shape[0]
def __getitem__(self, idx):
return self.binfp[self.idxfp[idx, 0]:self.idxfp[idx, 1]]
def convert_py_to_npy(input_tensor, bin_out, idx_out):
idx = torch.empty(len(input_tensor), 2, dtype=torch.long)
start = 0
for i, input in enumerate(input_tensor):
idx[i] = torch.tensor([start, start + len(input)])
start += len(input)
np.save(idx_out, idx)
binfp = np.memmap(bin_out, dtype='long', mode='w+', shape=(start))
start = 0
for i, input in enumerate(input_tensor):
for j, idx in enumerate(input):
binfp[start + j] = idx
start += len(input)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="Text infilling.")
parser.add_argument('--data_path', type=str,
default='/cognitive_comp/gaoxinyu/data/wudao')
args = parser.parse_args()
process_key = [
'incorrect_input_ids_list',
'label_ids_list',
'target_ids_list',
]
if os.path.exists(args.data_path):
print(f'''Loading data from {args.data_path}''')
data_dict = torch.load(args.data_path)
for k in process_key:
bin_out = ('_' + k + '.bin').join(args.data_path.rsplit('.pt', 1))
idx_out = ('_' + k).join(args.data_path.rsplit('.pt', 1))
convert_py_to_npy(data_dict[k], bin_out, idx_out)
else:
print(
f'Please create the synthetic datafile {args.data_path} with create_synthetic_data.py.')