DeepLearning101 commited on
Commit
4cda815
1 Parent(s): a2fef5f

Upload fusion_siamese.py

Browse files
models/sequence_matching/fusion_siamese.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/4/21 5:30 下午
3
+ # @Author : JianingWang
4
+ # @File : fusion_siamese.py
5
+ from typing import Optional
6
+ import torch
7
+ import numpy as np
8
+ import torch.nn as nn
9
+ from dataclasses import dataclass
10
+ from torch.nn import BCEWithLogitsLoss
11
+ from transformers import MegatronBertModel, MegatronBertPreTrainedModel
12
+ from transformers.file_utils import ModelOutput
13
+ from transformers.models.bert import BertPreTrainedModel, BertModel
14
+ from transformers.activations import ACT2FN
15
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
16
+ from transformers.modeling_outputs import SequenceClassifierOutput
17
+ from loss.focal_loss import FocalLoss
18
+ # from roformer import RoFormerPreTrainedModel, RoFormerModel
19
+
20
+
21
+ class BertPooler(nn.Module):
22
+ def __init__(self, hidden_size, hidden_act):
23
+ super().__init__()
24
+ self.dense = nn.Linear(hidden_size, hidden_size)
25
+ # self.activation = nn.Tanh()
26
+ self.activation = ACT2FN[hidden_act]
27
+ # self.dropout = nn.Dropout(hidden_dropout_prob)
28
+
29
+ def forward(self, features):
30
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
31
+ # x = self.dropout(x)
32
+ x = self.dense(x)
33
+ x = self.activation(x)
34
+ return x
35
+
36
+
37
+ class BertForFusionSiamese(BertPreTrainedModel):
38
+ def __init__(self, config):
39
+ super().__init__(config)
40
+ self.num_labels = config.num_labels
41
+ self.bert = BertModel(config)
42
+ self.hidden_size = config.hidden_size
43
+ self.hidden_act = config.hidden_act
44
+ self.bert_poor = BertPooler(self.hidden_size, self.hidden_act)
45
+ self.dense_1 = nn.Linear(self.hidden_size, self.hidden_size)
46
+ self.dense_2 = nn.Linear(self.hidden_size, self.hidden_size)
47
+
48
+ if hasattr(config, "cls_dropout_rate"):
49
+ cls_dropout_rate = config.cls_dropout_rate
50
+ else:
51
+ cls_dropout_rate = config.hidden_dropout_prob
52
+ self.dropout = nn.Dropout(cls_dropout_rate)
53
+ self.classifier = nn.Linear(3 * self.hidden_size, config.num_labels)
54
+ self.init_weights()
55
+
56
+ def forward(
57
+ self,
58
+ input_ids=None,
59
+ attention_mask=None,
60
+ token_type_ids=None,
61
+ position_ids=None,
62
+ head_mask=None,
63
+ inputs_embeds=None,
64
+ labels=None,
65
+ output_attentions=None,
66
+ output_hidden_states=None,
67
+ return_dict=None,
68
+ pseudo_label=None,
69
+ segment_spans=None,
70
+ pseuso_proba=None
71
+ ):
72
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
73
+ logits, outputs = None, None
74
+ inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids,
75
+ "position_ids": position_ids,
76
+ "head_mask": head_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions,
77
+ "output_hidden_states": output_hidden_states, "return_dict": return_dict}
78
+ inputs = {k: v for k, v in inputs.items() if v is not None}
79
+ outputs = self.bert(**inputs)
80
+ if "sequence_output" in outputs:
81
+ sequence_output = outputs.sequence_output # [bz, seq_len, dim]
82
+ else:
83
+ sequence_output = outputs[0] # [bz, seq_len, dim]
84
+
85
+ cls_output = self.bert_poor(sequence_output) # [bz, dim]
86
+
87
+ if segment_spans is not None:
88
+ # 如果输入的是两个segment,则分别进行平均池化
89
+ seg1_embeddings, seg2_embeddings = list(), list()
90
+ for ei, sentence_embeddings in enumerate(sequence_output):
91
+ # sentence_embedding: [seq_len, dim]
92
+ seg1_start, seg1_end, seg2_start, seg2_end = segment_spans[ei]
93
+ # print("sentence_embeddings[seg1_start, seg1_end].shape=", sentence_embeddings[seg1_start, seg1_end].shape)
94
+ # print("torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape=", torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape)
95
+ seg1_embeddings.append(torch.mean(sentence_embeddings[seg1_start: seg1_end], 0)) # [dim]
96
+ seg2_embeddings.append(torch.mean(sentence_embeddings[seg2_start: seg2_end], 0)) # [dim]
97
+ seg1_embeddings, seg2_embeddings = torch.stack(seg1_embeddings), torch.stack(seg2_embeddings) # [bz, dim]
98
+ # print("seg1_embeddings.shape=", seg1_embeddings.shape)
99
+ seg1_embeddings = self.bert_poor.activation(self.dense_1(seg1_embeddings))
100
+ seg2_embeddings = self.bert_poor.activation(self.dense_1(seg2_embeddings))
101
+ cls_output = torch.cat([cls_output, seg1_embeddings, seg2_embeddings], dim=-1) # [bz, 3*dim]
102
+ # cls_output = cls_output + seg1_embeddings + seg2_embeddings # [bz, dim]
103
+
104
+ pooler_output = self.dropout(cls_output)
105
+ # pooler_output = self.LayerNorm(pooler_output)
106
+ logits = self.classifier(pooler_output)
107
+
108
+ loss = None
109
+ if labels is not None:
110
+
111
+ # loss_fct = FocalLoss()
112
+ loss_fct = CrossEntropyLoss()
113
+ # 伪标签
114
+ if pseudo_label is not None:
115
+ train_logits, pseudo_logits = logits[pseudo_label > 0.9], logits[pseudo_label < 0.1]
116
+ train_labels, pseudo_labels = labels[pseudo_label > 0.9], labels[pseudo_label < 0.1]
117
+ train_loss = loss_fct(train_logits.view(-1, self.num_labels),
118
+ train_labels.view(-1)) if train_labels.nelement() else 0
119
+ pseudo_loss = loss_fct(pseudo_logits.view(-1, self.num_labels),
120
+ pseudo_labels.view(-1)) if pseudo_labels.nelement() else 0
121
+ loss = 0.9 * train_loss + 0.1 * pseudo_loss
122
+ else:
123
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
124
+ return SequenceClassifierOutput(
125
+ loss=loss,
126
+ logits=logits,
127
+ hidden_states=outputs.hidden_states,
128
+ attentions=outputs.attentions,
129
+ )
130
+
131
+
132
+
133
+ class BertForWSC(BertPreTrainedModel):
134
+ def __init__(self, config):
135
+ super().__init__(config)
136
+ self.num_labels = config.num_labels
137
+ self.bert = BertModel(config)
138
+ self.hidden_size = config.hidden_size
139
+ self.hidden_act = config.hidden_act
140
+ self.bert_poor = BertPooler(self.hidden_size, self.hidden_act)
141
+ self.dense_1 = nn.Linear(self.hidden_size, self.hidden_size)
142
+ self.dense_2 = nn.Linear(self.hidden_size, self.hidden_size)
143
+
144
+ if hasattr(config, "cls_dropout_rate"):
145
+ cls_dropout_rate = config.cls_dropout_rate
146
+ else:
147
+ cls_dropout_rate = config.hidden_dropout_prob
148
+ self.dropout = nn.Dropout(cls_dropout_rate)
149
+ self.classifier = nn.Linear(2 * self.hidden_size, config.num_labels)
150
+ self.init_weights()
151
+
152
+ def forward(
153
+ self,
154
+ input_ids=None,
155
+ attention_mask=None,
156
+ token_type_ids=None,
157
+ position_ids=None,
158
+ head_mask=None,
159
+ inputs_embeds=None,
160
+ labels=None,
161
+ output_attentions=None,
162
+ output_hidden_states=None,
163
+ return_dict=None,
164
+ pseudo_label=None,
165
+ span=None,
166
+ pseuso_proba=None
167
+ ):
168
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
169
+ logits, outputs = None, None
170
+ inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "token_type_ids": token_type_ids,
171
+ "position_ids": position_ids,
172
+ "head_mask": head_mask, "inputs_embeds": inputs_embeds, "output_attentions": output_attentions,
173
+ "output_hidden_states": output_hidden_states, "return_dict": return_dict}
174
+ inputs = {k: v for k, v in inputs.items() if v is not None}
175
+ outputs = self.bert(**inputs)
176
+ if "sequence_output" in outputs:
177
+ sequence_output = outputs.sequence_output # [bz, seq_len, dim]
178
+ else:
179
+ sequence_output = outputs[0] # [bz, seq_len, dim]
180
+
181
+ # cls_output = self.bert_poor(sequence_output) # [bz, dim]
182
+
183
+ # 如果输入的是两个span,则分别进行平均池化
184
+ seg1_embeddings, seg2_embeddings = list(), list()
185
+ # print("span=", span)
186
+ for ei, sentence_embeddings in enumerate(sequence_output):
187
+ # sentence_embedding: [seq_len, dim]
188
+ seg1_start, seg1_end, seg2_start, seg2_end = span[ei]
189
+ # print("sentence_embeddings[seg1_start, seg1_end].shape=", sentence_embeddings[seg1_start, seg1_end].shape)
190
+ # print("torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape=", torch.mean(sentence_embeddings[seg1_start, seg1_end], 0).shape)
191
+ seg1_embeddings.append(torch.mean(sentence_embeddings[seg1_start+1: seg1_end], 0)) # [dim]
192
+ seg2_embeddings.append(torch.mean(sentence_embeddings[seg2_start+1: seg2_end], 0)) # [dim]
193
+ seg1_embeddings, seg2_embeddings = torch.stack(seg1_embeddings), torch.stack(seg2_embeddings) # [bz, dim]
194
+ # print("seg1_embeddings.shape=", seg1_embeddings.shape)
195
+ # seg1_embeddings = self.bert_poor.activation(self.dense_1(seg1_embeddings))
196
+ # seg2_embeddings = self.bert_poor.activation(self.dense_1(seg2_embeddings))
197
+ cls_output = torch.cat([seg1_embeddings, seg2_embeddings], dim=-1) # [bz, 3*dim]
198
+ # cls_output = cls_output + seg1_embeddings + seg2_embeddings # [bz, dim]
199
+
200
+ pooler_output = self.dropout(cls_output)
201
+ # pooler_output = self.LayerNorm(pooler_output)
202
+ logits = self.classifier(pooler_output)
203
+
204
+ loss = None
205
+ if labels is not None:
206
+
207
+ # loss_fct = FocalLoss()
208
+ loss_fct = CrossEntropyLoss()
209
+ # 伪标签
210
+ if pseudo_label is not None:
211
+ train_logits, pseudo_logits = logits[pseudo_label > 0.9], logits[pseudo_label < 0.1]
212
+ train_labels, pseudo_labels = labels[pseudo_label > 0.9], labels[pseudo_label < 0.1]
213
+ train_loss = loss_fct(train_logits.view(-1, self.num_labels),
214
+ train_labels.view(-1)) if train_labels.nelement() else 0
215
+ pseudo_loss = loss_fct(pseudo_logits.view(-1, self.num_labels),
216
+ pseudo_labels.view(-1)) if pseudo_labels.nelement() else 0
217
+ loss = 0.9 * train_loss + 0.1 * pseudo_loss
218
+ else:
219
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
220
+ return SequenceClassifierOutput(
221
+ loss=loss,
222
+ logits=logits,
223
+ hidden_states=outputs.hidden_states,
224
+ attentions=outputs.attentions,
225
+ )