ping yang commited on
Commit
9bb46b0
·
1 Parent(s): ab756e2

Add application file

Browse files
Files changed (1) hide show
  1. app.py +659 -0
app.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2021 The IDEA Authors. All rights reserved.
3
+
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from logging import basicConfig
17
+ import torch
18
+ from torch import nn
19
+ import json
20
+ from tqdm import tqdm
21
+ import os
22
+ import numpy as np
23
+ from transformers import BertTokenizer, AutoTokenizer
24
+ import pytorch_lightning as pl
25
+
26
+ from pytorch_lightning.callbacks import ModelCheckpoint
27
+ from pytorch_lightning import loggers
28
+ from torch.utils.data import Dataset, DataLoader
29
+ from transformers.optimization import get_linear_schedule_with_warmup
30
+ from transformers import BertForMaskedLM, AlbertTokenizer
31
+ from transformers import AutoConfig
32
+ from transformers import MegatronBertForMaskedLM
33
+ import argparse
34
+ import copy
35
+ import streamlit as st
36
+ # os.environ["CUDA_VISIBLE_DEVICES"] = '6'
37
+
38
+
39
+ class UniMCDataset(Dataset):
40
+ def __init__(self, data, yes_token, no_token, tokenizer, args, used_mask=True):
41
+ super().__init__()
42
+
43
+ self.tokenizer = tokenizer
44
+ self.max_length = args.max_length
45
+ self.num_labels = args.num_labels
46
+ self.used_mask = used_mask
47
+ self.data = data
48
+ self.args = args
49
+ self.yes_token = yes_token
50
+ self.no_token = no_token
51
+
52
+ def __len__(self):
53
+ return len(self.data)
54
+
55
+ def __getitem__(self, index):
56
+ return self.encode(self.data[index], self.used_mask)
57
+
58
+ def get_token_type(self, sep_idx, max_length):
59
+ token_type_ids = np.zeros(shape=(max_length,))
60
+ for i in range(len(sep_idx)-1):
61
+ if i % 2 == 0:
62
+ ty = np.ones(shape=(sep_idx[i+1]-sep_idx[i],))
63
+ else:
64
+ ty = np.zeros(shape=(sep_idx[i+1]-sep_idx[i],))
65
+ token_type_ids[sep_idx[i]:sep_idx[i+1]] = ty
66
+
67
+ return token_type_ids
68
+
69
+ def get_position_ids(self, label_idx, max_length, question_len):
70
+ question_position_ids = np.arange(question_len)
71
+ label_position_ids = np.arange(question_len, label_idx[-1])
72
+ for i in range(len(label_idx)-1):
73
+ label_position_ids[label_idx[i]-question_len:label_idx[i+1]-question_len] = np.arange(
74
+ question_len, question_len+label_idx[i+1]-label_idx[i])
75
+ max_len_label = max(label_position_ids)
76
+ text_position_ids = np.arange(
77
+ max_len_label+1, max_length+max_len_label+1-label_idx[-1])
78
+ position_ids = list(question_position_ids) + \
79
+ list(label_position_ids)+list(text_position_ids)
80
+ if max_length <= 512:
81
+ return position_ids[:max_length]
82
+ else:
83
+ for i in range(512, max_length):
84
+ if position_ids[i] > 511:
85
+ position_ids[i] = 511
86
+ return position_ids[:max_length]
87
+
88
+ def get_att_mask(self, attention_mask, label_idx, question_len):
89
+ max_length = len(attention_mask)
90
+ attention_mask = np.array(attention_mask)
91
+ attention_mask = np.tile(attention_mask[None, :], (max_length, 1))
92
+
93
+ zeros = np.zeros(
94
+ shape=(label_idx[-1]-question_len, label_idx[-1]-question_len))
95
+ attention_mask[question_len:label_idx[-1],
96
+ question_len:label_idx[-1]] = zeros
97
+
98
+ for i in range(len(label_idx)-1):
99
+ label_token_length = label_idx[i+1]-label_idx[i]
100
+ if label_token_length <= 0:
101
+ print('label_idx', label_idx)
102
+ print('question_len', question_len)
103
+ continue
104
+ ones = np.ones(shape=(label_token_length, label_token_length))
105
+ attention_mask[label_idx[i]:label_idx[i+1],
106
+ label_idx[i]:label_idx[i+1]] = ones
107
+
108
+ return attention_mask
109
+
110
+ def random_masking(self, token_ids, maks_rate, mask_start_idx, max_length, mask_id, tokenizer):
111
+ rands = np.random.random(len(token_ids))
112
+ source, target = [], []
113
+ for i, (r, t) in enumerate(zip(rands, token_ids)):
114
+ if i < mask_start_idx:
115
+ source.append(t)
116
+ target.append(-100)
117
+ continue
118
+ if r < maks_rate * 0.8:
119
+ source.append(mask_id)
120
+ target.append(t)
121
+ elif r < maks_rate * 0.9:
122
+ source.append(t)
123
+ target.append(t)
124
+ elif r < maks_rate:
125
+ source.append(np.random.choice(tokenizer.vocab_size - 1) + 1)
126
+ target.append(t)
127
+ else:
128
+ source.append(t)
129
+ target.append(-100)
130
+ while len(source) < max_length:
131
+ source.append(0)
132
+ target.append(-100)
133
+ return source[:max_length], target[:max_length]
134
+
135
+ def encode(self, item, used_mask=False):
136
+
137
+ while len(self.tokenizer.encode('[MASK]'.join(item['choice']))) > self.max_length-32:
138
+ item['choice'] = [c[:int(len(c)/2)] for c in item['choice']]
139
+
140
+ if 'textb' in item.keys() and item['textb'] != '':
141
+ if 'question' in item.keys() and item['question'] != '':
142
+ texta = '[MASK]' + '[MASK]'.join(item['choice']) + '[SEP]' + \
143
+ item['question'] + '[SEP]' + \
144
+ item['texta']+'[SEP]'+item['textb']
145
+ else:
146
+ texta = '[MASK]' + '[MASK]'.join(item['choice']) + '[SEP]' + \
147
+ item['texta']+'[SEP]'+item['textb']
148
+
149
+ else:
150
+ if 'question' in item.keys() and item['question'] != '':
151
+ texta = '[MASK]' + '[MASK]'.join(item['choice']) + '[SEP]' + \
152
+ item['question'] + '[SEP]' + item['texta']
153
+ else:
154
+ texta = '[MASK]' + '[MASK]'.join(item['choice']) + \
155
+ '[SEP]' + item['texta']
156
+
157
+ encode_dict = self.tokenizer.encode_plus(texta,
158
+ max_length=self.max_length,
159
+ padding='max_length',
160
+ truncation='longest_first')
161
+
162
+ encode_sent = encode_dict['input_ids']
163
+ token_type_ids = encode_dict['token_type_ids']
164
+ attention_mask = encode_dict['attention_mask']
165
+ sample_max_length = sum(encode_dict['attention_mask'])
166
+
167
+ if 'label' not in item.keys():
168
+ item['label'] = 0
169
+ item['answer'] = ''
170
+
171
+ question_len = 1
172
+ label_idx = [question_len]
173
+ for choice in item['choice']:
174
+ cur_mask_idx = label_idx[-1] + \
175
+ len(self.tokenizer.encode(choice, add_special_tokens=False))+1
176
+ label_idx.append(cur_mask_idx)
177
+
178
+ token_type_ids = [0]*question_len+[1] * \
179
+ (label_idx[-1]-label_idx[0]+1)+[0]*self.max_length
180
+ token_type_ids = token_type_ids[:self.max_length]
181
+
182
+ attention_mask = self.get_att_mask(
183
+ attention_mask, label_idx, question_len)
184
+
185
+ position_ids = self.get_position_ids(
186
+ label_idx, self.max_length, question_len)
187
+
188
+ clslabels_mask = np.zeros(shape=(len(encode_sent),))
189
+ clslabels_mask[label_idx[:-1]] = 10000
190
+ clslabels_mask = clslabels_mask-10000
191
+
192
+ mlmlabels_mask = np.zeros(shape=(len(encode_sent),))
193
+ mlmlabels_mask[label_idx[0]] = 1
194
+
195
+ used_mask = False
196
+ if used_mask:
197
+ mask_rate = 0.1*np.random.choice(4, p=[0.3, 0.3, 0.25, 0.15])
198
+ source, target = self.random_masking(token_ids=encode_sent, maks_rate=mask_rate,
199
+ mask_start_idx=label_idx[-1], max_length=self.max_length,
200
+ mask_id=self.tokenizer.mask_token_id, tokenizer=self.tokenizer)
201
+ else:
202
+ source, target = encode_sent[:], encode_sent[:]
203
+
204
+ source = np.array(source)
205
+ target = np.array(target)
206
+ source[label_idx[:-1]] = self.tokenizer.mask_token_id
207
+ target[label_idx[:-1]] = self.no_token
208
+ target[label_idx[item['label']]] = self.yes_token
209
+
210
+ input_ids = source[:sample_max_length]
211
+ token_type_ids = token_type_ids[:sample_max_length]
212
+ attention_mask = attention_mask[:sample_max_length, :sample_max_length]
213
+ position_ids = position_ids[:sample_max_length]
214
+ mlmlabels = target[:sample_max_length]
215
+ clslabels = label_idx[item['label']]
216
+ clslabels_mask = clslabels_mask[:sample_max_length]
217
+ mlmlabels_mask = mlmlabels_mask[:sample_max_length]
218
+
219
+ return {
220
+ "input_ids": torch.tensor(input_ids).long(),
221
+ "token_type_ids": torch.tensor(token_type_ids).long(),
222
+ "attention_mask": torch.tensor(attention_mask).float(),
223
+ "position_ids": torch.tensor(position_ids).long(),
224
+ "mlmlabels": torch.tensor(mlmlabels).long(),
225
+ "clslabels": torch.tensor(clslabels).long(),
226
+ "clslabels_mask": torch.tensor(clslabels_mask).float(),
227
+ "mlmlabels_mask": torch.tensor(mlmlabels_mask).float(),
228
+ }
229
+
230
+
231
+ class UniMCDataModel(pl.LightningDataModule):
232
+ @staticmethod
233
+ def add_data_specific_args(parent_args):
234
+ parser = parent_args.add_argument_group('TASK NAME DataModel')
235
+ parser.add_argument('--num_workers', default=8, type=int)
236
+ parser.add_argument('--batchsize', default=16, type=int)
237
+ parser.add_argument('--max_length', default=512, type=int)
238
+ return parent_args
239
+
240
+ def __init__(self, train_data, val_data, yes_token, no_token, tokenizer, args):
241
+ super().__init__()
242
+ self.batchsize = args.batchsize
243
+
244
+ self.train_data = UniMCDataset(
245
+ train_data, yes_token, no_token, tokenizer, args, True)
246
+ self.valid_data = UniMCDataset(
247
+ val_data, yes_token, no_token, tokenizer, args, False)
248
+
249
+ def train_dataloader(self):
250
+ return DataLoader(self.train_data, shuffle=True, collate_fn=self.collate_fn, batch_size=self.batchsize, pin_memory=False)
251
+
252
+ def val_dataloader(self):
253
+ return DataLoader(self.valid_data, shuffle=False, collate_fn=self.collate_fn, batch_size=self.batchsize, pin_memory=False)
254
+
255
+ def collate_fn(self, batch):
256
+ '''
257
+ Aggregate a batch data.
258
+ batch = [ins1_dict, ins2_dict, ..., insN_dict]
259
+ batch_data = {'sentence':[ins1_sentence, ins2_sentence...], 'input_ids':[ins1_input_ids, ins2_input_ids...], ...}
260
+ '''
261
+ batch_data = {}
262
+ for key in batch[0]:
263
+ batch_data[key] = [example[key] for example in batch]
264
+
265
+ batch_data['input_ids'] = nn.utils.rnn.pad_sequence(batch_data['input_ids'],
266
+ batch_first=True,
267
+ padding_value=0)
268
+ batch_data['clslabels_mask'] = nn.utils.rnn.pad_sequence(batch_data['clslabels_mask'],
269
+ batch_first=True,
270
+ padding_value=-10000)
271
+
272
+ batch_size, batch_max_length = batch_data['input_ids'].shape
273
+ for k, v in batch_data.items():
274
+ if k == 'input_ids' or k == 'clslabels_mask':
275
+ continue
276
+ if k == 'clslabels':
277
+ batch_data[k] = torch.tensor(v).long()
278
+ continue
279
+ if k != 'attention_mask':
280
+ batch_data[k] = nn.utils.rnn.pad_sequence(v,
281
+ batch_first=True,
282
+ padding_value=0)
283
+ else:
284
+ attention_mask = torch.zeros(
285
+ (batch_size, batch_max_length, batch_max_length))
286
+ for i, att in enumerate(v):
287
+ sample_length, _ = att.shape
288
+ attention_mask[i, :sample_length, :sample_length] = att
289
+ batch_data[k] = attention_mask
290
+ return batch_data
291
+
292
+
293
+ class UniMCModel(nn.Module):
294
+ def __init__(self, pre_train_dir, yes_token):
295
+ super().__init__()
296
+ self.config = AutoConfig.from_pretrained(pre_train_dir)
297
+ if self.config.model_type == 'megatron-bert':
298
+ self.bert = MegatronBertForMaskedLM.from_pretrained(pre_train_dir)
299
+ else:
300
+ self.bert = BertForMaskedLM.from_pretrained(pre_train_dir)
301
+
302
+ self.loss_func = torch.nn.CrossEntropyLoss()
303
+ self.yes_token = yes_token
304
+
305
+ def forward(self, input_ids, attention_mask, token_type_ids, position_ids=None, mlmlabels=None, clslabels=None, clslabels_mask=None, mlmlabels_mask=None):
306
+
307
+ batch_size, seq_len = input_ids.shape
308
+ outputs = self.bert(input_ids=input_ids,
309
+ attention_mask=attention_mask,
310
+ position_ids=position_ids,
311
+ token_type_ids=token_type_ids,
312
+ labels=mlmlabels) # (bsz, seq, dim)
313
+ mask_loss = outputs.loss
314
+ mlm_logits = outputs.logits
315
+ cls_logits = mlm_logits[:, :,
316
+ self.yes_token].view(-1, seq_len)+clslabels_mask
317
+
318
+ if mlmlabels == None:
319
+ return 0, mlm_logits, cls_logits
320
+ else:
321
+ cls_loss = self.loss_func(cls_logits, clslabels)
322
+ all_loss = mask_loss+cls_loss
323
+ return all_loss, mlm_logits, cls_logits
324
+
325
+
326
+ class UniMCLitModel(pl.LightningModule):
327
+
328
+ @staticmethod
329
+ def add_model_specific_args(parent_args):
330
+ parser = parent_args.add_argument_group('BaseModel')
331
+
332
+ parser.add_argument('--learning_rate', default=1e-5, type=float)
333
+ parser.add_argument('--weight_decay', default=0.1, type=float)
334
+ parser.add_argument('--warmup', default=0.01, type=float)
335
+ parser.add_argument('--num_labels', default=2, type=int)
336
+
337
+ return parent_args
338
+
339
+ def __init__(self, args, yes_token, num_data=100):
340
+ super().__init__()
341
+ self.args = args
342
+ self.num_data = num_data
343
+ self.model = UniMCModel(self.args.pretrained_model_path, yes_token)
344
+
345
+ def setup(self, stage) -> None:
346
+ if stage == 'fit':
347
+ num_gpus = self.trainer.gpus if self.trainer.gpus is not None else 0
348
+ self.total_step = int(self.trainer.max_epochs * self.num_data /
349
+ (max(1, num_gpus) * self.trainer.accumulate_grad_batches))
350
+ print('Total training step:', self.total_step)
351
+
352
+ def training_step(self, batch, batch_idx):
353
+ loss, logits, cls_logits = self.model(**batch)
354
+ cls_acc = self.comput_metrix(
355
+ cls_logits, batch['clslabels'], batch['mlmlabels_mask'])
356
+ self.log('train_loss', loss)
357
+ self.log('train_acc', cls_acc)
358
+ return loss
359
+
360
+ def validation_step(self, batch, batch_idx):
361
+ loss, logits, cls_logits = self.model(**batch)
362
+ cls_acc = self.comput_metrix(
363
+ cls_logits, batch['clslabels'], batch['mlmlabels_mask'])
364
+ self.log('val_loss', loss)
365
+ self.log('val_acc', cls_acc)
366
+
367
+ def configure_optimizers(self):
368
+
369
+ no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
370
+ paras = list(
371
+ filter(lambda p: p[1].requires_grad, self.named_parameters()))
372
+ paras = [{
373
+ 'params':
374
+ [p for n, p in paras if not any(nd in n for nd in no_decay)],
375
+ 'weight_decay': self.args.weight_decay
376
+ }, {
377
+ 'params': [p for n, p in paras if any(nd in n for nd in no_decay)],
378
+ 'weight_decay': 0.0
379
+ }]
380
+ optimizer = torch.optim.AdamW(paras, lr=self.args.learning_rate)
381
+ scheduler = get_linear_schedule_with_warmup(
382
+ optimizer, int(self.total_step * self.args.warmup),
383
+ self.total_step)
384
+
385
+ return [{
386
+ 'optimizer': optimizer,
387
+ 'lr_scheduler': {
388
+ 'scheduler': scheduler,
389
+ 'interval': 'step',
390
+ 'frequency': 1
391
+ }
392
+ }]
393
+
394
+ def comput_metrix(self, logits, labels, mlmlabels_mask):
395
+ logits = torch.nn.functional.softmax(logits, dim=-1)
396
+ logits = torch.argmax(logits, dim=-1)
397
+ y_pred = logits.view(size=(-1,))
398
+ y_true = labels.view(size=(-1,))
399
+ corr = torch.eq(y_pred, y_true).float()
400
+ return torch.sum(corr.float())/labels.size(0)
401
+
402
+
403
+ class TaskModelCheckpoint:
404
+ @staticmethod
405
+ def add_argparse_args(parent_args):
406
+ parser = parent_args.add_argument_group('BaseModel')
407
+
408
+ parser.add_argument('--monitor', default='val_acc', type=str)
409
+ parser.add_argument('--mode', default='max', type=str)
410
+ parser.add_argument('--dirpath', default='./log/', type=str)
411
+ parser.add_argument(
412
+ '--filename', default='model-{epoch:02d}-{val_acc:.4f}', type=str)
413
+ parser.add_argument('--save_top_k', default=3, type=float)
414
+ parser.add_argument('--every_n_epochs', default=1, type=float)
415
+ parser.add_argument('--every_n_train_steps', default=100, type=float)
416
+ parser.add_argument('--save_weights_only', default=True, type=bool)
417
+ return parent_args
418
+
419
+ def __init__(self, args):
420
+ self.callbacks = ModelCheckpoint(monitor=args.monitor,
421
+ save_top_k=args.save_top_k,
422
+ mode=args.mode,
423
+ save_last=True,
424
+ every_n_train_steps=args.every_n_train_steps,
425
+ save_weights_only=args.save_weights_only,
426
+ dirpath=args.dirpath,
427
+ filename=args.filename)
428
+
429
+
430
+ class UniMCPredict:
431
+ def __init__(self, yes_token, no_token, model, tokenizer, args):
432
+ self.tokenizer = tokenizer
433
+ self.args = args
434
+ self.data_model = UniMCDataModel(
435
+ [], [], yes_token, no_token, tokenizer, args)
436
+ self.model = model
437
+
438
+ def predict(self, batch_data):
439
+ batch = [self.data_model.train_data.encode(
440
+ sample) for sample in batch_data]
441
+ batch = self.data_model.collate_fn(batch)
442
+ batch = {k: v.cuda() for k, v in batch.items()}
443
+ _, _, logits = self.model.model(**batch)
444
+ soft_logits = torch.nn.functional.softmax(logits, dim=-1)
445
+ logits = torch.argmax(soft_logits, dim=-1).detach().cpu().numpy()
446
+
447
+ soft_logits = soft_logits.detach().cpu().numpy()
448
+ clslabels_mask = batch['clslabels_mask'].detach(
449
+ ).cpu().numpy().tolist()
450
+ clslabels = batch['clslabels'].detach().cpu().numpy().tolist()
451
+ for i, v in enumerate(batch_data):
452
+ label_idx = [idx for idx, v in enumerate(
453
+ clslabels_mask[i]) if v == 0.]
454
+ label = label_idx.index(logits[i])
455
+ answer = batch_data[i]['choice'][label]
456
+ score = {}
457
+ for c in range(len(batch_data[i]['choice'])):
458
+ score[batch_data[i]['choice'][c]] = float(
459
+ soft_logits[i][label_idx[c]])
460
+
461
+ batch_data[i]['label_ori'] = copy.deepcopy(batch_data[i]['label'])
462
+ batch_data[i]['label'] = label
463
+ batch_data[i]['answer'] = answer
464
+ batch_data[i]['score'] = score
465
+
466
+ return batch_data
467
+
468
+
469
+ class UniMCPipelines:
470
+ @staticmethod
471
+ def pipelines_args(parent_args):
472
+ total_parser = parent_args.add_argument_group("pipelines args")
473
+ total_parser.add_argument(
474
+ '--pretrained_model_path', default='', type=str)
475
+ total_parser.add_argument('--load_checkpoints_path',
476
+ default='', type=str)
477
+ total_parser.add_argument('--train', action='store_true')
478
+ total_parser.add_argument('--language',
479
+ default='chinese', type=str)
480
+
481
+ total_parser = UniMCDataModel.add_data_specific_args(total_parser)
482
+ total_parser = TaskModelCheckpoint.add_argparse_args(total_parser)
483
+ total_parser = UniMCLitModel.add_model_specific_args(total_parser)
484
+ total_parser = pl.Trainer.add_argparse_args(parent_args)
485
+ return parent_args
486
+
487
+ def __init__(self, args):
488
+ self.args = args
489
+ self.checkpoint_callback = TaskModelCheckpoint(args).callbacks
490
+ self.logger = loggers.TensorBoardLogger(save_dir=args.default_root_dir)
491
+ self.trainer = pl.Trainer.from_argparse_args(args,
492
+ logger=self.logger,
493
+ callbacks=[self.checkpoint_callback])
494
+ self.config = AutoConfig.from_pretrained(args.pretrained_model_path)
495
+ if self.config.model_type == 'albert':
496
+ self.tokenizer = AlbertTokenizer.from_pretrained(
497
+ args.pretrained_model_path)
498
+ else:
499
+ if args.language == 'chinese':
500
+ self.tokenizer = BertTokenizer.from_pretrained(
501
+ args.pretrained_model_path)
502
+ else:
503
+ self.tokenizer = AutoTokenizer.from_pretrained(
504
+ args.pretrained_model_path, is_split_into_words=True, add_prefix_space=True)
505
+
506
+ if args.language == 'chinese':
507
+ self.yes_token = self.tokenizer.encode('是')[1]
508
+ self.no_token = self.tokenizer.encode('非')[1]
509
+ else:
510
+ self.yes_token = self.tokenizer.encode('yes')[1]
511
+ self.no_token = self.tokenizer.encode('no')[1]
512
+
513
+ if args.load_checkpoints_path != '':
514
+ self.model = UniMCLitModel.load_from_checkpoint(
515
+ args.load_checkpoints_path, args=args, yes_token=self.yes_token)
516
+ print('load model from: ', args.load_checkpoints_path)
517
+ else:
518
+ self.model = UniMCLitModel(args, yes_token=self.yes_token)
519
+
520
+ def fit(self, train_data, dev_data, process=True):
521
+ if process:
522
+ train_data = self.preprocess(train_data)
523
+ dev_data = self.preprocess(dev_data)
524
+ data_model = UniMCDataModel(
525
+ train_data, dev_data, self.yes_token, self.no_token, self.tokenizer, self.args)
526
+ self.model.num_data = len(train_data)
527
+ self.trainer.fit(self.model, data_model)
528
+
529
+ def predict(self, test_data, cuda=True, process=True):
530
+ if process:
531
+ test_data = self.preprocess(test_data)
532
+
533
+ result = []
534
+ start = 0
535
+ if cuda:
536
+ self.model = self.model.cuda()
537
+ self.model.model.eval()
538
+ predict_model = UniMCPredict(
539
+ self.yes_token, self.no_token, self.model, self.tokenizer, self.args)
540
+ while start < len(test_data):
541
+ batch_data = test_data[start:start+self.args.batchsize]
542
+ start += self.args.batchsize
543
+ batch_result = predict_model.predict(batch_data)
544
+ result.extend(batch_result)
545
+ if process:
546
+ result = self.postprocess(result)
547
+ return result
548
+
549
+ def preprocess(self, data):
550
+
551
+ for i, line in enumerate(data):
552
+ if 'task_type' in line.keys() and line['task_type'] == '语义匹配':
553
+ data[i]['choice'] = ['不能理解为:'+data[i]
554
+ ['textb'], '可以理解为:'+data[i]['textb']]
555
+ # data[i]['question']='怎么理解这段话?'
556
+ data[i]['textb'] = ''
557
+
558
+ if 'task_type' in line.keys() and line['task_type'] == '自然语言推理':
559
+ data[i]['choice'] = ['不能推断出:'+data[i]['textb'],
560
+ '很难推断出:'+data[i]['textb'], '可以推断出:'+data[i]['textb']]
561
+ # data[i]['question']='根据这段话'
562
+ data[i]['textb'] = ''
563
+
564
+ return data
565
+
566
+ def postprocess(self, data):
567
+ for i, line in enumerate(data):
568
+ if 'task_type' in line.keys() and line['task_type'] == '语义匹配':
569
+ data[i]['textb'] = data[i]['choice'][0].replace('不能理解为:', '')
570
+ data[i]['choice'] = ['不相似', '相似']
571
+ ns = {}
572
+ for k, v in data[i]['score'].items():
573
+ if '不能' in k:
574
+ k = '不相似'
575
+ if '可以' in k:
576
+ k = '相似'
577
+ ns[k] = v
578
+ data[i]['score'] = ns
579
+ data[i]['answer'] = data[i]['choice'][data[i]['label']]
580
+
581
+ if 'task_type' in line.keys() and line['task_type'] == '自然语言推理':
582
+ data[i]['textb'] = data[i]['choice'][0].replace('不能推断出:', '')
583
+ data[i]['choice'] = ['矛盾', '自然', '蕴含']
584
+ ns = {}
585
+ for k, v in data[i]['score'].items():
586
+ if '不能' in k:
587
+ k = '矛盾'
588
+ if '很难' in k:
589
+ k = '自然'
590
+ if '可以' in k:
591
+ k = '蕴含'
592
+ ns[k] = v
593
+ data[i]['score'] = ns
594
+ data[i]['answer'] = data[i]['choice'][data[i]['label']]
595
+
596
+ return data
597
+
598
+
599
+ def load_data(data_path):
600
+ with open(data_path, 'r', encoding='utf8') as f:
601
+ lines = f.readlines()
602
+ samples = [json.loads(line) for line in tqdm(lines)]
603
+ return samples
604
+
605
+
606
+ def comp_acc(pred_data, test_data):
607
+ corr = 0
608
+ for i in range(len(pred_data)):
609
+ if pred_data[i]['label'] == test_data[i]['label']:
610
+ corr += 1
611
+ return corr/len(pred_data)
612
+
613
+
614
+ @st.experimental_memo()
615
+ def load_model():
616
+ total_parser = argparse.ArgumentParser("TASK NAME")
617
+ total_parser = UniMCPipelines.pipelines_args(total_parser)
618
+ args = total_parser.parse_args()
619
+
620
+ args.pretrained_model_path = 'IDEA-CCNL/Erlangshen-UniMC-RoBERTa-110M-Chinese'
621
+ args.max_length = 512
622
+ args.batchsize = 8
623
+ args.default_root_dir = './'
624
+
625
+ model = UniMCPipelines(args)
626
+ return model
627
+
628
+
629
+ def main():
630
+
631
+ model = load_model()
632
+
633
+ st.subheader("UniMC Zero-shot 体验")
634
+ st.info("请输入以下信息...")
635
+
636
+ sentences = st.text_area("请输入句子:", """彭于晏不着急,胡歌也不着急,他俩都不着急,那我也不着急""")
637
+ question = st.text_input("请输入问题(不输入问题也可以):", "请问下面的新闻属于哪个类别?")
638
+ choice = st.text_input("输入标签(以中文;分割):", "娱乐;军事;体育;财经")
639
+ choice = choice.split(';')
640
+
641
+ data = [{"texta": sentences,
642
+ "textb": "",
643
+ "question": question,
644
+ "choice": choice,
645
+ "answer": "", "label": 0,
646
+ "id": 0}]
647
+
648
+ if st.button("点击一下,开始预测!"):
649
+ result = model.predict(data, cuda=False)
650
+ st.success("预测成功!")
651
+ st.json(result[0])
652
+ else:
653
+ st.info(
654
+ "**Enter a text** above and **press the button** to predict the category."
655
+ )
656
+
657
+
658
+ if __name__ == "__main__":
659
+ main()