File size: 1,671 Bytes
a48216a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import json
import pickle as pkl
import numpy as np
from sklearn.preprocessing import MultiLabelBinarizer

class BasicPreprocessor(object):
    def __init__(self, data_generator, tokenizer, args):
        self.data_generator = data_generator
        self.tokenizer = tokenizer
        self.args = args
        file_path = os.path.join(args.data_dir, args.data_file)
        if file_path.endswith("pkl"):
            with open(file_path, "rb") as f:
                self.raw_data = pkl.load(f)
            print(self.raw_data[0])
            exit()
        elif file_path.endswith("json"):
            self.raw_data = json.load(open(file_path, "r", encoding="utf-8"))
        self.shuffle()

        self.mlb=MultiLabelBinarizer()
        self.mlb.fit([args.labels])

    def shuffle(self):
        idx=np.arange(len(self.raw_data))
        np.random.shuffle(idx)
        self.raw_data=np.array(self.raw_data)[idx]

    def process(self):
        args = self.args
        data_generator = self.data_generator
        raw_data = self.raw_data
        tokenizer = self.tokenizer
        mlb = self.mlb
        
        if args.test_only:
            train_data = data_generator(raw_data[:1], tokenizer, mlb, 'test', args)
            test_data = data_generator(raw_data, tokenizer, mlb, 'test', args)
            return train_data, test_data
        #只使用90%作为训练集,10%作为测试集,不使用验证集
        train_data = data_generator(raw_data[:int(len(raw_data)*0.9)], tokenizer, mlb, 'train', args)
        test_data = data_generator(raw_data[int(len(raw_data)*0.9):], tokenizer, mlb, 'test', args)

        return train_data, test_data