File size: 5,367 Bytes
8ebda9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# coding=utf8
import os
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import AutoTokenizer
class GPT2QADataset(Dataset):
'''
Dataset Used for yuyuan medical qa task.
Just surpport small datasets, when deal with large datasets it may be slowly.
for large datasets please use mmapdatasets(doing)
'''
def __init__(self, data_path, name, args):
super().__init__()
self.tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_path)
if self.tokenizer.pad_token is None:
self.tokenizer.add_special_tokens({'pad_token': '<|endoftext|>'})
self.data_size = os.path.getsize(data_path)/1024/1024/1024
self.data_type_name = name
self.data = self.load_data(data_path)
self.max_seq_length = args.max_seq_length
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.encode(self.data[index])
def load_data(self, data_path):
# 有进度条展示
if self.data_size <= 5:
with open(data_path, "rt", encoding='utf8') as f:
lines = f.readlines()
total_num = len(lines)
data_gen = lines
else:
data_gen = open(data_path, "rt", encoding='utf8')
total_num = None
data = []
with tqdm(total=total_num, desc=f'{self.data_type_name}处理进度', mininterval=0.3) as bar:
for idx, line in enumerate(data_gen):
data.append(self.data_parse(line))
bar.update()
if self.data_size > 5:
data_gen.close()
return data
def data_parse(self, line):
"""
解析不同格式的数据
"""
dic = eval(line.strip())
return dic
def encode(self, item):
"""
将数据转换成模型训练的输入
"""
inputs_dict = self.tokenizer.encode_plus(item['Question']+item['answer'],
max_length=self.max_seq_length, padding='max_length',
truncation=True, return_tensors='pt')
target = inputs_dict['input_ids']
labels = target.clone().detach()
labels[target == self.tokenizer.pad_token_id] = -100
return {
"input_ids": inputs_dict['input_ids'].squeeze(),
"attention_mask": inputs_dict['attention_mask'].squeeze(),
"labels": labels.squeeze(),
"question": item['Question'],
"answer": item['answer']
}
class GPT2QADataModel(pl.LightningDataModule):
@staticmethod
def add_data_specific_args(parent_args):
parser = parent_args.add_argument_group('GPT2QADataModel')
parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--num_workers', default=2, type=int)
parser.add_argument('--train_data', default='train.txt', type=str)
parser.add_argument('--valid_data', default='valid.txt', type=str)
parser.add_argument('--test_data', default='test.txt', type=str)
parser.add_argument('--train_batchsize', type=int, required=True)
parser.add_argument('--valid_batchsize', type=int, required=True)
parser.add_argument('--max_seq_length', default=1024, type=int)
return parent_args
def __init__(self, args):
super().__init__()
self.args = args
self.train_batchsize = args.train_batchsize
self.valid_batchsize = args.valid_batchsize
if not args.do_eval_only:
self.train_data = GPT2QADataset(os.path.join(
args.data_dir, args.train_data), '训练集', args)
self.valid_data = GPT2QADataset(os.path.join(
args.data_dir, args.valid_data), '验证集', args)
self.test_data = GPT2QADataset(os.path.join(
args.data_dir, args.test_data), '测试集', args)
def train_dataloader(self):
return DataLoader(
self.train_data, shuffle=True,
batch_size=self.train_batchsize,
pin_memory=False, num_workers=self.args.num_workers)
def val_dataloader(self):
return DataLoader(self.valid_data, shuffle=False,
batch_size=self.valid_batchsize,
pin_memory=False, num_workers=self.args.num_workers)
def predict_dataloader(self):
return DataLoader(self.test_data, shuffle=False,
batch_size=self.valid_batchsize, pin_memory=False,
num_workers=self.args.num_workers)
if __name__ == '__main__':
import argparse
modelfile = '/cognitive_comp/wuziwei/pretrained_model_hf/medical_v2'
datafile = '/cognitive_comp/wuziwei/task-data/medical_qa/medical_qa_train.txt'
parser = argparse.ArgumentParser(description='hf test', allow_abbrev=False)
group = parser.add_argument_group(title='test args')
group.add_argument('--pretrained-model-path', type=str, default=modelfile,
help='Number of transformer layers.')
group.add_argument('--max-seq-length', type=int, default=1024)
args = parser.parse_args()
testml = GPT2QADataset(datafile, 'medical_qa', args=args)
print(testml[10])
|