marriage_law_retrieval / preprocessors.py
luciusssss's picture
Upload 22 files
a48216a verified
raw
history blame contribute delete
No virus
1.67 kB
import os
import json
import pickle as pkl
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer
class BasicPreprocessor(object):
def __init__(self, data_generator, tokenizer, args):
self.data_generator = data_generator
self.tokenizer = tokenizer
self.args = args
file_path = os.path.join(args.data_dir, args.data_file)
if file_path.endswith("pkl"):
with open(file_path, "rb") as f:
self.raw_data = pkl.load(f)
print(self.raw_data[0])
exit()
elif file_path.endswith("json"):
self.raw_data = json.load(open(file_path, "r", encoding="utf-8"))
self.shuffle()
self.mlb=MultiLabelBinarizer()
self.mlb.fit([args.labels])
def shuffle(self):
idx=np.arange(len(self.raw_data))
np.random.shuffle(idx)
self.raw_data=np.array(self.raw_data)[idx]
def process(self):
args = self.args
data_generator = self.data_generator
raw_data = self.raw_data
tokenizer = self.tokenizer
mlb = self.mlb
if args.test_only:
train_data = data_generator(raw_data[:1], tokenizer, mlb, 'test', args)
test_data = data_generator(raw_data, tokenizer, mlb, 'test', args)
return train_data, test_data
#只使用90%作为训练集,10%作为测试集,不使用验证集
train_data = data_generator(raw_data[:int(len(raw_data)*0.9)], tokenizer, mlb, 'train', args)
test_data = data_generator(raw_data[int(len(raw_data)*0.9):], tokenizer, mlb, 'test', args)
return train_data, test_data