File size: 1,671 Bytes
a48216a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import os
import json
import pickle as pkl
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
class BasicPreprocessor(object):
def __init__(self, data_generator, tokenizer, args):
self.data_generator = data_generator
self.tokenizer = tokenizer
self.args = args
file_path = os.path.join(args.data_dir, args.data_file)
if file_path.endswith("pkl"):
with open(file_path, "rb") as f:
self.raw_data = pkl.load(f)
print(self.raw_data[0])
exit()
elif file_path.endswith("json"):
self.raw_data = json.load(open(file_path, "r", encoding="utf-8"))
self.shuffle()
self.mlb=MultiLabelBinarizer()
self.mlb.fit([args.labels])
def shuffle(self):
idx=np.arange(len(self.raw_data))
np.random.shuffle(idx)
self.raw_data=np.array(self.raw_data)[idx]
def process(self):
args = self.args
data_generator = self.data_generator
raw_data = self.raw_data
tokenizer = self.tokenizer
mlb = self.mlb
if args.test_only:
train_data = data_generator(raw_data[:1], tokenizer, mlb, 'test', args)
test_data = data_generator(raw_data, tokenizer, mlb, 'test', args)
return train_data, test_data
#只使用90%作为训练集,10%作为测试集,不使用验证集
train_data = data_generator(raw_data[:int(len(raw_data)*0.9)], tokenizer, mlb, 'train', args)
test_data = data_generator(raw_data[int(len(raw_data)*0.9):], tokenizer, mlb, 'test', args)
return train_data, test_data
|