File size: 5,367 Bytes
8ebda9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# coding=utf8
import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoTokenizer


class GPT2QADataset(Dataset):
    '''
    Dataset Used for yuyuan medical qa task.
    Just surpport small datasets, when deal with large datasets it may be slowly.
    for large datasets please use mmapdatasets(doing)
    '''

    def __init__(self, data_path, name, args):
        super().__init__()
        self.tokenizer = AutoTokenizer.from_pretrained(
            args.pretrained_model_path)
        if self.tokenizer.pad_token is None:
            self.tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'})
        self.data_size = os.path.getsize(data_path)/1024/1024/1024
        self.data_type_name = name
        self.data = self.load_data(data_path)
        self.max_seq_length = args.max_seq_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.encode(self.data[index])

    def load_data(self, data_path):
        # 有进度条展示
        if self.data_size <= 5:
            with open(data_path, "rt", encoding='utf8') as f:
                lines = f.readlines()
            total_num = len(lines)
            data_gen = lines
        else:
            data_gen = open(data_path, "rt", encoding='utf8')
            total_num = None

        data = []
        with tqdm(total=total_num, desc=f'{self.data_type_name}处理进度', mininterval=0.3) as bar:
            for idx, line in enumerate(data_gen):
                data.append(self.data_parse(line))
                bar.update()

        if self.data_size > 5:
            data_gen.close()
        return data

    def data_parse(self, line):
        """
        解析不同格式的数据
        """
        dic = eval(line.strip())
        return dic

    def encode(self, item):
        """
        将数据转换成模型训练的输入
        """
        inputs_dict = self.tokenizer.encode_plus(item['Question']+item['answer'],
                                                 max_length=self.max_seq_length, padding='max_length',
                                                 truncation=True, return_tensors='pt')
        target = inputs_dict['input_ids']
        labels = target.clone().detach()
        labels[target == self.tokenizer.pad_token_id] = -100
        return {
            "input_ids": inputs_dict['input_ids'].squeeze(),
            "attention_mask": inputs_dict['attention_mask'].squeeze(),
            "labels": labels.squeeze(),
            "question": item['Question'],
            "answer": item['answer']
        }


class GPT2QADataModel(pl.LightningDataModule):
    @staticmethod
    def add_data_specific_args(parent_args):
        parser = parent_args.add_argument_group('GPT2QADataModel')
        parser.add_argument('--data_dir', type=str, required=True)
        parser.add_argument('--num_workers', default=2, type=int)
        parser.add_argument('--train_data', default='train.txt', type=str)
        parser.add_argument('--valid_data', default='valid.txt', type=str)
        parser.add_argument('--test_data', default='test.txt', type=str)
        parser.add_argument('--train_batchsize', type=int, required=True)
        parser.add_argument('--valid_batchsize', type=int, required=True)
        parser.add_argument('--max_seq_length', default=1024, type=int)
        return parent_args

    def __init__(self, args):
        super().__init__()
        self.args = args
        self.train_batchsize = args.train_batchsize
        self.valid_batchsize = args.valid_batchsize
        if not args.do_eval_only:
            self.train_data = GPT2QADataset(os.path.join(
                args.data_dir, args.train_data), '训练集', args)
            self.valid_data = GPT2QADataset(os.path.join(
                args.data_dir, args.valid_data), '验证集', args)
        self.test_data = GPT2QADataset(os.path.join(
            args.data_dir, args.test_data), '测试集', args)

    def train_dataloader(self):
        return DataLoader(
            self.train_data, shuffle=True,
            batch_size=self.train_batchsize,
            pin_memory=False, num_workers=self.args.num_workers)

    def val_dataloader(self):
        return DataLoader(self.valid_data, shuffle=False,
                          batch_size=self.valid_batchsize,
                          pin_memory=False, num_workers=self.args.num_workers)

    def predict_dataloader(self):
        return DataLoader(self.test_data, shuffle=False,
                          batch_size=self.valid_batchsize, pin_memory=False,
                          num_workers=self.args.num_workers)


if __name__ == '__main__':
    import argparse
    modelfile = '/cognitive_comp/wuziwei/pretrained_model_hf/medical_v2'
    datafile = '/cognitive_comp/wuziwei/task-data/medical_qa/medical_qa_train.txt'
    parser = argparse.ArgumentParser(description='hf test', allow_abbrev=False)
    group = parser.add_argument_group(title='test args')
    group.add_argument('--pretrained-model-path', type=str, default=modelfile,
                       help='Number of transformer layers.')
    group.add_argument('--max-seq-length', type=int, default=1024)
    args = parser.parse_args()

    testml = GPT2QADataset(datafile, 'medical_qa', args=args)

    print(testml[10])