DeepLearning101 commited on
Commit
f4b6e70
1 Parent(s): 6c0ee22

Upload 2 files

Browse files
models/sequence_labeling/head_token_cls.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel
5
+ from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel
6
+ from transformers.models.albert.modeling_albert import AlbertPreTrainedModel, AlbertModel
7
+ from transformers.models.megatron_bert.modeling_megatron_bert import MegatronBertPreTrainedModel, MegatronBertModel
8
+ from transformers.modeling_outputs import TokenClassifierOutput
9
+ from torch.nn import CrossEntropyLoss
10
+ from loss.focal_loss import FocalLoss
11
+ from loss.label_smoothing import LabelSmoothingCrossEntropy
12
+ from models.basic_modules.crf import CRF
13
+ from tools.model_utils.parameter_freeze import ParameterFreeze
14
+
15
+ from tools.runner_utils.log_util import logging
16
+ logger = logging.getLogger(__name__)
17
+
18
+ freezer = ParameterFreeze()
19
+
20
+
21
+ """
22
+ BERT for token-level classification with softmax head.
23
+ """
24
+ class BertSoftmaxForSequenceLabeling(BertPreTrainedModel):
25
+ def __init__(self, config):
26
+ super(BertSoftmaxForSequenceLabeling, self).__init__(config)
27
+ self.num_labels = config.num_labels
28
+ self.bert = BertModel(config)
29
+ if self.config.use_freezing:
30
+ self.bert = freezer.freeze_lm(self.bert)
31
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
32
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
33
+ self.loss_type = config.loss_type
34
+ self.init_weights()
35
+
36
+ def forward(
37
+ self,
38
+ input_ids,
39
+ attention_mask=None,
40
+ token_type_ids=None,
41
+ position_ids=None,
42
+ head_mask=None,
43
+ labels=None,
44
+ return_dict=False,
45
+ ):
46
+ outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
47
+ sequence_output = outputs[0]
48
+ sequence_output = self.dropout(sequence_output)
49
+ logits = self.classifier(sequence_output)
50
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
51
+ if labels is not None:
52
+ assert self.loss_type in ["lsr", "focal", "ce"]
53
+ if self.loss_type == "lsr":
54
+ loss_fct = LabelSmoothingCrossEntropy(ignore_index=0)
55
+ elif self.loss_type == "focal":
56
+ loss_fct = FocalLoss(ignore_index=0)
57
+ else:
58
+ loss_fct = CrossEntropyLoss(ignore_index=0)
59
+ # Only keep active parts of the loss
60
+ if attention_mask is not None:
61
+ active_loss = attention_mask.view(-1) == 1
62
+ active_logits = logits.view(-1, self.num_labels)[active_loss]
63
+ active_labels = labels.view(-1)[active_loss]
64
+ loss = loss_fct(active_logits, active_labels)
65
+ else:
66
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
67
+
68
+ if not return_dict:
69
+ outputs = (loss,) + outputs
70
+ return outputs # (loss), scores, (hidden_states), (attentions)
71
+
72
+ return TokenClassifierOutput(
73
+ loss=loss,
74
+ logits=logits,
75
+ hidden_states=outputs.hidden_states,
76
+ attentions=outputs.attentions,
77
+ )
78
+
79
+
80
+ """
81
+ RoBERTa for token-level classification with softmax head.
82
+ """
83
+ class RobertaSoftmaxForSequenceLabeling(RobertaPreTrainedModel):
84
+ def __init__(self, config):
85
+ super(RobertaSoftmaxForSequenceLabeling, self).__init__(config)
86
+ self.num_labels = config.num_labels
87
+ self.roberta = RobertaModel(config)
88
+ if self.config.use_freezing:
89
+ self.roberta = freezer.freeze_lm(self.roberta)
90
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
91
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
92
+ self.loss_type = config.loss_type
93
+ self.init_weights()
94
+
95
+ def forward(
96
+ self,
97
+ input_ids,
98
+ attention_mask=None,
99
+ token_type_ids=None,
100
+ position_ids=None,
101
+ head_mask=None,
102
+ labels=None,
103
+ return_dict=False,
104
+ ):
105
+ outputs = self.roberta(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
106
+ sequence_output = outputs[0]
107
+ sequence_output = self.dropout(sequence_output)
108
+ logits = self.classifier(sequence_output)
109
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
110
+ if labels is not None:
111
+ assert self.loss_type in ["lsr", "focal", "ce"]
112
+ if self.loss_type == "lsr":
113
+ loss_fct = LabelSmoothingCrossEntropy(ignore_index=0)
114
+ elif self.loss_type == "focal":
115
+ loss_fct = FocalLoss(ignore_index=0)
116
+ else:
117
+ loss_fct = CrossEntropyLoss(ignore_index=0)
118
+ # Only keep active parts of the loss
119
+ if attention_mask is not None:
120
+ active_loss = attention_mask.view(-1) == 1
121
+ active_logits = logits.view(-1, self.num_labels)[active_loss]
122
+ active_labels = labels.view(-1)[active_loss]
123
+ loss = loss_fct(active_logits, active_labels)
124
+ else:
125
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
126
+
127
+ if not return_dict:
128
+ outputs = (loss,) + outputs
129
+ return outputs # (loss), scores, (hidden_states), (attentions)
130
+
131
+ return TokenClassifierOutput(
132
+ loss=loss,
133
+ logits=logits,
134
+ hidden_states=outputs.hidden_states,
135
+ attentions=outputs.attentions,
136
+ )
137
+
138
+
139
+ """
140
+ ALBERT for token-level classification with softmax head.
141
+ """
142
+ class AlbertSoftmaxForSequenceLabeling(AlbertPreTrainedModel):
143
+ def __init__(self, config):
144
+ super(AlbertSoftmaxForSequenceLabeling, self).__init__(config)
145
+ self.num_labels = config.num_labels
146
+ self.loss_type = config.loss_type
147
+ self.bert = AlbertModel(config)
148
+ if self.config.use_freezing:
149
+ self.bert = freezer.freeze_lm(self.bert)
150
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
151
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
152
+ self.init_weights()
153
+
154
+ def forward(
155
+ self,
156
+ input_ids,
157
+ attention_mask=None,
158
+ token_type_ids=None,
159
+ position_ids=None,
160
+ head_mask=None,
161
+ labels=None,
162
+ return_dict=False,
163
+ ):
164
+ outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids,
165
+ position_ids=position_ids,head_mask=head_mask)
166
+ sequence_output = outputs[0]
167
+ sequence_output = self.dropout(sequence_output)
168
+ logits = self.classifier(sequence_output)
169
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
170
+ if labels is not None:
171
+ assert self.loss_type in ["lsr", "focal", "ce"]
172
+ if self.loss_type =="lsr":
173
+ loss_fct = LabelSmoothingCrossEntropy(ignore_index=0)
174
+ elif self.loss_type == "focal":
175
+ loss_fct = FocalLoss(ignore_index=0)
176
+ else:
177
+ loss_fct = CrossEntropyLoss(ignore_index=0)
178
+ # Only keep active parts of the loss
179
+ if attention_mask is not None:
180
+ active_loss = attention_mask.view(-1) == 1
181
+ active_logits = logits.view(-1, self.num_labels)[active_loss]
182
+ active_labels = labels.view(-1)[active_loss]
183
+ loss = loss_fct(active_logits, active_labels)
184
+ else:
185
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
186
+
187
+ if not return_dict:
188
+ outputs = (loss,) + outputs
189
+ return outputs # (loss), scores, (hidden_states), (attentions)
190
+
191
+ return TokenClassifierOutput(
192
+ loss=loss,
193
+ logits=logits,
194
+ hidden_states=outputs.hidden_states,
195
+ attentions=outputs.attentions,
196
+ )
197
+
198
+
199
+ """
200
+ MegatronBERT for token-level classification with softmax head.
201
+ """
202
+ class MegatronBertSoftmaxForSequenceLabeling(MegatronBertPreTrainedModel):
203
+ def __init__(self, config):
204
+ super(MegatronBertSoftmaxForSequenceLabeling, self).__init__(config)
205
+ self.num_labels = config.num_labels
206
+ self.bert = MegatronBertModel(config)
207
+ if self.config.use_freezing:
208
+ self.bert = freezer.freeze_lm(self.bert)
209
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
210
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
211
+ self.loss_type = config.loss_type
212
+ self.init_weights()
213
+
214
+ def forward(
215
+ self,
216
+ input_ids,
217
+ attention_mask=None,
218
+ token_type_ids=None,
219
+ position_ids=None,
220
+ head_mask=None,
221
+ labels=None,
222
+ return_dict=False,
223
+ ):
224
+ outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
225
+ sequence_output = outputs[0]
226
+ sequence_output = self.dropout(sequence_output)
227
+ logits = self.classifier(sequence_output)
228
+ outputs = (logits,) + outputs[2:] # add hidden states and attention if they are here
229
+ if labels is not None:
230
+ assert self.loss_type in ["lsr", "focal", "ce"]
231
+ if self.loss_type == "lsr":
232
+ loss_fct = LabelSmoothingCrossEntropy(ignore_index=0)
233
+ elif self.loss_type == "focal":
234
+ loss_fct = FocalLoss(ignore_index=0)
235
+ else:
236
+ loss_fct = CrossEntropyLoss(ignore_index=0)
237
+ # Only keep active parts of the loss
238
+ if attention_mask is not None:
239
+ active_loss = attention_mask.view(-1) == 1
240
+ active_logits = logits.view(-1, self.num_labels)[active_loss]
241
+ active_labels = labels.view(-1)[active_loss]
242
+ loss = loss_fct(active_logits, active_labels)
243
+ else:
244
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
245
+
246
+ if not return_dict:
247
+ outputs = (loss,) + outputs
248
+ return outputs # (loss), scores, (hidden_states), (attentions)
249
+
250
+ return TokenClassifierOutput(
251
+ loss=loss,
252
+ logits=logits,
253
+ hidden_states=outputs.hidden_states,
254
+ attentions=outputs.attentions,
255
+ )
256
+
257
+
258
+ """
259
+ BERT for token-level classification with CRF head.
260
+ """
261
+ class BertCrfForSequenceLabeling(BertPreTrainedModel):
262
+ def __init__(self, config):
263
+ super(BertCrfForSequenceLabeling, self).__init__(config)
264
+ self.bert = BertModel(config)
265
+ if self.config.use_freezing:
266
+ self.bert = freezer.freeze_lm(self.bert)
267
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
268
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
269
+ self.crf = CRF(num_tags=config.num_labels, batch_first=True)
270
+ self.init_weights()
271
+
272
+ def forward(
273
+ self,
274
+ input_ids,
275
+ attention_mask=None,
276
+ token_type_ids=None,
277
+ position_ids=None,
278
+ head_mask=None,
279
+ labels=None,
280
+ return_dict=False,
281
+ ):
282
+ outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
283
+ sequence_output = outputs[0]
284
+ sequence_output = self.dropout(sequence_output)
285
+ logits = self.classifier(sequence_output)
286
+ outputs = (logits,)
287
+ if labels is not None:
288
+ loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
289
+ outputs =(-1*loss,)+outputs
290
+
291
+ if not return_dict:
292
+ return outputs # (loss), scores, (hidden_states), (attentions)
293
+
294
+ return TokenClassifierOutput(
295
+ loss=loss,
296
+ logits=logits,
297
+ hidden_states=outputs.hidden_states,
298
+ attentions=outputs.attentions,
299
+ )
300
+
301
+
302
+ """
303
+ RoBERTa for token-level classification with CRF head.
304
+ """
305
+ class RobertaCrfForSequenceLabeling(RobertaPreTrainedModel):
306
+ def __init__(self, config):
307
+ super(RobertaCrfForSequenceLabeling, self).__init__(config)
308
+ self.roberta = RobertaModel(config)
309
+ if self.config.use_freezing:
310
+ self.roberta = freezer.freeze_lm(self.roberta)
311
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
312
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
313
+ self.crf = CRF(num_tags=config.num_labels, batch_first=True)
314
+ self.init_weights()
315
+
316
+ def forward(
317
+ self,
318
+ input_ids,
319
+ attention_mask=None,
320
+ token_type_ids=None,
321
+ position_ids=None,
322
+ head_mask=None,
323
+ labels=None,
324
+ return_dict=False,
325
+ ):
326
+ outputs =self.roberta(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
327
+ sequence_output = outputs[0]
328
+ sequence_output = self.dropout(sequence_output)
329
+ logits = self.classifier(sequence_output)
330
+ outputs = (logits,)
331
+ if labels is not None:
332
+ loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
333
+ outputs =(-1*loss,)+outputs
334
+
335
+ if not return_dict:
336
+ return outputs # (loss), scores, (hidden_states), (attentions)
337
+
338
+ return TokenClassifierOutput(
339
+ loss=loss,
340
+ logits=logits,
341
+ hidden_states=outputs.hidden_states,
342
+ attentions=outputs.attentions,
343
+ )
344
+
345
+
346
+ """
347
+ ALBERT for token-level classification with CRF head.
348
+ """
349
+ class AlbertCrfForSequenceLabeling(AlbertPreTrainedModel):
350
+ def __init__(self, config):
351
+ super(AlbertCrfForSequenceLabeling, self).__init__(config)
352
+ self.bert = AlbertModel(config)
353
+ if self.config.use_freezing:
354
+ self.bert = freezer.freeze_lm(self.bert)
355
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
356
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
357
+ self.crf = CRF(num_tags=config.num_labels, batch_first=True)
358
+ self.init_weights()
359
+
360
+ def forward(
361
+ self,
362
+ input_ids,
363
+ attention_mask=None,
364
+ token_type_ids=None,
365
+ position_ids=None,
366
+ head_mask=None,
367
+ labels=None,
368
+ return_dict=False,
369
+ ):
370
+ outputs = self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
371
+ sequence_output = outputs[0]
372
+ sequence_output = self.dropout(sequence_output)
373
+ logits = self.classifier(sequence_output)
374
+ outputs = (logits,)
375
+ if labels is not None:
376
+ loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
377
+ outputs =(-1*loss,)+outputs
378
+
379
+ if not return_dict:
380
+ return outputs # (loss), scores, (hidden_states), (attentions)
381
+
382
+ return TokenClassifierOutput(
383
+ loss=loss,
384
+ logits=logits,
385
+ hidden_states=outputs.hidden_states,
386
+ attentions=outputs.attentions,
387
+ )
388
+
389
+
390
+ """
391
+ MegatronBERT for token-level classification with CRF head.
392
+ """
393
+ class MegatronBertCrfForSequenceLabeling(MegatronBertPreTrainedModel):
394
+ def __init__(self, config):
395
+ super(MegatronBertCrfForSequenceLabeling, self).__init__(config)
396
+ self.bert = MegatronBertModel(config)
397
+ if self.config.use_freezing:
398
+ self.bert = freezer.freeze_lm(self.bert)
399
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
400
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
401
+ self.crf = CRF(num_tags=config.num_labels, batch_first=True)
402
+ self.init_weights()
403
+
404
+ def forward(
405
+ self,
406
+ input_ids,
407
+ attention_mask=None,
408
+ token_type_ids=None,
409
+ position_ids=None,
410
+ head_mask=None,
411
+ labels=None,
412
+ return_dict=False,
413
+ ):
414
+ outputs =self.bert(input_ids = input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids)
415
+ sequence_output = outputs[0]
416
+ sequence_output = self.dropout(sequence_output)
417
+ logits = self.classifier(sequence_output)
418
+ outputs = (logits,)
419
+ if labels is not None:
420
+ loss = self.crf(emissions = logits, tags=labels, mask=attention_mask)
421
+ outputs =(-1*loss,)+outputs
422
+
423
+ if not return_dict:
424
+ return outputs # (loss), scores, (hidden_states), (attentions)
425
+
426
+ return TokenClassifierOutput(
427
+ loss=loss,
428
+ logits=logits,
429
+ hidden_states=outputs.hidden_states,
430
+ attentions=outputs.attentions,
431
+ )
models/sequence_labeling/lebert.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from transformers.configuration_bert import BertConfig
2
+ # from transformers import BertPreTrainedModel
3
+ # from transformers.modeling_bert import BertEmbeddings, BertEncoder, BertPooler, BertLayer, BaseModelOutput, BaseModelOutputWithPooling
4
+ # from transformers.modeling_bert import BERT_INPUTS_DOCSTRING, _TOKENIZER_FOR_DOC, _CONFIG_FOR_DOC
5
+
6
+ from transformers.models.bert.modeling_bert import BertConfig, BertPreTrainedModel, BertEmbeddings, \
7
+ BertPooler, BertLayer, BaseModelOutputWithPoolingAndCrossAttentions, BaseModelOutputWithPastAndCrossAttentions
8
+ from transformers.models.bert.modeling_bert import BERT_INPUTS_DOCSTRING, _TOKENIZER_FOR_DOC, _CONFIG_FOR_DOC
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import math
13
+ import os
14
+ import warnings
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple
17
+
18
+ import torch
19
+ import torch.utils.checkpoint
20
+ from torch import nn
21
+ from torch.nn import CrossEntropyLoss, MSELoss
22
+
23
+ from transformers.file_utils import (
24
+ add_code_sample_docstrings,
25
+ add_start_docstrings_to_model_forward,
26
+ )
27
+
28
+ class WordEmbeddingAdapter(nn.Module):
29
+
30
+ def __init__(self, config):
31
+ super(WordEmbeddingAdapter, self).__init__()
32
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
33
+ self.tanh = nn.Tanh()
34
+
35
+ self.linear1 = nn.Linear(config.word_embed_dim, config.hidden_size)
36
+ self.linear2 = nn.Linear(config.hidden_size, config.hidden_size)
37
+
38
+ attn_W = torch.zeros(config.hidden_size, config.hidden_size)
39
+ self.attn_W = nn.Parameter(attn_W)
40
+ self.attn_W.data.normal_(mean=0.0, std=config.initializer_range)
41
+ self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
42
+
43
+ def forward(self, layer_output, word_embeddings, word_mask):
44
+ """
45
+ :param layer_output:bert layer的输出,[b_size, len_input, d_model]
46
+ :param word_embeddings:每个汉字对应的词向量集合,[b_size, len_input, num_word, d_word]
47
+ :param word_mask:每个汉字对应的词向量集合的attention mask, [b_size, len_input, num_word]
48
+ """
49
+
50
+ # transform
51
+ # 将词向量,与字符向量进行维度对齐
52
+ word_outputs = self.linear1(word_embeddings)
53
+ word_outputs = self.tanh(word_outputs)
54
+ word_outputs = self.linear2(word_outputs)
55
+ word_outputs = self.dropout(word_outputs) # word_outputs:[b_size, len_input, num_word, d_model]
56
+ # if type(word_mask) == torch.long:
57
+ word_mask = word_mask.bool()
58
+
59
+ # 计算每个字符向量,与其对应的所有词向量的注意力权重,然后加权求和。采用双线性映射计算注意力权重
60
+ # layer_output = layer_output.unsqueeze(2) # layer_output:[b_size, len_input, 1, d_model]
61
+ socres = torch.matmul(layer_output.unsqueeze(2), self.attn_W) # [b_size, len_input, 1, d_model]
62
+ socres = torch.matmul(socres, torch.transpose(word_outputs, 2, 3)) # [b_size, len_input, 1, num_word]
63
+ socres = socres.squeeze(2) # [b_size, len_input, num_word]
64
+ socres.masked_fill_(word_mask, -1e9) # 将pad的注意力设为很小的数
65
+ socres = F.softmax(socres, dim=-1) # [b_size, len_input, num_word]
66
+ attn = socres.unsqueeze(-1) # [b_size, len_input, num_word, 1]
67
+
68
+ weighted_word_embedding = torch.sum(word_outputs * attn, dim=2) # [N, L, D] # 加权求和,得到每个汉字对应的词向量集合的表示
69
+ layer_output = layer_output + weighted_word_embedding
70
+
71
+ layer_output = self.dropout(layer_output)
72
+ layer_output = self.layer_norm(layer_output)
73
+
74
+ return layer_output
75
+
76
+
77
+ class LEBertModel(BertPreTrainedModel):
78
+ """
79
+
80
+ The model can behave as an encoder (with only self-attention) as well
81
+ as a decoder, in which case a layer of cross-attention is added between
82
+ the self-attention layers, following the architecture described in `Attention is all you need`_ by Ashish Vaswani,
83
+ Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
84
+
85
+ To behave as an decoder the model needs to be initialized with the
86
+ :obj:`is_decoder` argument of the configuration set to :obj:`True`.
87
+ To be used in a Seq2Seq model, the model needs to initialized with both :obj:`is_decoder`
88
+ argument and :obj:`add_cross_attention` set to :obj:`True`; an
89
+ :obj:`encoder_hidden_states` is then expected as an input to the forward pass.
90
+
91
+ .. _`Attention is all you need`:
92
+ https://arxiv.org/abs/1706.03762
93
+
94
+ """
95
+
96
+ def __init__(self, config):
97
+ super().__init__(config)
98
+ self.config = config
99
+
100
+ self.embeddings = BertEmbeddings(config)
101
+ self.encoder = BertEncoder(config)
102
+ self.pooler = BertPooler(config)
103
+
104
+ self.init_weights()
105
+
106
+ def get_input_embeddings(self):
107
+ return self.embeddings.word_embeddings
108
+
109
+ def set_input_embeddings(self, value):
110
+ self.embeddings.word_embeddings = value
111
+
112
+ def _prune_heads(self, heads_to_prune):
113
+ """Prunes heads of the model.
114
+ heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
115
+ See base class PreTrainedModel
116
+ """
117
+ for layer, heads in heads_to_prune.items():
118
+ self.encoder.layer[layer].attention.prune_heads(heads)
119
+
120
+ @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
121
+ @add_code_sample_docstrings(
122
+ processor_class=_TOKENIZER_FOR_DOC,
123
+ checkpoint="bert-base-uncased",
124
+ output_type=BaseModelOutputWithPoolingAndCrossAttentions,
125
+ config_class=_CONFIG_FOR_DOC,
126
+ )
127
+ def forward(
128
+ self,
129
+ input_ids=None,
130
+ attention_mask=None,
131
+ token_type_ids=None,
132
+ word_embeddings=None,
133
+ word_mask=None,
134
+ position_ids=None,
135
+ head_mask=None,
136
+ inputs_embeds=None,
137
+ encoder_hidden_states=None,
138
+ encoder_attention_mask=None,
139
+ output_attentions=None,
140
+ output_hidden_states=None,
141
+ return_dict=None,
142
+ ):
143
+ r"""
144
+ encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`, defaults to :obj:`None`):
145
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
146
+ if the model is configured as a decoder.
147
+ encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
148
+ Mask to avoid performing attention on the padding token indices of the encoder input. This mask
149
+ is used in the cross-attention if the model is configured as a decoder.
150
+ Mask values selected in ``[0, 1]``:
151
+ ``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
152
+ """
153
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
154
+ output_hidden_states = (
155
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
156
+ )
157
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
158
+
159
+ if input_ids is not None and inputs_embeds is not None:
160
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
161
+ elif input_ids is not None:
162
+ input_shape = input_ids.size()
163
+ elif inputs_embeds is not None:
164
+ input_shape = inputs_embeds.size()[:-1]
165
+ else:
166
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
167
+
168
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
169
+
170
+ if attention_mask is None:
171
+ attention_mask = torch.ones(input_shape, device=device)
172
+ if token_type_ids is None:
173
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
174
+
175
+ # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
176
+ # ourselves in which case we just need to make it broadcastable to all heads.
177
+ extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
178
+
179
+ # If a 2D ou 3D attention mask is provided for the cross-attention
180
+ # we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
181
+ if self.config.is_decoder and encoder_hidden_states is not None:
182
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
183
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
184
+ if encoder_attention_mask is None:
185
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
186
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
187
+ else:
188
+ encoder_extended_attention_mask = None
189
+
190
+ # Prepare head mask if needed
191
+ # 1.0 in head_mask indicate we keep the head
192
+ # attention_probs has shape bsz x n_heads x N x N
193
+ # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
194
+ # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
195
+ head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
196
+
197
+ embedding_output = self.embeddings(
198
+ input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
199
+ )
200
+ encoder_outputs = self.encoder(
201
+ embedding_output,
202
+ word_embeddings=word_embeddings,
203
+ word_mask=word_mask,
204
+ attention_mask=extended_attention_mask,
205
+ head_mask=head_mask,
206
+ encoder_hidden_states=encoder_hidden_states,
207
+ encoder_attention_mask=encoder_extended_attention_mask,
208
+ output_attentions=output_attentions,
209
+ output_hidden_states=output_hidden_states,
210
+ return_dict=return_dict,
211
+ )
212
+ sequence_output = encoder_outputs[0]
213
+ pooled_output = self.pooler(sequence_output)
214
+
215
+ if not return_dict:
216
+ return (sequence_output, pooled_output) + encoder_outputs[1:]
217
+
218
+ return BaseModelOutputWithPoolingAndCrossAttentions(
219
+ last_hidden_state=sequence_output,
220
+ pooler_output=pooled_output,
221
+ hidden_states=encoder_outputs.hidden_states,
222
+ attentions=encoder_outputs.attentions,
223
+ )
224
+
225
+
226
+ class BertEncoder(nn.Module):
227
+ def __init__(self, config):
228
+ super().__init__()
229
+ self.config = config
230
+ self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
231
+ self.word_embedding_adapter = WordEmbeddingAdapter(config)
232
+
233
+ def forward(
234
+ self,
235
+ hidden_states,
236
+ word_embeddings,
237
+ word_mask,
238
+ attention_mask=None,
239
+ head_mask=None,
240
+ encoder_hidden_states=None,
241
+ encoder_attention_mask=None,
242
+ past_key_values=None,
243
+ use_cache=None,
244
+ output_attentions=False,
245
+ output_hidden_states=False,
246
+ return_dict=False,
247
+ ):
248
+ all_hidden_states = () if output_hidden_states else None
249
+ all_attentions = () if output_attentions else None
250
+
251
+ next_decoder_cache = () if use_cache else None
252
+ for i, layer_module in enumerate(self.layer):
253
+ if output_hidden_states:
254
+ all_hidden_states = all_hidden_states + (hidden_states,)
255
+
256
+ layer_head_mask = head_mask[i] if head_mask is not None else None
257
+ past_key_value = past_key_values[i] if past_key_values is not None else None
258
+
259
+ if getattr(self.config, "gradient_checkpointing", False):
260
+
261
+ if use_cache:
262
+ # logger.warning(
263
+ # "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
264
+ # )
265
+ use_cache = False
266
+
267
+ def create_custom_forward(module):
268
+ def custom_forward(*inputs):
269
+ return module(*inputs, output_attentions)
270
+
271
+ return custom_forward
272
+
273
+ layer_outputs = torch.utils.checkpoint.checkpoint(
274
+ create_custom_forward(layer_module),
275
+ hidden_states,
276
+ attention_mask,
277
+ layer_head_mask,
278
+ encoder_hidden_states,
279
+ encoder_attention_mask,
280
+ )
281
+ else:
282
+ layer_outputs = layer_module(
283
+ hidden_states,
284
+ attention_mask,
285
+ layer_head_mask,
286
+ encoder_hidden_states,
287
+ encoder_attention_mask,
288
+ past_key_value,
289
+ output_attentions,
290
+ )
291
+ hidden_states = layer_outputs[0]
292
+ if use_cache:
293
+ next_decoder_cache += (layer_outputs[-1],)
294
+
295
+ if output_attentions:
296
+ all_attentions = all_attentions + (layer_outputs[1],)
297
+
298
+ # 在第i层之后,进行融合
299
+ # if i == self.config.add_layer:
300
+ if i >= int(self.config.add_layer): # edit by wjn
301
+ hidden_states = self.word_embedding_adapter(hidden_states, word_embeddings, word_mask)
302
+
303
+ if output_hidden_states:
304
+ all_hidden_states = all_hidden_states + (hidden_states,)
305
+
306
+ # if not return_dict:
307
+ # return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
308
+ if not return_dict:
309
+ return tuple(
310
+ v
311
+ for v in [
312
+ hidden_states,
313
+ next_decoder_cache,
314
+ all_hidden_states,
315
+ all_attentions,
316
+ # all_cross_attentions,
317
+ ]
318
+ if v is not None
319
+ )
320
+ return BaseModelOutputWithPastAndCrossAttentions(
321
+ last_hidden_state=hidden_states,
322
+ hidden_states=all_hidden_states,
323
+ attentions=all_attentions,
324
+ past_key_values=next_decoder_cache,
325
+ )