DeepLearning101 commited on
Commit
e95b4e9
1 Parent(s): 08f4077

Upload 3 files

Browse files
models/multiple_choice/duma.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/4/12 12:12 下午
3
+ # @Author : JianingWang
4
+ # @File : duma.py
5
+ import math
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import CrossEntropyLoss
10
+
11
+ from transformers.models.bert.modeling_bert import BertModel, BertPreTrainedModel
12
+ from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaPreTrainedModel
13
+ from transformers.models.albert.modeling_albert import AlbertModel, AlbertPreTrainedModel
14
+ from transformers.models.megatron_bert.modeling_megatron_bert import MegatronBertModel, MegatronBertPreTrainedModel
15
+ from transformers.modeling_outputs import MultipleChoiceModelOutput
16
+
17
+
18
+ def split_context_query(sequence_output, pq_end_pos, input_ids):
19
+ context_max_len = sequence_output.size(1)
20
+ query_max_len = sequence_output.size(1)
21
+ sep_tok_len = 1 # [SEP]
22
+ context_sequence_output = sequence_output.new(
23
+ torch.Size((sequence_output.size(0), context_max_len, sequence_output.size(2)))).zero_()
24
+ query_sequence_output = sequence_output.new_zeros(
25
+ (sequence_output.size(0), query_max_len, sequence_output.size(2)))
26
+ query_attention_mask = sequence_output.new_zeros((sequence_output.size(0), query_max_len))
27
+ context_attention_mask = sequence_output.new_zeros((sequence_output.size(0), context_max_len))
28
+ for i in range(0, sequence_output.size(0)):
29
+ p_end = pq_end_pos[i][0]
30
+ q_end = pq_end_pos[i][1]
31
+ context_sequence_output[i, :min(context_max_len, p_end)] = sequence_output[i, 1: 1 + min(context_max_len, p_end)]
32
+ idx = min(query_max_len, q_end - p_end - sep_tok_len)
33
+ query_sequence_output[i, :idx] = sequence_output[i, p_end + sep_tok_len + 1: p_end + sep_tok_len + 1 + min(q_end - p_end - sep_tok_len, query_max_len)]
34
+ query_attention_mask[i, :idx] = sequence_output.new_ones((1, query_max_len))[0, :idx]
35
+ context_attention_mask[i, : min(context_max_len, p_end)] = sequence_output.new_ones((1, context_max_len))[0, : min(context_max_len, p_end)]
36
+ return context_sequence_output, query_sequence_output, context_attention_mask, query_attention_mask
37
+
38
+
39
+ class BertCoAttention(nn.Module):
40
+ def __init__(self, config):
41
+ super(BertCoAttention, self).__init__()
42
+ if config.hidden_size % config.num_attention_heads != 0:
43
+ raise ValueError(
44
+ "The hidden size (%d) is not a multiple of the number of attention "
45
+ "heads (%d)" % (config.hidden_size, config.num_attention_heads))
46
+ self.output_attentions = config.output_attentions
47
+
48
+ self.num_attention_heads = config.num_attention_heads
49
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
50
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
51
+
52
+ self.query = nn.Linear(config.hidden_size, self.all_head_size)
53
+ self.key = nn.Linear(config.hidden_size, self.all_head_size)
54
+ self.value = nn.Linear(config.hidden_size, self.all_head_size)
55
+
56
+ self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
57
+
58
+ def transpose_for_scores(self, x):
59
+ new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
60
+ x = x.view(*new_x_shape)
61
+ return x.permute(0, 2, 1, 3)
62
+
63
+ def forward(self, context_states, query_states, attention_mask=None, head_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
64
+ mixed_query_layer = self.query(query_states)
65
+
66
+ extended_attention_mask = attention_mask[:, None, None, :]
67
+ # extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
68
+ extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
69
+ attention_mask = extended_attention_mask
70
+
71
+ # If this is instantiated as a cross-attention module, the keys
72
+ # and values come from an encoder; the attention mask needs to be
73
+ # such that the encoder"s padding tokens are not attended to.
74
+ if encoder_hidden_states is not None:
75
+ mixed_key_layer = self.key(encoder_hidden_states)
76
+ mixed_value_layer = self.value(encoder_hidden_states)
77
+ attention_mask = encoder_attention_mask
78
+ else:
79
+ mixed_key_layer = self.key(context_states)
80
+ mixed_value_layer = self.value(context_states)
81
+
82
+ query_layer = self.transpose_for_scores(mixed_query_layer)
83
+ key_layer = self.transpose_for_scores(mixed_key_layer)
84
+ value_layer = self.transpose_for_scores(mixed_value_layer)
85
+
86
+ # Take the dot product between "query" and "key" to get the raw attention scores.
87
+ attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
88
+ attention_scores = attention_scores / math.sqrt(self.attention_head_size)
89
+ if attention_mask is not None:
90
+ # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
91
+ attention_scores = attention_scores + attention_mask
92
+
93
+ # Normalize the attention scores to probabilities.
94
+ attention_probs = nn.Softmax(dim=-1)(attention_scores)
95
+
96
+ # This is actually dropping out entire tokens to attend to, which might
97
+ # seem a bit unusual, but is taken from the original Transformer paper.
98
+ attention_probs = self.dropout(attention_probs)
99
+
100
+ # Mask heads if we want to
101
+ if head_mask is not None:
102
+ attention_probs = attention_probs * head_mask
103
+
104
+ context_layer = torch.matmul(attention_probs, value_layer)
105
+
106
+ context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
107
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
108
+ context_layer = context_layer.view(*new_context_layer_shape)
109
+
110
+ # outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,)
111
+ outputs = context_layer
112
+ return outputs
113
+
114
+
115
+ class BertDUMAForMultipleChoice(BertPreTrainedModel):
116
+
117
+ def __init__(self, config):
118
+ super(BertDUMAForMultipleChoice, self).__init__(config)
119
+
120
+ self.bert = BertModel(config)
121
+ self.classifier_2 = nn.Linear(2 * config.hidden_size, 1)
122
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
123
+ self.bert_att = BertCoAttention(config)
124
+
125
+ self.init_weights()
126
+
127
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
128
+ inputs_embeds=None, labels=None, pq_end_pos=None, iter=1):
129
+ num_choices = input_ids.shape[1]
130
+
131
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1))
132
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
133
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
134
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
135
+ flat_head_mask = head_mask.view(-1, head_mask.size(-1)) if head_mask is not None else None
136
+ flat_inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) if inputs_embeds is not None else None
137
+
138
+ outputs = self.bert(
139
+ input_ids=flat_input_ids,
140
+ attention_mask=flat_attention_mask,
141
+ token_type_ids=flat_token_type_ids,
142
+ position_ids=flat_position_ids,
143
+ head_mask=flat_head_mask,
144
+ inputs_embeds=flat_inputs_embeds
145
+ )
146
+
147
+ sequence_output = outputs[0]
148
+
149
+ pq_end_pos = pq_end_pos.view(-1, pq_end_pos.size(-1))
150
+
151
+ context_sequence_output, query_sequence_output, context_attention_mask, query_attention_mask = \
152
+ split_context_query(sequence_output, pq_end_pos, input_ids)
153
+ for _ in range(0, iter):
154
+ cq_biatt_output = self.bert_att(context_sequence_output, query_sequence_output, context_attention_mask)
155
+ qc_biatt_output = self.bert_att(query_sequence_output, context_sequence_output, query_attention_mask)
156
+
157
+ query_sequence_output = cq_biatt_output
158
+ context_sequence_output = qc_biatt_output
159
+
160
+ cat_output = torch.cat([torch.mean(qc_biatt_output, 1), torch.mean(cq_biatt_output, 1)], 1)
161
+ pooled_output = self.dropout(cat_output)
162
+ logits = self.classifier_2(pooled_output)
163
+
164
+ reshaped_logits = logits.view(-1, num_choices)
165
+
166
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
167
+
168
+ if labels is not None:
169
+ loss_fct = CrossEntropyLoss()
170
+ loss = loss_fct(reshaped_logits, labels)
171
+ outputs = (loss,) + outputs
172
+
173
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
174
+
175
+
176
+ class RobertaDUMAForMultipleChoice(RobertaPreTrainedModel):
177
+
178
+ def __init__(self, config):
179
+ super(RobertaDUMAForMultipleChoice, self).__init__(config)
180
+
181
+ self.roberta = RobertaModel(config)
182
+ self.classifier_2 = nn.Linear(2 * config.hidden_size, 1)
183
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
184
+ self.bert_att = BertCoAttention(config)
185
+
186
+ self.init_weights()
187
+
188
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
189
+ inputs_embeds=None, labels=None, pq_end_pos=None, iter=1):
190
+ num_choices = input_ids.shape[1]
191
+
192
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1))
193
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
194
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
195
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
196
+ flat_head_mask = head_mask.view(-1, head_mask.size(-1)) if head_mask is not None else None
197
+ flat_inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) if inputs_embeds is not None else None
198
+
199
+ outputs = self.roberta(
200
+ input_ids=flat_input_ids,
201
+ attention_mask=flat_attention_mask,
202
+ token_type_ids=flat_token_type_ids,
203
+ position_ids=flat_position_ids,
204
+ head_mask=flat_head_mask,
205
+ inputs_embeds=flat_inputs_embeds
206
+ )
207
+
208
+ sequence_output = outputs[0]
209
+
210
+ pq_end_pos = pq_end_pos.view(-1, pq_end_pos.size(-1))
211
+
212
+ context_sequence_output, query_sequence_output, context_attention_mask, query_attention_mask = \
213
+ split_context_query(sequence_output, pq_end_pos, input_ids)
214
+ for _ in range(0, iter):
215
+ cq_biatt_output = self.bert_att(context_sequence_output, query_sequence_output, context_attention_mask)
216
+ qc_biatt_output = self.bert_att(query_sequence_output, context_sequence_output, query_attention_mask)
217
+
218
+ query_sequence_output = cq_biatt_output
219
+ context_sequence_output = qc_biatt_output
220
+
221
+ cat_output = torch.cat([torch.mean(qc_biatt_output, 1), torch.mean(cq_biatt_output, 1)], 1)
222
+ pooled_output = self.dropout(cat_output)
223
+ logits = self.classifier_2(pooled_output)
224
+
225
+ reshaped_logits = logits.view(-1, num_choices)
226
+
227
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
228
+
229
+ if labels is not None:
230
+ loss_fct = CrossEntropyLoss()
231
+ loss = loss_fct(reshaped_logits, labels)
232
+ outputs = (loss,) + outputs
233
+
234
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
235
+
236
+ class AlbertDUMAForMultipleChoice(AlbertPreTrainedModel):
237
+
238
+ def __init__(self, config):
239
+ super(AlbertDUMAForMultipleChoice, self).__init__(config)
240
+
241
+ self.albert = AlbertModel(config)
242
+ self.classifier_2 = nn.Linear(2 * config.hidden_size, 1)
243
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
244
+ self.bert_att = BertCoAttention(config)
245
+
246
+ self.init_weights()
247
+
248
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
249
+ inputs_embeds=None, labels=None, pq_end_pos=None, iter=1):
250
+ num_choices = input_ids.shape[1]
251
+
252
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1))
253
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
254
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
255
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
256
+ flat_head_mask = head_mask.view(-1, head_mask.size(-1)) if head_mask is not None else None
257
+ flat_inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) if inputs_embeds is not None else None
258
+
259
+ outputs = self.albert(
260
+ input_ids=flat_input_ids,
261
+ attention_mask=flat_attention_mask,
262
+ token_type_ids=flat_token_type_ids,
263
+ position_ids=flat_position_ids,
264
+ head_mask=flat_head_mask,
265
+ inputs_embeds=flat_inputs_embeds
266
+ )
267
+
268
+ sequence_output = outputs[0]
269
+
270
+ pq_end_pos = pq_end_pos.view(-1, pq_end_pos.size(-1))
271
+
272
+ context_sequence_output, query_sequence_output, context_attention_mask, query_attention_mask = \
273
+ split_context_query(sequence_output, pq_end_pos, input_ids)
274
+ for _ in range(0, iter):
275
+ cq_biatt_output = self.bert_att(context_sequence_output, query_sequence_output, context_attention_mask)
276
+ qc_biatt_output = self.bert_att(query_sequence_output, context_sequence_output, query_attention_mask)
277
+
278
+ query_sequence_output = cq_biatt_output
279
+ context_sequence_output = qc_biatt_output
280
+
281
+ cat_output = torch.cat([torch.mean(qc_biatt_output, 1), torch.mean(cq_biatt_output, 1)], 1)
282
+ pooled_output = self.dropout(cat_output)
283
+ logits = self.classifier_2(pooled_output)
284
+
285
+ reshaped_logits = logits.view(-1, num_choices)
286
+
287
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
288
+
289
+ if labels is not None:
290
+ loss_fct = CrossEntropyLoss()
291
+ loss = loss_fct(reshaped_logits, labels)
292
+ outputs = (loss,) + outputs
293
+
294
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
295
+
296
+
297
+ class MegatronDumaForMultipleChoice(MegatronBertPreTrainedModel):
298
+
299
+ def __init__(self, config):
300
+ super(MegatronDumaForMultipleChoice, self).__init__(config)
301
+
302
+ self.bert = MegatronBertModel(config)
303
+ self.classifier_2 = nn.Linear(2 * config.hidden_size, 1)
304
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
305
+ self.bert_att = BertCoAttention(config)
306
+
307
+ self.init_weights()
308
+
309
+ def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None,
310
+ inputs_embeds=None, labels=None, pq_end_pos=None, iter=1):
311
+ num_choices = input_ids.shape[1]
312
+
313
+ flat_input_ids = input_ids.view(-1, input_ids.size(-1))
314
+ flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
315
+ flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
316
+ flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
317
+ flat_head_mask = head_mask.view(-1, head_mask.size(-1)) if head_mask is not None else None
318
+ flat_inputs_embeds = inputs_embeds.view(-1, inputs_embeds.size(-1)) if inputs_embeds is not None else None
319
+
320
+ outputs = self.bert(
321
+ input_ids=flat_input_ids,
322
+ attention_mask=flat_attention_mask,
323
+ token_type_ids=flat_token_type_ids,
324
+ position_ids=flat_position_ids,
325
+ head_mask=flat_head_mask,
326
+ inputs_embeds=flat_inputs_embeds
327
+ )
328
+
329
+ sequence_output = outputs[0]
330
+
331
+ pq_end_pos = pq_end_pos.view(-1, pq_end_pos.size(-1))
332
+
333
+ context_sequence_output, query_sequence_output, context_attention_mask, query_attention_mask = \
334
+ split_context_query(sequence_output, pq_end_pos, input_ids)
335
+ for _ in range(0, iter):
336
+ cq_biatt_output = self.bert_att(context_sequence_output, query_sequence_output, context_attention_mask)
337
+ qc_biatt_output = self.bert_att(query_sequence_output, context_sequence_output, query_attention_mask)
338
+
339
+ query_sequence_output = cq_biatt_output
340
+ context_sequence_output = qc_biatt_output
341
+
342
+ cat_output = torch.cat([torch.mean(qc_biatt_output, 1), torch.mean(cq_biatt_output, 1)], 1)
343
+ pooled_output = self.dropout(cat_output)
344
+ logits = self.classifier_2(pooled_output)
345
+
346
+ reshaped_logits = logits.view(-1, num_choices)
347
+
348
+ outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
349
+
350
+ if labels is not None:
351
+ loss_fct = CrossEntropyLoss()
352
+ loss = loss_fct(reshaped_logits, labels)
353
+ outputs = (loss,) + outputs
354
+
355
+ return outputs # (loss), reshaped_logits, (hidden_states), (attentions)
models/multiple_choice/multiple_choice.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/4/16 12:10 下午
3
+ # @Author : JianingWang
4
+ # @File : multiple_choice.py
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import CrossEntropyLoss
8
+ import torch.nn.functional as F
9
+ # from transformers import MegatronBertPreTrainedModel, MegatronBertModel
10
+ from transformers.models.megatron_bert import MegatronBertPreTrainedModel, MegatronBertModel
11
+ from transformers.modeling_outputs import MultipleChoiceModelOutput
12
+
13
+
14
+ class MegatronBertForMultipleChoice(MegatronBertPreTrainedModel):
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+
18
+ self.bert = MegatronBertModel(config)
19
+ # classifier_dropout = (
20
+ # config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
21
+ # )
22
+ classifier_dropout = 0.2
23
+ self.dropout = nn.Dropout(classifier_dropout)
24
+ self.classifier = nn.Linear(config.hidden_size, 1)
25
+
26
+ # Initialize weights and apply final processing
27
+ self.post_init()
28
+
29
+ def forward(
30
+ self,
31
+ input_ids=None,
32
+ attention_mask=None,
33
+ token_type_ids=None,
34
+ position_ids=None,
35
+ head_mask=None,
36
+ inputs_embeds=None,
37
+ labels=None,
38
+ output_attentions=None,
39
+ output_hidden_states=None,
40
+ return_dict=None,
41
+ pseudo=None
42
+ ):
43
+ r"""
44
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
45
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
46
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
47
+ `input_ids` above)
48
+ """
49
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
50
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
51
+
52
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
53
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
54
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
55
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
56
+ inputs_embeds = (
57
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
58
+ if inputs_embeds is not None
59
+ else None
60
+ )
61
+
62
+ outputs = self.bert(
63
+ input_ids,
64
+ attention_mask=attention_mask,
65
+ token_type_ids=token_type_ids,
66
+ position_ids=position_ids,
67
+ head_mask=head_mask,
68
+ inputs_embeds=inputs_embeds,
69
+ output_attentions=output_attentions,
70
+ output_hidden_states=output_hidden_states,
71
+ return_dict=return_dict,
72
+ )
73
+
74
+ pooled_output = outputs[1] # [batch_size, num_choices, hidden_dim]
75
+
76
+ pooled_output = self.dropout(pooled_output)
77
+ logits = self.classifier(pooled_output) # [batch_size, num_choices, 1]
78
+ reshaped_logits = logits.view(-1, num_choices) # [batch_size, num_choices]
79
+
80
+ loss = None
81
+ if labels is not None:
82
+ if pseudo is None:
83
+ loss_fct = CrossEntropyLoss()
84
+ loss = loss_fct(reshaped_logits, labels)
85
+ else:
86
+ loss_fct = CrossEntropyLoss(reduction="none")
87
+ loss = loss_fct(reshaped_logits, labels)
88
+ weight = 1 - pseudo * 0.9
89
+ loss *= weight
90
+ loss = loss.mean()
91
+
92
+ if not return_dict:
93
+ output = (reshaped_logits,) + outputs[2:]
94
+ return ((loss,) + output) if loss is not None else output
95
+
96
+ return MultipleChoiceModelOutput(
97
+ loss=loss,
98
+ logits=reshaped_logits,
99
+ hidden_states=outputs.hidden_states,
100
+ attentions=outputs.attentions,
101
+ )
102
+
103
+
104
+ class MegatronBertRDropForMultipleChoice(MegatronBertPreTrainedModel):
105
+ def __init__(self, config):
106
+ super().__init__(config)
107
+
108
+ self.bert = MegatronBertModel(config)
109
+ # classifier_dropout = (
110
+ # config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
111
+ # )
112
+ classifier_dropout = 0.2
113
+ self.dropout = nn.Dropout(classifier_dropout)
114
+ self.classifier = nn.Linear(config.hidden_size, 1)
115
+
116
+ # Initialize weights and apply final processing
117
+ self.post_init()
118
+
119
+ def forward(
120
+ self,
121
+ input_ids=None,
122
+ attention_mask=None,
123
+ token_type_ids=None,
124
+ position_ids=None,
125
+ head_mask=None,
126
+ inputs_embeds=None,
127
+ labels=None,
128
+ output_attentions=None,
129
+ output_hidden_states=None,
130
+ return_dict=None,
131
+
132
+ ):
133
+ r"""
134
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
135
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
136
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
137
+ `input_ids` above)
138
+ """
139
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
140
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
141
+
142
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
143
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
144
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
145
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
146
+ inputs_embeds = (
147
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
148
+ if inputs_embeds is not None
149
+ else None
150
+ )
151
+
152
+ logits_list = []
153
+ for i in range(2):
154
+ outputs = self.bert(
155
+ input_ids,
156
+ attention_mask=attention_mask,
157
+ token_type_ids=token_type_ids,
158
+ position_ids=position_ids,
159
+ head_mask=head_mask,
160
+ inputs_embeds=inputs_embeds,
161
+ output_attentions=output_attentions,
162
+ output_hidden_states=output_hidden_states,
163
+ return_dict=return_dict,
164
+ )
165
+ pooled_output = outputs[1]
166
+ pooled_output = self.dropout(pooled_output)
167
+ logits = self.classifier(pooled_output)
168
+ logits_list.append(logits.view(-1, num_choices))
169
+
170
+ loss = None
171
+ alpha = 1.0
172
+ for logits in logits_list:
173
+ if labels is not None:
174
+ loss_fct = CrossEntropyLoss()
175
+ l = loss_fct(logits, labels)
176
+ if loss:
177
+ loss += alpha * l
178
+ else:
179
+ loss = alpha * l
180
+
181
+ if loss is not None:
182
+ p = torch.log_softmax(logits_list[0], dim=-1)
183
+ p_tec = torch.exp(p)
184
+ q = torch.log_softmax(logits_list[-1], dim=-1)
185
+ q_tec = torch.exp(q)
186
+
187
+ kl_loss = F.kl_div(p, q_tec, reduction="none").sum()
188
+ reverse_kl_loss = F.kl_div(q, p_tec, reduction="none").sum()
189
+ loss += 0.5 * (kl_loss + reverse_kl_loss) / 2.
190
+
191
+ return MultipleChoiceModelOutput(
192
+ loss=loss,
193
+ logits=logits_list[0],
194
+ hidden_states=None,
195
+ attentions=None
196
+ )
models/multiple_choice/multiple_choice_tag.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/3/3 7:59 下午
3
+ # @Author : JianingWang
4
+ # @File : multiple_choice.py
5
+ import torch
6
+ from roformer import RoFormerPreTrainedModel, RoFormerModel
7
+ from torch import nn
8
+ from torch.nn import CrossEntropyLoss
9
+
10
+ from transformers import MegatronBertPreTrainedModel, MegatronBertModel
11
+ from transformers.modeling_outputs import MultipleChoiceModelOutput
12
+ from transformers.models.bert import BertPreTrainedModel, BertModel
13
+
14
+
15
+ class BertForTagMultipleChoice(BertPreTrainedModel):
16
+ def __init__(self, config):
17
+ super().__init__(config)
18
+
19
+ self.bert = BertModel(config)
20
+ classifier_dropout = (
21
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
22
+ )
23
+ self.dropout = nn.Dropout(classifier_dropout)
24
+ self.classifier = nn.Linear(config.hidden_size * 2, 1)
25
+
26
+ # Initialize weights and apply final processing
27
+ self.post_init()
28
+
29
+ def forward(
30
+ self,
31
+ input_ids=None,
32
+ attention_mask=None,
33
+ token_type_ids=None,
34
+ position_ids=None,
35
+ head_mask=None,
36
+ inputs_embeds=None,
37
+ labels=None,
38
+ output_attentions=None,
39
+ output_hidden_states=None,
40
+ return_dict=None,
41
+ pseudo=None
42
+ ):
43
+ r"""
44
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
45
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
46
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
47
+ `input_ids` above)
48
+ """
49
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
50
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
51
+
52
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
53
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
54
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
55
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
56
+ inputs_embeds = (
57
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
58
+ if inputs_embeds is not None
59
+ else None
60
+ )
61
+
62
+ outputs = self.bert(
63
+ input_ids,
64
+ attention_mask=attention_mask,
65
+ token_type_ids=token_type_ids,
66
+ position_ids=position_ids,
67
+ head_mask=head_mask,
68
+ inputs_embeds=inputs_embeds,
69
+ output_attentions=output_attentions,
70
+ output_hidden_states=output_hidden_states,
71
+ return_dict=return_dict,
72
+ )
73
+
74
+ w = torch.logical_and(input_ids >= min(self.config.start_token_ids), input_ids <= max(self.config.start_token_ids))
75
+ start_index = w.nonzero()[:, 1].view(-1, 2)
76
+ # <start_entity> + <end_entity> 进分类
77
+ pooled_output = torch.cat([torch.cat([x[y[0], :], x[y[1], :]]).unsqueeze(0) for x, y in zip(outputs.last_hidden_state, start_index)])
78
+
79
+ pooled_output = self.dropout(pooled_output)
80
+ logits = self.classifier(pooled_output)
81
+ reshaped_logits = logits.view(-1, num_choices)
82
+
83
+ loss = None
84
+ if labels is not None:
85
+ if pseudo is None:
86
+ loss_fct = CrossEntropyLoss()
87
+ loss = loss_fct(reshaped_logits, labels)
88
+ else:
89
+ loss_fct = CrossEntropyLoss(reduction="none")
90
+ loss = loss_fct(reshaped_logits, labels)
91
+ weight = 1 - pseudo * 0.9
92
+ loss *= weight
93
+ loss = loss.mean()
94
+
95
+ if not return_dict:
96
+ output = (reshaped_logits,) + outputs[2:]
97
+ return ((loss,) + output) if loss is not None else output
98
+
99
+ return MultipleChoiceModelOutput(
100
+ loss=loss,
101
+ logits=reshaped_logits,
102
+ hidden_states=outputs.hidden_states,
103
+ attentions=outputs.attentions,
104
+ )
105
+
106
+
107
+ class RoFormerForTagMultipleChoice(RoFormerPreTrainedModel):
108
+ def __init__(self, config):
109
+ super().__init__(config)
110
+
111
+ self.roformer = RoFormerModel(config, add_pooling_layer=False)
112
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
113
+ self.classifier = nn.Linear(config.hidden_size * 2, 1)
114
+
115
+ # Initialize weights and apply final processing
116
+ self.post_init()
117
+
118
+ def forward(
119
+ self,
120
+ input_ids=None,
121
+ attention_mask=None,
122
+ token_type_ids=None,
123
+ head_mask=None,
124
+ inputs_embeds=None,
125
+ labels=None,
126
+ output_attentions=None,
127
+ output_hidden_states=None,
128
+ return_dict=None,
129
+ ):
130
+ r"""
131
+ labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
132
+ Labels for computing the multiple choice classification loss. Indices should be in ``[0, ...,
133
+ num_choices-1]`` where :obj:`num_choices` is the size of the second dimension of the input tensors. (See
134
+ :obj:`input_ids` above)
135
+ """
136
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
137
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
138
+
139
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
140
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
141
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
142
+
143
+ inputs_embeds = (
144
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
145
+ if inputs_embeds is not None
146
+ else None
147
+ )
148
+
149
+ outputs = self.roformer(
150
+ input_ids,
151
+ attention_mask=attention_mask,
152
+ token_type_ids=token_type_ids,
153
+ head_mask=head_mask,
154
+ inputs_embeds=inputs_embeds,
155
+ output_attentions=output_attentions,
156
+ output_hidden_states=output_hidden_states,
157
+ return_dict=return_dict,
158
+ )
159
+
160
+ w = torch.logical_and(input_ids >= min(self.config.start_token_ids), input_ids <= max(self.config.start_token_ids))
161
+ start_index = w.nonzero()[:, 1].view(-1, 2)
162
+ # <start_entity> + <end_entity> 进分类
163
+ pooled_output = torch.cat([torch.cat([x[y[0], :], x[y[1], :]]).unsqueeze(0) for x, y in zip(outputs.last_hidden_state, start_index)])
164
+
165
+ pooled_output = self.dropout(pooled_output)
166
+ logits = self.classifier(pooled_output)
167
+ reshaped_logits = logits.view(-1, num_choices)
168
+
169
+ loss = None
170
+ if labels is not None:
171
+ loss_fct = CrossEntropyLoss()
172
+ loss = loss_fct(reshaped_logits, labels)
173
+
174
+ if not return_dict:
175
+ output = (reshaped_logits,) + outputs[2:]
176
+ return ((loss,) + output) if loss is not None else output
177
+
178
+ return MultipleChoiceModelOutput(
179
+ loss=loss,
180
+ logits=reshaped_logits,
181
+ hidden_states=outputs.hidden_states,
182
+ attentions=outputs.attentions,
183
+ )
184
+
185
+
186
+ class MegatronBertForTagMultipleChoice(MegatronBertPreTrainedModel):
187
+ def __init__(self, config):
188
+ super().__init__(config)
189
+
190
+ self.bert = MegatronBertModel(config)
191
+ self.dropout = nn.Dropout(0.2)
192
+ self.classifier = nn.Linear(config.hidden_size * 2, 1)
193
+
194
+ # Initialize weights and apply final processing
195
+ self.post_init()
196
+
197
+ def forward(
198
+ self,
199
+ input_ids=None,
200
+ attention_mask=None,
201
+ token_type_ids=None,
202
+ position_ids=None,
203
+ head_mask=None,
204
+ inputs_embeds=None,
205
+ labels=None,
206
+ output_attentions=None,
207
+ output_hidden_states=None,
208
+ return_dict=None,
209
+ pseudo=None
210
+ ):
211
+ r"""
212
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
213
+ Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
214
+ num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
215
+ `input_ids` above)
216
+ """
217
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
218
+ num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
219
+ input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
220
+ attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
221
+ token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
222
+ position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
223
+ inputs_embeds = (
224
+ inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
225
+ if inputs_embeds is not None
226
+ else None
227
+ )
228
+
229
+ outputs = self.bert(
230
+ input_ids,
231
+ attention_mask=attention_mask,
232
+ token_type_ids=token_type_ids,
233
+ position_ids=position_ids,
234
+ head_mask=head_mask,
235
+ inputs_embeds=inputs_embeds,
236
+ output_attentions=output_attentions,
237
+ output_hidden_states=output_hidden_states,
238
+ return_dict=return_dict,
239
+ )
240
+
241
+ w = torch.logical_and(input_ids >= min(self.config.start_token_ids), input_ids <= max(self.config.start_token_ids))
242
+ start_index = w.nonzero()[:, 1].view(-1, 2)
243
+ # <start_entity> + <end_entity> 进分类
244
+ pooled_output = torch.cat([torch.cat([x[y[0], :], x[y[1], :]]).unsqueeze(0) for x, y in zip(outputs.last_hidden_state, start_index)])
245
+
246
+ pooled_output = self.dropout(pooled_output)
247
+ logits = self.classifier(pooled_output)
248
+ reshaped_logits = logits.view(-1, num_choices)
249
+
250
+ loss = None
251
+ if labels is not None:
252
+ if pseudo is None:
253
+ loss_fct = CrossEntropyLoss()
254
+ loss = loss_fct(reshaped_logits, labels)
255
+ else:
256
+ loss_fct = CrossEntropyLoss(reduction="none")
257
+ loss = loss_fct(reshaped_logits, labels)
258
+ weight = 1 - pseudo*0.9
259
+ loss *= weight
260
+ loss = loss.mean()
261
+
262
+ if not return_dict:
263
+ output = (reshaped_logits,) + outputs[2:]
264
+ return ((loss,) + output) if loss is not None else output
265
+
266
+ return MultipleChoiceModelOutput(
267
+ loss=loss,
268
+ logits=reshaped_logits,
269
+ hidden_states=outputs.hidden_states,
270
+ attentions=outputs.attentions,
271
+ )