DeepLearning101 commited on
Commit
5894ace
1 Parent(s): 3169cc9

Upload 3 files

Browse files
models/language_modeling/causal_lm.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2023/2/16 3:35 下午
3
+ # @Author : JianingWang
4
+ # @File : mlm.py
5
+ import logging
6
+ from typing import Union, Tuple, Optional
7
+ import torch
8
+ import torch.nn as nn
9
+ from tqdm import tqdm
10
+ from typing import Optional, Tuple
11
+ from torch.nn import CrossEntropyLoss
12
+ from transformers import AutoModelForCausalLM
13
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
14
+ from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel, GPT2Model, GPT2PreTrainedModel
15
+
16
+ """
17
+ Function: Use Causal LM to pre-train GPT-2
18
+ Notes:
19
+ - In default, the Causal LM aims to train on all tokens, the label of each token is the next token, which let the model learn in regressive way.
20
+ - If you want to choose some tokens, or mask some tokens (like MLM), the label of non-masked token should be -100, which can be used for cross-entropy function (only calculate loss at not -100)
21
+ """
22
+ class GPT2ForCausalLM(GPT2PreTrainedModel):
23
+ _keys_to_ignore_on_load_missing = [r"attn.masked_bias", r"attn.bias", r"lm_head.weight"]
24
+
25
+ def __init__(self, config):
26
+ super().__init__(config)
27
+ self.transformer = GPT2Model(config)
28
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
29
+
30
+ # Model parallel
31
+ self.model_parallel = False
32
+ self.device_map = None
33
+
34
+ # Initialize weights and apply final processing
35
+ self.post_init()
36
+
37
+ def get_output_embeddings(self):
38
+ return self.lm_head
39
+
40
+ def set_output_embeddings(self, new_embeddings):
41
+ self.lm_head = new_embeddings
42
+
43
+ def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
44
+ token_type_ids = kwargs.get("token_type_ids", None)
45
+ # only last token for inputs_ids if past is defined in kwargs
46
+ if past:
47
+ input_ids = input_ids[:, -1].unsqueeze(-1)
48
+ if token_type_ids is not None:
49
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
50
+
51
+ attention_mask = kwargs.get("attention_mask", None)
52
+ position_ids = kwargs.get("position_ids", None)
53
+
54
+ if attention_mask is not None and position_ids is None:
55
+ # create position_ids on the fly for batch generation
56
+ position_ids = attention_mask.long().cumsum(-1) - 1
57
+ position_ids.masked_fill_(attention_mask == 0, 1)
58
+ if past:
59
+ position_ids = position_ids[:, -1].unsqueeze(-1)
60
+ else:
61
+ position_ids = None
62
+ return {
63
+ "input_ids": input_ids,
64
+ "past_key_values": past,
65
+ "use_cache": kwargs.get("use_cache"),
66
+ "position_ids": position_ids,
67
+ "attention_mask": attention_mask,
68
+ "token_type_ids": token_type_ids,
69
+ }
70
+
71
+ def forward(
72
+ self,
73
+ input_ids: Optional[torch.LongTensor] = None,
74
+ past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
75
+ attention_mask: Optional[torch.FloatTensor] = None,
76
+ token_type_ids: Optional[torch.LongTensor] = None,
77
+ position_ids: Optional[torch.LongTensor] = None,
78
+ head_mask: Optional[torch.FloatTensor] = None,
79
+ inputs_embeds: Optional[torch.FloatTensor] = None,
80
+ encoder_hidden_states: Optional[torch.Tensor] = None,
81
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
82
+ labels: Optional[torch.LongTensor] = None,
83
+ use_cache: Optional[bool] = None,
84
+ output_attentions: Optional[bool] = None,
85
+ output_hidden_states: Optional[bool] = None,
86
+ return_dict: Optional[bool] = None,
87
+ ) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
88
+ r"""
89
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
90
+ Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
91
+ `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
92
+ are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
93
+ """
94
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
95
+
96
+ transformer_outputs = self.transformer(
97
+ input_ids,
98
+ past_key_values=past_key_values,
99
+ attention_mask=attention_mask,
100
+ token_type_ids=token_type_ids,
101
+ position_ids=position_ids,
102
+ head_mask=head_mask,
103
+ inputs_embeds=inputs_embeds,
104
+ encoder_hidden_states=encoder_hidden_states,
105
+ encoder_attention_mask=encoder_attention_mask,
106
+ use_cache=use_cache,
107
+ output_attentions=output_attentions,
108
+ output_hidden_states=output_hidden_states,
109
+ return_dict=return_dict,
110
+ )
111
+ hidden_states = transformer_outputs[0]
112
+
113
+ # Set device for model parallelism
114
+ if self.model_parallel:
115
+ torch.cuda.set_device(self.transformer.first_device)
116
+ hidden_states = hidden_states.to(self.lm_head.weight.device)
117
+
118
+ lm_logits = self.lm_head(hidden_states)
119
+
120
+ loss = None
121
+ if labels is not None:
122
+ # Shift so that tokens < n predict n
123
+ shift_logits = lm_logits[..., :-1, :].contiguous()
124
+ shift_labels = labels[..., 1:].contiguous()
125
+ # print("shift_labels=", shift_labels)
126
+ # Flatten the tokens
127
+ loss_fct = CrossEntropyLoss()
128
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
129
+
130
+ if not return_dict:
131
+ output = (lm_logits,) + transformer_outputs[1:]
132
+ return ((loss,) + output) if loss is not None else output
133
+
134
+ return CausalLMOutputWithCrossAttentions(
135
+ loss=loss,
136
+ logits=lm_logits,
137
+ past_key_values=transformer_outputs.past_key_values,
138
+ hidden_states=transformer_outputs.hidden_states,
139
+ attentions=transformer_outputs.attentions,
140
+ cross_attentions=transformer_outputs.cross_attentions,
141
+ )
142
+
143
+ @staticmethod
144
+ def _reorder_cache(past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
145
+ """
146
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
147
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
148
+ beam_idx at every generation step.
149
+ """
150
+ return tuple(
151
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
152
+ for layer_past in past
153
+ )
154
+
155
+
156
+
157
+ # class GPT2ForCanusalLM(GPT2LMHeadModel):
158
+
159
+ # def __init__(self, config):
160
+ # super().__init__(config)
161
+ # self.transformer = GPT2Model(config)
162
+ # self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
163
+
164
+ # # Model parallel
165
+ # self.model_parallel = False
166
+ # self.device_map = None
167
+
168
+ # # Initialize weights and apply final processing
169
+ # self.post_init()
170
+
171
+ # def forward(
172
+ # self,
173
+ # input_ids: Optional[torch.LongTensor] = None, # input token id
174
+ # past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
175
+ # attention_mask: Optional[torch.FloatTensor] = None,
176
+ # token_type_ids: Optional[torch.LongTensor] = None,
177
+ # labels: Optional[torch.LongTensor] = None,
178
+ # label_masks: Optional[torch.LongTensor] = None, # mask=1 means it should be calculated loss
179
+ # output_attentions=None,
180
+ # output_hidden_states=None,
181
+ # return_dict=None,
182
+ # ):
183
+ # transformer_outputs = self.transformer(
184
+ # input_ids,
185
+ # past_key_values=past_key_values,
186
+ # attention_mask=attention_mask,
187
+ # token_type_ids=token_type_ids,
188
+ # output_attentions=output_attentions,
189
+ # output_hidden_states=output_hidden_states,
190
+ # return_dict=return_dict,
191
+ # )
192
+ # hidden_states = transformer_outputs[0]
193
+ # lm_logits = self.lm_head(hidden_states)
194
+
195
+ # # print("len(input_ids)=", len(input_ids[0]))
196
+ # # print("input_ids[-1]=", input_ids[0][-1])
197
+
198
+ # loss = None
199
+ # if labels is not None:
200
+ # shift_logits = lm_logits[..., :-1, :].contiguous()
201
+ # # print("shift_logits.shape=", shift_logits.shape)
202
+ # if labels is None:
203
+ # labels = input_ids
204
+ # shift_labels = labels[..., 1:].contiguous()
205
+ # # print("shift_labels=", shift_labels)
206
+ # # print("shift_labels.shape=", shift_labels.shape)
207
+ # # Flatten the tokens
208
+ # loss_fct = CrossEntropyLoss(reduction="none")
209
+ # loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) # [batch_size, lngth]
210
+ # label_masks = label_masks[..., 1:].contiguous()
211
+ # # print("loss.shape=", loss.shape)
212
+ # # print("shift_logits.shape=", shift_logits.shape)
213
+ # # print("label_masks.shape=", label_masks.shape)
214
+ # loss = loss.view(shift_logits.size(0), shift_logits.size(1)) * label_masks # [batch_size, length]
215
+ # loss = torch.sum(loss, axis=1) / torch.sum(label_masks, axis=1) # [batch_size]
216
+ # # print("loss=", loss)
217
+ # if not return_dict:
218
+ # output = (lm_logits,) + transformer_outputs[1:]
219
+ # return ((loss,) + output) if loss is not None else output
220
+
221
+ # return CausalLMOutputWithCrossAttentions(
222
+ # loss=loss,
223
+ # logits=lm_logits,
224
+ # past_key_values=transformer_outputs.past_key_values,
225
+ # hidden_states=transformer_outputs.hidden_states,
226
+ # attentions=transformer_outputs.attentions,
227
+ # cross_attentions=transformer_outputs.cross_attentions,
228
+ # )
229
+
230
+
231
+ if __name__ == "__main__":
232
+ from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
233
+ # model_path = "/Users/wangjianing/Desktop/开源代码与数据模型/模型/gpt2"
234
+ model_path = "/wjn/pre-trained-lm/gpt2"
235
+ tokenizer = GPT2Tokenizer.from_pretrained(model_path)
236
+ tokenizer.pad_token_id = tokenizer.eos_token_id
237
+ # print("tokenizer.eos_token_id=", tokenizer.eos_token_id) # 50256
238
+ model = GPT2LMHeadModel.from_pretrained(model_path)
239
+ input_text = "My friend Jack invites me to play computer games with him, but my girl friend doesn't agree. I think"
240
+ inputs = tokenizer(input_text, add_special_tokens=True, return_tensors="pt")
241
+ inputs["labels"] = inputs["input_ids"]
242
+ print("inputs=", inputs)
243
+ """
244
+ inputs= {"input_ids": tensor([[ 3666, 1545, 3619, 27671, 502, 284, 711, 3644, 1830, 351,
245
+ 683, 11, 475, 616, 2576, 1545, 1595, 470, 4236, 13,
246
+ 314, 892, 220]]), "attention_mask": tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]), "labels": tensor([[ 3666, 1545, 3619, 27671, 502, 284, 711, 3644, 1830, 351,
247
+ 683, 11, 475, 616, 2576, 1545, 1595, 470, 4236, 13,
248
+ 314, 892, 220]])}
249
+
250
+ """
251
+ outputs = model(**inputs)
252
+ print("loss=", outputs[0])
253
+ """
254
+ loss= tensor(3.9444, grad_fn=<NllLossBackward0>)
255
+ """
256
+ output_sequences = model.generate(
257
+ **inputs,
258
+ emb_match=None,
259
+ control_code=None,
260
+ past_key_values=None,
261
+ max_length=len(inputs["input_ids"][0]) + 10,
262
+ min_length=5,
263
+ temperature=1.0,
264
+ top_k=1,
265
+ top_p=0.5, #top_p=0.5,
266
+ repetition_penalty=1.0, # 重复词惩罚,用于控制生成多样性的文本
267
+ do_sample=False,
268
+ num_beams=5,
269
+ # bad_words_ids=[[628], [198]] if True else None,
270
+ num_return_sequences=3,
271
+ )
272
+ print("output_sequences=", output_sequences)
273
+ # print("output_sequences=", output_sequences)
274
+ results = tokenizer.decode(output_sequences[0])
275
+ print("results=", results)
276
+ """
277
+ results= My friend Jack invites me to play computer games with him, but my girl friend doesn"t agree. I think  it"s a good idea to play computer games
278
+ """
models/language_modeling/kpplm.py ADDED
@@ -0,0 +1,752 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2022/3/15 21:26
3
+ # @Author : ruihan.wjn
4
+ # @File : pk-plm.py
5
+
6
+ """
7
+ This code is implemented for the paper ""Knowledge Prompting in Pre-trained Langauge Models for Natural Langauge Understanding""
8
+ """
9
+
10
+ from time import time
11
+ import torch
12
+ from torch import nn
13
+ import torch.nn.functional as F
14
+ from torch.nn import CrossEntropyLoss
15
+ from collections import OrderedDict
16
+ from transformers.models.bert import BertPreTrainedModel, BertModel
17
+ from transformers.models.roberta import RobertaModel, RobertaPreTrainedModel, RobertaTokenizer, RobertaForMaskedLM
18
+ from transformers.models.deberta import DebertaModel, DebertaPreTrainedModel, DebertaTokenizer, DebertaForMaskedLM
19
+ from transformers.models.bert.modeling_bert import BertOnlyMLMHead, BertPreTrainingHeads
20
+ from transformers.models.roberta.modeling_roberta import RobertaModel, RobertaLMHead
21
+ from transformers.models.deberta.modeling_deberta import DebertaModel, DebertaLMPredictionHead
22
+
23
+ """
24
+ kg enhanced corpus structure example:
25
+ {
26
+ "token_ids": [20, 46098, 3277, 680, 10, 4066, 278, 9, 11129, 4063, 877, 579, 8, 8750, 14720, 8, 22498, 548,
27
+ 19231, 46098, 3277, 6, 25, 157, 25, 130, 3753, 46098, 3277, 4, 3684, 19809, 10960, 9, 5, 30731, 2788, 914, 5,
28
+ 1675, 8151, 35], "entity_pos": [[8, 11], [13, 15], [26, 27]],
29
+ "entity_qid": ["Q17582", "Q231978", "Q427013"],
30
+ "relation_pos": null,
31
+ "relation_pid": null
32
+ }
33
+ """
34
+
35
+
36
+ from enum import Enum
37
+ class SiameseDistanceMetric(Enum):
38
+ """
39
+ The metric for the contrastive loss
40
+ """
41
+ EUCLIDEAN = lambda x, y: F.pairwise_distance(x, y, p=2)
42
+ MANHATTAN = lambda x, y: F.pairwise_distance(x, y, p=1)
43
+ COSINE_DISTANCE = lambda x, y: 1-F.cosine_similarity(x, y)
44
+
45
+
46
+ class ContrastiveLoss(nn.Module):
47
+ """
48
+ Contrastive loss. Expects as input two texts and a label of either 0 or 1. If the label == 1, then the distance between the
49
+ two embeddings is reduced. If the label == 0, then the distance between the embeddings is increased.
50
+ Further information: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
51
+ :param model: SentenceTransformer model
52
+ :param distance_metric: Function that returns a distance between two emeddings. The class SiameseDistanceMetric contains pre-defined metrices that can be used
53
+ :param margin: Negative samples (label == 0) should have a distance of at least the margin value.
54
+ :param size_average: Average by the size of the mini-batch.
55
+ Example::
56
+ from sentence_transformers import SentenceTransformer, SentencesDataset, LoggingHandler, losses
57
+ from sentence_transformers.readers import InputExample
58
+ model = SentenceTransformer("distilbert-base-nli-mean-tokens")
59
+ train_examples = [InputExample(texts=["This is a positive pair", "Where the distance will be minimized"], label=1),
60
+ InputExample(texts=["This is a negative pair", "Their distance will be increased"], label=0)]
61
+ train_dataset = SentencesDataset(train_examples, model)
62
+ train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
63
+ train_loss = losses.ContrastiveLoss(model=model)
64
+ """
65
+
66
+ def __init__(self, distance_metric=SiameseDistanceMetric.COSINE_DISTANCE, margin: float = 0.5, size_average:bool = True):
67
+ super(ContrastiveLoss, self).__init__()
68
+ self.distance_metric = distance_metric
69
+ self.margin = margin
70
+ self.size_average = size_average
71
+
72
+ def forward(self, sent_embs1, sent_embs2, labels: torch.Tensor):
73
+ rep_anchor, rep_other = sent_embs1, sent_embs2
74
+ distances = self.distance_metric(rep_anchor, rep_other)
75
+ losses = 0.5 * (labels.float() * distances.pow(2) + (1 - labels).float() * F.relu(self.margin - distances).pow(2))
76
+ return losses.mean() if self.size_average else losses.sum()
77
+
78
+
79
+
80
+ class NSPHead(nn.Module):
81
+ def __init__(self, config):
82
+ super().__init__()
83
+ self.seq_relationship = nn.Linear(config.hidden_size, 2)
84
+
85
+ def forward(self, pooled_output):
86
+ seq_relationship_score = self.seq_relationship(pooled_output)
87
+ return seq_relationship_score
88
+
89
+
90
+
91
+ class RoBertaKPPLMForProcessedWikiKGPLM(RobertaForMaskedLM):
92
+
93
+ def __init__(self, config):
94
+ super().__init__(config)
95
+ self.num_labels = config.num_labels
96
+ self.config = config
97
+ # self.roberta = RobertaModel(config)
98
+ try:
99
+ classifier_dropout = (
100
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
101
+ )
102
+ except:
103
+ classifier_dropout = (config.hidden_dropout_prob)
104
+ self.dropout = nn.Dropout(classifier_dropout)
105
+ # self.cls = BertOnlyMLMHead(config)
106
+ # self.lm_head = RobertaLMHead(config) # Masked Language Modeling head
107
+ self.detector = NSPHead(config) # Knowledge Noise Detection head
108
+ self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size)
109
+ self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size)
110
+ # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)])
111
+
112
+ self.contrastive_loss_fn = ContrastiveLoss()
113
+ self.post_init()
114
+
115
+ def forward(
116
+ self,
117
+ input_ids=None,
118
+ attention_mask=None,
119
+ token_type_ids=None,
120
+ position_ids=None,
121
+ head_mask=None,
122
+ inputs_embeds=None,
123
+ encoder_hidden_states=None,
124
+ encoder_attention_mask=None,
125
+ labels=None,
126
+ # entity_label=None,
127
+ entity_candidate=None,
128
+ # relation_label=None,
129
+ relation_candidate=None,
130
+ noise_detect_label=None,
131
+ task_id=None,
132
+ mask_id=None,
133
+ output_attentions=None,
134
+ output_hidden_states=None,
135
+ return_dict=None,
136
+ ):
137
+ # start_time = time()
138
+ mlm_labels = labels
139
+
140
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
141
+ # print("attention_mask.shape=", attention_mask.shape)
142
+ # print("input_ids[0]=", input_ids[0])
143
+ # print("token_type_ids[0]=", token_type_ids[0])
144
+ # attention_mask = None
145
+
146
+ outputs = self.roberta(
147
+ input_ids,
148
+ attention_mask=attention_mask,
149
+ token_type_ids=token_type_ids,
150
+ position_ids=position_ids,
151
+ head_mask=head_mask,
152
+ inputs_embeds=inputs_embeds,
153
+ encoder_hidden_states=encoder_hidden_states,
154
+ encoder_attention_mask=encoder_attention_mask,
155
+ output_attentions=output_attentions,
156
+ output_hidden_states=output_hidden_states,
157
+ return_dict=return_dict,
158
+ )
159
+
160
+ sequence_output = outputs[0]
161
+ prediction_scores = self.lm_head(sequence_output) # mlm head
162
+ # noise_detect_scores = self.detector(pooled_output) # knowledge noise detector use pool output
163
+ noise_detect_scores = self.detector(sequence_output[:, 0, :]) # knowledge noise detector use cls embedding
164
+
165
+ # ner
166
+ # sequence_output = self.dropout(sequence_output)
167
+ # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0)
168
+
169
+ # mlm
170
+ masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None
171
+ total_loss = list()
172
+ if mlm_labels is not None:
173
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
174
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
175
+ total_loss.append(masked_lm_loss)
176
+
177
+ # if noise_detect_label is not None:
178
+ # noise_detect_scores = noise_detect_scores[task_id == 1]
179
+ # noise_detect_label = noise_detect_label[task_id == 1]
180
+ #
181
+ # if len(noise_detect_label) > 0:
182
+ # loss_fct = CrossEntropyLoss()
183
+ # noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1))
184
+ # total_loss.append(noise_detect_loss)
185
+
186
+ entity_candidate = entity_candidate[task_id == 2]
187
+ if len(entity_candidate) > 0:
188
+ batch_size = entity_candidate.shape[0]
189
+ candidate_num = entity_candidate.shape[1]
190
+ # print("negative_num=", negative_num)
191
+ # 获取被mask实体的embedding
192
+ batch_entity_query_embedding = list()
193
+ for ei, input_id in enumerate(input_ids[task_id == 2]):
194
+ batch_entity_query_embedding.append(
195
+ torch.mean(sequence_output[task_id == 2][ei][input_id == mask_id[task_id == 2][ei]], 0)) # [hidden_dim]
196
+ batch_entity_query_embedding = torch.stack(batch_entity_query_embedding) # [bz, dim]
197
+ # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape)
198
+ batch_entity_query_embedding = self.entity_mlp(batch_entity_query_embedding) # [bz, dim]
199
+ batch_entity_query_embedding = batch_entity_query_embedding.unsqueeze(1).repeat((1, candidate_num, 1)) # [bz, 11, dim]
200
+ batch_entity_query_embedding = batch_entity_query_embedding.view(-1, batch_entity_query_embedding.shape[-1]) # [bz * 11, dim]
201
+ # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape)
202
+
203
+ # 获得positive和negative的BERT表示
204
+ # entity_candidiate: [bz, 11, len]
205
+ entity_candidate = entity_candidate.view(-1, entity_candidate.shape[-1]) # [bz * 11, len]
206
+ entity_candidate_embedding = self.roberta.embeddings(input_ids=entity_candidate) # [bz * 11, len, dim]
207
+ entity_candidate_embedding = self.entity_mlp(torch.mean(entity_candidate_embedding, 1)) # [bz * 11, dim]
208
+
209
+ contrastive_entity_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda()
210
+ contrastive_entity_label = contrastive_entity_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11]
211
+
212
+ entity_loss = self.contrastive_loss_fn(
213
+ batch_entity_query_embedding, entity_candidate_embedding, contrastive_entity_label
214
+ )
215
+ total_loss.append(entity_loss)
216
+
217
+ relation_candidate = relation_candidate[task_id == 3]
218
+ if len(relation_candidate) > 0:
219
+ batch_size = relation_candidate.shape[0]
220
+ candidate_num = relation_candidate.shape[1]
221
+ # print("negative_num=", negative_num)
222
+ # 获取被mask relation的embedding
223
+ batch_relation_query_embedding = list()
224
+ for ei, input_id in enumerate(input_ids[task_id == 3]):
225
+ batch_relation_query_embedding.append(
226
+ torch.mean(sequence_output[task_id == 3][ei][input_id == mask_id[task_id == 3][ei]], 0)) # [hidden_dim]
227
+ batch_relation_query_embedding = torch.stack(batch_relation_query_embedding) # [bz, dim]
228
+ # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape)
229
+ batch_relation_query_embedding = self.relation_mlp(batch_relation_query_embedding) # [bz, dim]
230
+ batch_relation_query_embedding = batch_relation_query_embedding.unsqueeze(1).repeat(
231
+ (1, candidate_num, 1)) # [bz, 11, dim]
232
+ batch_relation_query_embedding = batch_relation_query_embedding.view(-1, batch_relation_query_embedding.shape[-1]) # [bz * 11, dim]
233
+ # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape)
234
+
235
+ # 获得positive和negative的BERT表示
236
+ # entity_candidiate: [bz, 11, len]
237
+ relation_candidate = relation_candidate.view(-1, relation_candidate.shape[-1]) # [bz * 11, len]
238
+ relation_candidate_embedding = self.roberta.embeddings(input_ids=relation_candidate) # [bz * 11, len, dim]
239
+ relation_candidate_embedding = self.relation_mlp(torch.mean(relation_candidate_embedding, 1)) # [bz * 11, dim]
240
+
241
+ contrastive_relation_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda()
242
+ contrastive_relation_label = contrastive_relation_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11]
243
+
244
+ relation_loss = self.contrastive_loss_fn(
245
+ batch_relation_query_embedding, relation_candidate_embedding, contrastive_relation_label
246
+ )
247
+ total_loss.append(relation_loss)
248
+
249
+ total_loss = torch.sum(torch.stack(total_loss), -1)
250
+
251
+ # end_time = time()
252
+ # print("neural_mode_time: {}".format(end_time - start_time))
253
+ # print("masked_lm_loss.unsqueeze(0)=", masked_lm_loss.unsqueeze(0))
254
+ # print("masked_lm_loss.unsqueeze(0).shape=", masked_lm_loss.unsqueeze(0).shape)
255
+ # print("logits=", prediction_scores.argmax(2))
256
+ # print("logits.shape=", prediction_scores.argmax(2).shape)
257
+
258
+
259
+ return OrderedDict([
260
+ ("loss", total_loss),
261
+ ("mlm_loss", masked_lm_loss.unsqueeze(0)),
262
+ # ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None),
263
+ # ("entity_loss", entity_loss.unsqueeze(0) if entity_loss is not None else None),
264
+ # ("relation_loss", relation_loss.unsqueeze(0) if relation_loss is not None else None),
265
+ ("logits", prediction_scores.argmax(2)),
266
+ # ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None and len(noise_detect_scores) > 0 else None),
267
+ ])
268
+
269
+
270
+ class DeBertaKPPLMForProcessedWikiKGPLM(DebertaForMaskedLM):
271
+
272
+ def __init__(self, config):
273
+ super().__init__(config)
274
+ self.num_labels = config.num_labels
275
+ self.config = config
276
+ # self.roberta = RobertaModel(config)
277
+ try:
278
+ classifier_dropout = (
279
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
280
+ )
281
+ except:
282
+ classifier_dropout = (config.hidden_dropout_prob)
283
+ self.dropout = nn.Dropout(classifier_dropout)
284
+ # self.cls = BertOnlyMLMHead(config)
285
+ # self.lm_head = RobertaLMHead(config) # Masked Language Modeling head
286
+ self.detector = NSPHead(config) # Knowledge Noise Detection head
287
+ self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size)
288
+ self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size)
289
+ # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)])
290
+
291
+ self.contrastive_loss_fn = ContrastiveLoss()
292
+ self.post_init()
293
+
294
+ def forward(
295
+ self,
296
+ input_ids=None,
297
+ attention_mask=None,
298
+ token_type_ids=None,
299
+ position_ids=None,
300
+ head_mask=None,
301
+ inputs_embeds=None,
302
+ encoder_hidden_states=None,
303
+ encoder_attention_mask=None,
304
+ labels=None,
305
+ # entity_label=None,
306
+ entity_candidate=None,
307
+ # relation_label=None,
308
+ relation_candidate=None,
309
+ noise_detect_label=None,
310
+ task_id=None,
311
+ mask_id=None,
312
+ output_attentions=None,
313
+ output_hidden_states=None,
314
+ return_dict=None,
315
+ ):
316
+ # start_time = time()
317
+ mlm_labels = labels
318
+
319
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
320
+ # print("attention_mask.shape=", attention_mask.shape)
321
+ # print("input_ids[0]=", input_ids[0])
322
+ # print("token_type_ids[0]=", token_type_ids[0])
323
+ # attention_mask = None
324
+
325
+ outputs = self.deberta(
326
+ input_ids,
327
+ # attention_mask=attention_mask,
328
+ attention_mask=None,
329
+ token_type_ids=token_type_ids,
330
+ position_ids=position_ids,
331
+ inputs_embeds=inputs_embeds,
332
+ output_attentions=output_attentions,
333
+ output_hidden_states=output_hidden_states,
334
+ return_dict=return_dict,
335
+ )
336
+
337
+ sequence_output = outputs[0]
338
+ prediction_scores = self.cls(sequence_output) # mlm head
339
+ # noise_detect_scores = self.detector(pooled_output) # knowledge noise detector use pool output
340
+ noise_detect_scores = self.detector(sequence_output[:, 0, :]) # knowledge noise detector use cls embedding
341
+
342
+ # ner
343
+ # sequence_output = self.dropout(sequence_output)
344
+ # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0)
345
+
346
+ # mlm
347
+ masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None
348
+ total_loss = list()
349
+ if mlm_labels is not None:
350
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
351
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
352
+ total_loss.append(masked_lm_loss)
353
+
354
+ # if noise_detect_label is not None:
355
+ # noise_detect_scores = noise_detect_scores[task_id == 1]
356
+ # noise_detect_label = noise_detect_label[task_id == 1]
357
+ #
358
+ # if len(noise_detect_label) > 0:
359
+ # loss_fct = CrossEntropyLoss()
360
+ # noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1))
361
+ # total_loss.append(noise_detect_loss)
362
+
363
+ entity_candidate = entity_candidate[task_id == 2]
364
+ if len(entity_candidate) > 0:
365
+ batch_size = entity_candidate.shape[0]
366
+ candidate_num = entity_candidate.shape[1]
367
+ # print("negative_num=", negative_num)
368
+ # 获取被mask实体的embedding
369
+ batch_entity_query_embedding = list()
370
+ for ei, input_id in enumerate(input_ids[task_id == 2]):
371
+ batch_entity_query_embedding.append(
372
+ torch.mean(sequence_output[task_id == 2][ei][input_id == mask_id[task_id == 2][ei]], 0)) # [hidden_dim]
373
+ batch_entity_query_embedding = torch.stack(batch_entity_query_embedding) # [bz, dim]
374
+ # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape)
375
+ batch_entity_query_embedding = self.entity_mlp(batch_entity_query_embedding) # [bz, dim]
376
+ batch_entity_query_embedding = batch_entity_query_embedding.unsqueeze(1).repeat((1, candidate_num, 1)) # [bz, 11, dim]
377
+ batch_entity_query_embedding = batch_entity_query_embedding.view(-1, batch_entity_query_embedding.shape[-1]) # [bz * 11, dim]
378
+ # print("batch_entity_query_embedding.shape=", batch_entity_query_embedding.shape)
379
+
380
+ # 获得positive和negative的BERT表示
381
+ # entity_candidiate: [bz, 11, len]
382
+ entity_candidate = entity_candidate.view(-1, entity_candidate.shape[-1]) # [bz * 11, len]
383
+ entity_candidate_embedding = self.deberta.embeddings(input_ids=entity_candidate) # [bz * 11, len, dim]
384
+ entity_candidate_embedding = self.entity_mlp(torch.mean(entity_candidate_embedding, 1)) # [bz * 11, dim]
385
+
386
+ contrastive_entity_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda()
387
+ contrastive_entity_label = contrastive_entity_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11]
388
+
389
+ entity_loss = self.contrastive_loss_fn(
390
+ batch_entity_query_embedding, entity_candidate_embedding, contrastive_entity_label
391
+ )
392
+ total_loss.append(entity_loss)
393
+
394
+ relation_candidate = relation_candidate[task_id == 3]
395
+ if len(relation_candidate) > 0:
396
+ batch_size = relation_candidate.shape[0]
397
+ candidate_num = relation_candidate.shape[1]
398
+ # print("negative_num=", negative_num)
399
+ # 获取被mask relation的embedding
400
+ batch_relation_query_embedding = list()
401
+ for ei, input_id in enumerate(input_ids[task_id == 3]):
402
+ batch_relation_query_embedding.append(
403
+ torch.mean(sequence_output[task_id == 3][ei][input_id == mask_id[task_id == 3][ei]], 0)) # [hidden_dim]
404
+ batch_relation_query_embedding = torch.stack(batch_relation_query_embedding) # [bz, dim]
405
+ # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape)
406
+ batch_relation_query_embedding = self.relation_mlp(batch_relation_query_embedding) # [bz, dim]
407
+ batch_relation_query_embedding = batch_relation_query_embedding.unsqueeze(1).repeat(
408
+ (1, candidate_num, 1)) # [bz, 11, dim]
409
+ batch_relation_query_embedding = batch_relation_query_embedding.view(-1, batch_relation_query_embedding.shape[-1]) # [bz * 11, dim]
410
+ # print("batch_relation_query_embedding.shape=", batch_relation_query_embedding.shape)
411
+
412
+ # 获得positive和negative的BERT表示
413
+ # entity_candidiate: [bz, 11, len]
414
+ relation_candidate = relation_candidate.view(-1, relation_candidate.shape[-1]) # [bz * 11, len]
415
+ relation_candidate_embedding = self.deberta.embeddings(input_ids=relation_candidate) # [bz * 11, len, dim]
416
+ relation_candidate_embedding = self.relation_mlp(torch.mean(relation_candidate_embedding, 1)) # [bz * 11, dim]
417
+
418
+ contrastive_relation_label = torch.Tensor([0] * (candidate_num - 1) + [1]).float().cuda()
419
+ contrastive_relation_label = contrastive_relation_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11]
420
+
421
+ relation_loss = self.contrastive_loss_fn(
422
+ batch_relation_query_embedding, relation_candidate_embedding, contrastive_relation_label
423
+ )
424
+ total_loss.append(relation_loss)
425
+
426
+ total_loss = torch.sum(torch.stack(total_loss), -1)
427
+
428
+ # end_time = time()
429
+ # print("neural_mode_time: {}".format(end_time - start_time))
430
+ # print("masked_lm_loss.unsqueeze(0)=", masked_lm_loss.unsqueeze(0))
431
+ # print("masked_lm_loss.unsqueeze(0).shape=", masked_lm_loss.unsqueeze(0).shape)
432
+ # print("logits=", prediction_scores.argmax(2))
433
+ # print("logits.shape=", prediction_scores.argmax(2).shape)
434
+
435
+
436
+ return OrderedDict([
437
+ ("loss", total_loss),
438
+ ("mlm_loss", masked_lm_loss.unsqueeze(0)),
439
+ # ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None),
440
+ # ("entity_loss", entity_loss.unsqueeze(0) if entity_loss is not None else None),
441
+ # ("relation_loss", relation_loss.unsqueeze(0) if relation_loss is not None else None),
442
+ ("logits", prediction_scores.argmax(2)),
443
+ # ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None and len(noise_detect_scores) > 0 else None),
444
+ ])
445
+
446
+
447
+ class RoBertaForWikiKGPLM(RobertaPreTrainedModel):
448
+
449
+ def __init__(self, config):
450
+ super().__init__(config)
451
+ self.num_labels = config.num_labels
452
+ self.config = config
453
+ self.roberta = RobertaModel(config)
454
+ classifier_dropout = (
455
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
456
+ )
457
+ self.dropout = nn.Dropout(classifier_dropout)
458
+ # self.cls = BertOnlyMLMHead(config)
459
+ self.lm_head = RobertaLMHead(config) # Masked Language Modeling head
460
+ self.detector = NSPHead(config) # Knowledge Noise Detection head
461
+ self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size)
462
+ self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size)
463
+ # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)])
464
+
465
+ self.contrastive_loss_fn = ContrastiveLoss()
466
+ self.post_init()
467
+
468
+ self.tokenizer = RobertaTokenizer.from_pretrained(config.name_or_path)
469
+
470
+ def forward(
471
+ self,
472
+ input_ids=None,
473
+ attention_mask=None,
474
+ token_type_ids=None,
475
+ position_ids=None,
476
+ head_mask=None,
477
+ inputs_embeds=None,
478
+ encoder_hidden_states=None,
479
+ encoder_attention_mask=None,
480
+ mlm_labels=None,
481
+ entity_label=None,
482
+ entity_negative=None,
483
+ relation_label=None,
484
+ relation_negative=None,
485
+ noise_detect_label=None,
486
+ task_id=None,
487
+ mask_id=None,
488
+ output_attentions=None,
489
+ output_hidden_states=None,
490
+ return_dict=None,
491
+ ):
492
+ # start_time = time()
493
+
494
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
495
+ # print("attention_mask.shape=", attention_mask.shape)
496
+ # print("input_ids[0]=", input_ids[0])
497
+ # print("token_type_ids[0]=", token_type_ids[0])
498
+ # attention_mask = None
499
+
500
+
501
+ outputs = self.roberta(
502
+ input_ids,
503
+ attention_mask=attention_mask,
504
+ token_type_ids=token_type_ids,
505
+ position_ids=position_ids,
506
+ head_mask=head_mask,
507
+ inputs_embeds=inputs_embeds,
508
+ encoder_hidden_states=encoder_hidden_states,
509
+ encoder_attention_mask=encoder_attention_mask,
510
+ output_attentions=output_attentions,
511
+ output_hidden_states=output_hidden_states,
512
+ return_dict=return_dict,
513
+ )
514
+
515
+ sequence_output, pooled_output = outputs[:2]
516
+ prediction_scores = self.lm_head(sequence_output) # mlm head
517
+ noise_detect_scores = self.detector(pooled_output) # knowledge noise detector
518
+
519
+
520
+ # ner
521
+ # sequence_output = self.dropout(sequence_output)
522
+ # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0)
523
+
524
+ # mlm
525
+ masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None
526
+ if mlm_labels is not None:
527
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
528
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
529
+
530
+ if noise_detect_label is not None:
531
+ loss_fct = CrossEntropyLoss()
532
+ noise_detect_loss = loss_fct(noise_detect_scores.view(-1, 2), noise_detect_label.view(-1))
533
+ total_loss = masked_lm_loss + noise_detect_loss
534
+
535
+ if entity_label is not None and entity_negative is not None:
536
+ batch_size = input_ids.shape[0]
537
+ negative_num = entity_negative.shape[1]
538
+ # print("negative_num=", negative_num)
539
+ # 获取被mask实体的embedding
540
+ batch_query_embedding = list()
541
+ for ei, input_id in enumerate(input_ids):
542
+ batch_query_embedding.append(torch.mean(sequence_output[ei][input_id == mask_id[ei]], 0)) # [hidden_dim]
543
+ batch_query_embedding = torch.stack(batch_query_embedding) # [bz, dim]
544
+ # print("batch_query_embedding.shape=", batch_query_embedding.shape)
545
+ batch_query_embedding = self.entity_mlp(batch_query_embedding) # [bz, dim]
546
+ batch_query_embedding = batch_query_embedding.unsqueeze(1).repeat((1, negative_num + 1, 1)) # [bz, 11, dim]
547
+ batch_query_embedding = batch_query_embedding.view(-1, batch_query_embedding.shape[-1]) # [bz * 11, dim]
548
+ # print("batch_query_embedding.shape=", batch_query_embedding.shape)
549
+
550
+ # 获得positive和negative的BERT表示
551
+ # entity_label: [bz, len], entity_negative: [bz, 10, len]
552
+ entity_negative = entity_negative.view(-1, entity_negative.shape[-1]) # [bz * 10, len]
553
+ entity_label_embedding = self.roberta.embeddings(input_ids=entity_label) # [bz, len, dim]
554
+ entity_label_embedding = self.entity_mlp(torch.mean(entity_label_embedding, 1)) # [bz, dim]
555
+ entity_label_embedding = entity_label_embedding.unsqueeze(1) # [bz, 1, dim]
556
+
557
+ entity_negative_embedding = self.roberta.embeddings(input_ids=entity_negative) # [bz * 10, len, dim]
558
+ entity_negative_embedding = self.entity_mlp(torch.mean(entity_negative_embedding, 1)) # [bz * 10, dim]
559
+ entity_negative_embedding = entity_negative_embedding \
560
+ .view(input_ids.shape[0], -1, entity_negative_embedding.shape[-1]) # [bz, 10, dim]
561
+
562
+ contrastive_label = torch.Tensor([0] * negative_num + [1]).float().cuda()
563
+ contrastive_label = contrastive_label.unsqueeze(0).repeat([batch_size, 1]).view(-1) # [bz * 11]
564
+ # print("entity_negative_embedding.shape=", entity_negative_embedding.shape)
565
+ # print("entity_label_embedding.shape=", entity_label_embedding.shape)
566
+ candidate_embedding = torch.cat([entity_negative_embedding, entity_label_embedding], 1) # [bz, 11, dim]
567
+ candidate_embedding = candidate_embedding.view(-1, candidate_embedding.shape[-1]) # [bz * 11, dim]
568
+ # print("candidate_embedding.shape=", candidate_embedding.shape)
569
+
570
+ entity_loss = self.contrastive_loss_fn(batch_query_embedding, candidate_embedding, contrastive_label)
571
+ total_loss = masked_lm_loss + entity_loss
572
+
573
+
574
+ # if ner_labels is not None:
575
+ # loss_fct = CrossEntropyLoss()
576
+ # # Only keep active parts of the loss
577
+ #
578
+ # active_loss = attention_mask.repeat(self.config.entity_type_num, 1, 1).view(-1) == 1
579
+ # active_logits = ner_logits.reshape(-1, self.config.num_ner_labels)
580
+ # active_labels = torch.where(
581
+ # active_loss, ner_labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(ner_labels)
582
+ # )
583
+ # ner_loss = loss_fct(active_logits, active_labels)
584
+ #
585
+ # if masked_lm_loss:
586
+ # total_loss = masked_lm_loss + ner_loss * 4
587
+ # print("total_loss=", total_loss)
588
+ # print("mlm_loss=", masked_lm_loss)
589
+
590
+
591
+ # end_time = time()
592
+ # print("neural_mode_time: {}".format(end_time - start_time))
593
+
594
+ return OrderedDict([
595
+ ("loss", total_loss),
596
+ ("mlm_loss", masked_lm_loss.unsqueeze(0)),
597
+ ("noise_detect_loss", noise_detect_loss.unsqueeze(0) if noise_detect_loss is not None else None),
598
+ ("entity_loss", entity_loss.unsqueeze(0) if entity_label is not None else None),
599
+ ("logits", prediction_scores.argmax(2)),
600
+ ("noise_detect_logits", noise_detect_scores.argmax(-1) if noise_detect_scores is not None else None),
601
+ ])
602
+ # MaskedLMOutput(
603
+ # loss=total_loss,
604
+ # logits=prediction_scores.argmax(2),
605
+ # ner_l
606
+ # hidden_states=outputs.hidden_states,
607
+ # attentions=outputs.attentions,
608
+ # )
609
+
610
+
611
+
612
+
613
+ class BertForWikiKGPLM(BertPreTrainedModel):
614
+
615
+ def __init__(self, config):
616
+ super().__init__(config)
617
+ self.num_labels = config.num_labels
618
+ self.config = config
619
+ self.bert = BertModel(config)
620
+ classifier_dropout = (
621
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
622
+ )
623
+ self.dropout = nn.Dropout(classifier_dropout)
624
+ # self.cls = BertOnlyMLMHead(config)
625
+ self.cls = BertPreTrainedModel(config)
626
+ self.entity_mlp = nn.Linear(config.hidden_size, config.hidden_size)
627
+ self.relation_mlp = nn.Linear(config.hidden_size, config.hidden_size)
628
+ # self.classifiers = nn.ModuleList([nn.Linear(config.hidden_size, config.num_ner_labels) for _ in range(config.entity_type_num)])
629
+
630
+ self.contrastive_loss_fn = ContrastiveLoss()
631
+ self.post_init()
632
+
633
+ def forward(
634
+ self,
635
+ input_ids=None,
636
+ attention_mask=None,
637
+ token_type_ids=None,
638
+ position_ids=None,
639
+ head_mask=None,
640
+ inputs_embeds=None,
641
+ encoder_hidden_states=None,
642
+ encoder_attention_mask=None,
643
+ mlm_labels=None,
644
+ entity_label=None,
645
+ entity_negative=None,
646
+ relation_label=None,
647
+ relation_negative=None,
648
+ noise_detect_label=None,
649
+ task_id=None,
650
+ mask_id=None,
651
+ output_attentions=None,
652
+ output_hidden_states=None,
653
+ return_dict=None,
654
+ ):
655
+
656
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
657
+ print("attention_mask.shape=", attention_mask.shape)
658
+ print("input_ids[0]=", input_ids[0])
659
+ print("token_type_ids[0]=", token_type_ids[0])
660
+ attention_mask = None
661
+ outputs = self.bert(
662
+ input_ids,
663
+ attention_mask=attention_mask,
664
+ token_type_ids=token_type_ids,
665
+ position_ids=position_ids,
666
+ head_mask=head_mask,
667
+ inputs_embeds=inputs_embeds,
668
+ encoder_hidden_states=encoder_hidden_states,
669
+ encoder_attention_mask=encoder_attention_mask,
670
+ output_attentions=output_attentions,
671
+ output_hidden_states=output_hidden_states,
672
+ return_dict=return_dict,
673
+ )
674
+
675
+ sequence_output, pooled_output = outputs[:2]
676
+ prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
677
+
678
+ # ner
679
+ # sequence_output = self.dropout(sequence_output)
680
+ # ner_logits = torch.stack([classifier(sequence_output) for classifier in self.classifiers]).movedim(1, 0)
681
+
682
+ # mlm
683
+ masked_lm_loss, noise_detect_loss, entity_loss, total_loss = None, None, None, None
684
+
685
+ if mlm_labels is not None:
686
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
687
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), mlm_labels.view(-1))
688
+
689
+ if noise_detect_label is not None:
690
+ loss_fct = CrossEntropyLoss()
691
+ noise_detect_loss = loss_fct(seq_relationship_score.view(-1, 2), noise_detect_label.view(-1))
692
+ total_loss = masked_lm_loss + noise_detect_loss
693
+
694
+ if entity_label is not None and entity_negative is not None:
695
+ negative_num = entity_negative.shape[1]
696
+ # 获取被mask实体的embedding
697
+ batch_query_embedding = list()
698
+ for ei, input_id in enumerate(input_ids):
699
+ batch_query_embedding.append(torch.mean(sequence_output[ei][input_id == mask_id[ei]], 0)) # [hidden_dim]
700
+ batch_query_embedding = torch.stack(batch_query_embedding) # [bz, dim]
701
+ batch_query_embedding = self.entity_mlp(batch_query_embedding) # [bz, dim]
702
+ batch_query_embedding = batch_query_embedding.repeat((1, negative_num + 1, 1)) # [bz, 11, dim]
703
+
704
+ # 获得positive和negative的BERT表示
705
+ # entity_label: [bz, len], entity_negative: [bz, 10, len]
706
+ entity_negative = entity_negative.view(-1, entity_negative.shape[-1]) # [bz * 10, len]
707
+ entity_label_embedding = self.bert.embeddings(input_id=entity_label) # [bz, len, dim]
708
+ entity_label_embedding = self.entity_mlp(torch.mean(entity_label_embedding, 1)) # [bz, dim]
709
+ entity_label_embedding = entity_label_embedding.unsqueeze(1) # [bz, 1, dim]
710
+
711
+ entity_negative_embedding = self.bert.embeddings(input_id=entity_negative) # [bz * 10, len, dim]
712
+ entity_negative_embedding = self.entity_mlp(torch.mean(entity_negative_embedding, 1)) # [bz * 10, dim]
713
+ entity_negative_embedding = entity_negative_embedding \
714
+ .view(input_ids.shape[0], -1, entity_negative_embedding.shape[-1]) # [bz, 10, dim]
715
+
716
+ contrastive_label = torch.Tensor([0] * negative_num + [1]).float().cuda()
717
+ candidate_embedding = torch.cat([entity_negative_embedding, entity_label_embedding], 1) # [bz, 11, dim]
718
+
719
+ entity_loss = self.contrastive_loss_fn(batch_query_embedding, candidate_embedding, contrastive_label)
720
+ total_loss = masked_lm_loss + entity_loss
721
+
722
+
723
+ # if ner_labels is not None:
724
+ # loss_fct = CrossEntropyLoss()
725
+ # # Only keep active parts of the loss
726
+ #
727
+ # active_loss = attention_mask.repeat(self.config.entity_type_num, 1, 1).view(-1) == 1
728
+ # active_logits = ner_logits.reshape(-1, self.config.num_ner_labels)
729
+ # active_labels = torch.where(
730
+ # active_loss, ner_labels.view(-1), torch.tensor(loss_fct.ignore_index).type_as(ner_labels)
731
+ # )
732
+ # ner_loss = loss_fct(active_logits, active_labels)
733
+ #
734
+ # if masked_lm_loss:
735
+ # total_loss = masked_lm_loss + ner_loss * 4
736
+
737
+ return OrderedDict([
738
+ ("loss", total_loss),
739
+ ("mlm_loss", masked_lm_loss.unsqueeze(0)),
740
+ ("noise_detect_loss", noise_detect_loss.unsqueeze(0)),
741
+ ("entity_loss", entity_loss.unsqueeze(0)),
742
+ ("logits", prediction_scores.argmax(2)),
743
+ ("noise_detect_logits", seq_relationship_score.argmax(3)),
744
+ ()
745
+ ])
746
+ # MaskedLMOutput(
747
+ # loss=total_loss,
748
+ # logits=prediction_scores.argmax(2),
749
+ # ner_l
750
+ # hidden_states=outputs.hidden_states,
751
+ # attentions=outputs.attentions,
752
+ # )
models/language_modeling/mlm.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2021/12/30 8:35 下午
3
+ # @Author : JianingWang
4
+ # @File : mlm.py
5
+ import logging
6
+ from typing import Union, Tuple, Optional
7
+ import torch
8
+ from torch.nn import CrossEntropyLoss
9
+ from transformers.modeling_outputs import MaskedLMOutput
10
+ from transformers.models.bert import BertPreTrainedModel
11
+ from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertModel, BertOnlyMLMHead
12
+ from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel, RobertaModel, RobertaLMHead
13
+ from transformers.models.albert.modeling_albert import AlbertPreTrainedModel, AlbertModel, AlbertMLMHead
14
+ from transformers.models.roformer.modeling_roformer import RoFormerPreTrainedModel, RoFormerModel, RoFormerOnlyMLMHead
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ """
19
+ Function: Use MLM to pre-train BERT
20
+ Notes:
21
+ - The label of non-masked token is -100, which can be used for cross-entropy function (only calculate loss at not -100)
22
+ """
23
+ class BertForMaskedLM(BertPreTrainedModel):
24
+
25
+ def __init__(self, config, *inputs, **kwargs):
26
+ super().__init__(config, *inputs, **kwargs)
27
+
28
+ self.bert = BertModel(config, add_pooling_layer=False)
29
+ self.cls = BertOnlyMLMHead(config)
30
+
31
+ # Initialize weights and apply final processing
32
+ self.post_init()
33
+
34
+ def forward(
35
+ self,
36
+ input_ids: Optional[torch.Tensor] = None,
37
+ attention_mask: Optional[torch.Tensor] = None,
38
+ token_type_ids: Optional[torch.Tensor] = None,
39
+ position_ids: Optional[torch.Tensor] = None,
40
+ head_mask: Optional[torch.Tensor] = None,
41
+ inputs_embeds: Optional[torch.Tensor] = None,
42
+ encoder_hidden_states: Optional[torch.Tensor] = None,
43
+ encoder_attention_mask: Optional[torch.Tensor] = None,
44
+ labels: Optional[torch.Tensor] = None,
45
+ output_attentions: Optional[bool] = None,
46
+ output_hidden_states: Optional[bool] = None,
47
+ return_dict: Optional[bool] = None,
48
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
49
+ r"""
50
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
51
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
52
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
53
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
54
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
55
+ Used to hide legacy arguments that have been deprecated.
56
+ """
57
+
58
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
59
+ outputs = self.bert(
60
+ input_ids,
61
+ attention_mask=attention_mask,
62
+ token_type_ids=token_type_ids,
63
+ position_ids=position_ids,
64
+ head_mask=head_mask,
65
+ inputs_embeds=inputs_embeds,
66
+ encoder_hidden_states=encoder_hidden_states,
67
+ encoder_attention_mask=encoder_attention_mask,
68
+ output_attentions=output_attentions,
69
+ output_hidden_states=output_hidden_states,
70
+ return_dict=return_dict,
71
+ )
72
+
73
+ sequence_output = outputs[0]
74
+ prediction_scores = self.cls(sequence_output)
75
+
76
+ masked_lm_loss = None
77
+ if labels is not None:
78
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
79
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
80
+
81
+ if not return_dict:
82
+ output = (prediction_scores,) + outputs[2:]
83
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
84
+
85
+ return MaskedLMOutput(
86
+ loss=masked_lm_loss, # ()
87
+ logits=prediction_scores, # (batch_size, seq_len, vocab_size)
88
+ hidden_states=outputs.hidden_states, # (batch_size, seq_len, hidden_size)
89
+ attentions=outputs.attentions,
90
+ )
91
+
92
+ """
93
+ Function: Use MLM to pre-train RoBERTa
94
+ Notes:
95
+ - The label of non-masked token is -100, which can be used for cross-entropy function (only calculate loss at not -100)
96
+ """
97
+ class RobertaForMaskedLM(RobertaPreTrainedModel):
98
+
99
+ def __init__(self, config, *inputs, **kwargs):
100
+ super().__init__(config, *inputs, **kwargs)
101
+
102
+ self.roberta = BertModel(config, add_pooling_layer=False)
103
+ self.lm_head = RobertaLMHead(config)
104
+
105
+ # Initialize weights and apply final processing
106
+ self.post_init()
107
+
108
+ def forward(
109
+ self,
110
+ input_ids: Optional[torch.LongTensor] = None,
111
+ attention_mask: Optional[torch.FloatTensor] = None,
112
+ token_type_ids: Optional[torch.LongTensor] = None,
113
+ position_ids: Optional[torch.LongTensor] = None,
114
+ head_mask: Optional[torch.FloatTensor] = None,
115
+ inputs_embeds: Optional[torch.FloatTensor] = None,
116
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
117
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
118
+ labels: Optional[torch.LongTensor] = None,
119
+ output_attentions: Optional[bool] = None,
120
+ output_hidden_states: Optional[bool] = None,
121
+ return_dict: Optional[bool] = None,
122
+ ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
123
+ r"""
124
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
125
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
126
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
127
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
128
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
129
+ Used to hide legacy arguments that have been deprecated.
130
+ """
131
+
132
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
133
+ outputs = self.roberta(
134
+ input_ids,
135
+ attention_mask=attention_mask,
136
+ token_type_ids=token_type_ids,
137
+ position_ids=position_ids,
138
+ head_mask=head_mask,
139
+ inputs_embeds=inputs_embeds,
140
+ encoder_hidden_states=encoder_hidden_states,
141
+ encoder_attention_mask=encoder_attention_mask,
142
+ output_attentions=output_attentions,
143
+ output_hidden_states=output_hidden_states,
144
+ return_dict=return_dict,
145
+ )
146
+
147
+ sequence_output = outputs[0]
148
+ prediction_scores = self.lm_head(sequence_output)
149
+
150
+ masked_lm_loss = None
151
+ if labels is not None:
152
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
153
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
154
+
155
+ if not return_dict:
156
+ output = (prediction_scores,) + outputs[2:]
157
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
158
+
159
+ return MaskedLMOutput(
160
+ loss=masked_lm_loss, # ()
161
+ logits=prediction_scores, # (batch_size, seq_len, vocab_size)
162
+ hidden_states=outputs.hidden_states, # (batch_size, seq_len, hidden_size)
163
+ attentions=outputs.attentions,
164
+ )
165
+
166
+ """
167
+ Function: Use MLM to pre-train ALBERT
168
+ Notes:
169
+ - The label of non-masked token is -100, which can be used for cross-entropy function (only calculate loss at not -100)
170
+ """
171
+ class AlbertForMaskedLM(AlbertPreTrainedModel):
172
+
173
+ def __init__(self, config, *inputs, **kwargs):
174
+ super().__init__(config, *inputs, **kwargs)
175
+
176
+ self.albert = AlbertModel(config, add_pooling_layer=False)
177
+ self.predictions = AlbertMLMHead(config)
178
+
179
+ # Initialize weights and apply final processing
180
+ self.post_init()
181
+
182
+ def forward(
183
+ self,
184
+ input_ids: Optional[torch.LongTensor] = None,
185
+ attention_mask: Optional[torch.FloatTensor] = None,
186
+ token_type_ids: Optional[torch.LongTensor] = None,
187
+ position_ids: Optional[torch.LongTensor] = None,
188
+ head_mask: Optional[torch.FloatTensor] = None,
189
+ inputs_embeds: Optional[torch.FloatTensor] = None,
190
+ labels: Optional[torch.LongTensor] = None,
191
+ output_attentions: Optional[bool] = None,
192
+ output_hidden_states: Optional[bool] = None,
193
+ return_dict: Optional[bool] = None,
194
+ ) -> Union[MaskedLMOutput, Tuple]:
195
+ r"""
196
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
197
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
198
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
199
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
200
+
201
+ Returns:
202
+
203
+ Example:
204
+
205
+ ```python
206
+ >>> import torch
207
+ >>> from transformers import AlbertTokenizer, AlbertForMaskedLM
208
+
209
+ >>> tokenizer = AlbertTokenizer.from_pretrained("albert-base-v2")
210
+ >>> model = AlbertForMaskedLM.from_pretrained("albert-base-v2")
211
+
212
+ >>> # add mask_token
213
+ >>> inputs = tokenizer("The capital of [MASK] is Paris.", return_tensors="pt")
214
+ >>> with torch.no_grad():
215
+ ... logits = model(**inputs).logits
216
+
217
+ >>> # retrieve index of [MASK]
218
+ >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
219
+ >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
220
+ >>> tokenizer.decode(predicted_token_id)
221
+ "france"
222
+ ```
223
+
224
+ ```python
225
+ >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
226
+ >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
227
+ >>> outputs = model(**inputs, labels=labels)
228
+ >>> round(outputs.loss.item(), 2)
229
+ 0.81
230
+ ```
231
+ """
232
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
233
+
234
+ outputs = self.albert(
235
+ input_ids=input_ids,
236
+ attention_mask=attention_mask,
237
+ token_type_ids=token_type_ids,
238
+ position_ids=position_ids,
239
+ head_mask=head_mask,
240
+ inputs_embeds=inputs_embeds,
241
+ output_attentions=output_attentions,
242
+ output_hidden_states=output_hidden_states,
243
+ return_dict=return_dict,
244
+ )
245
+ sequence_outputs = outputs[0]
246
+
247
+ prediction_scores = self.predictions(sequence_outputs)
248
+
249
+ masked_lm_loss = None
250
+ if labels is not None:
251
+ loss_fct = CrossEntropyLoss()
252
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
253
+
254
+ if not return_dict:
255
+ output = (prediction_scores,) + outputs[2:]
256
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
257
+
258
+ return MaskedLMOutput(
259
+ loss=masked_lm_loss,
260
+ logits=prediction_scores,
261
+ hidden_states=outputs.hidden_states,
262
+ attentions=outputs.attentions,
263
+ )
264
+
265
+ """
266
+ Function: Use MLM to pre-train RoFormer
267
+ Notes:
268
+ - The label of non-masked token is -100, which can be used for cross-entropy function (only calculate loss at not -100)
269
+ """
270
+ class RoFormerForMaskedLM(RoFormerPreTrainedModel):
271
+ def __init__(self, config):
272
+ super().__init__(config)
273
+
274
+ if config.is_decoder:
275
+ logger.warning(
276
+ "If you want to use `RoFormerForMaskedLM` make sure `config.is_decoder=False` for "
277
+ "bi-directional self-attention."
278
+ )
279
+
280
+ self.roformer = RoFormerModel(config)
281
+ self.cls = RoFormerOnlyMLMHead(config)
282
+
283
+ # Initialize weights and apply final processing
284
+ self.post_init()
285
+
286
+ def forward(
287
+ self,
288
+ input_ids: Optional[torch.LongTensor] = None,
289
+ attention_mask: Optional[torch.FloatTensor] = None,
290
+ token_type_ids: Optional[torch.LongTensor] = None,
291
+ head_mask: Optional[torch.FloatTensor] = None,
292
+ inputs_embeds: Optional[torch.FloatTensor] = None,
293
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
294
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
295
+ labels: Optional[torch.LongTensor] = None,
296
+ output_attentions: Optional[bool] = None,
297
+ output_hidden_states: Optional[bool] = None,
298
+ return_dict: Optional[bool] = None,
299
+ ) -> Union[MaskedLMOutput, Tuple[torch.Tensor]]:
300
+ r"""
301
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
302
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
303
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
304
+ loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
305
+ """
306
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
307
+
308
+ outputs = self.roformer(
309
+ input_ids,
310
+ attention_mask=attention_mask,
311
+ token_type_ids=token_type_ids,
312
+ head_mask=head_mask,
313
+ inputs_embeds=inputs_embeds,
314
+ encoder_hidden_states=encoder_hidden_states,
315
+ encoder_attention_mask=encoder_attention_mask,
316
+ output_attentions=output_attentions,
317
+ output_hidden_states=output_hidden_states,
318
+ return_dict=return_dict,
319
+ )
320
+
321
+ sequence_output = outputs[0]
322
+ prediction_scores = self.cls(sequence_output)
323
+
324
+ masked_lm_loss = None
325
+ if labels is not None:
326
+ loss_fct = CrossEntropyLoss() # -100 index = padding token
327
+ masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
328
+
329
+ if not return_dict:
330
+ output = (prediction_scores,) + outputs[1:]
331
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
332
+
333
+ return MaskedLMOutput(
334
+ loss=masked_lm_loss,
335
+ logits=prediction_scores,
336
+ hidden_states=outputs.hidden_states,
337
+ attentions=outputs.attentions,
338
+ )
339
+
340
+
341
+ if __name__ == "__main__":
342
+ from transformers.models.bert.tokenization_bert import BertTokenizer
343
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
344
+ model = BertForMaskedLM.from_pretrained("bert-base-uncased")
345
+ input_text = "Today is a nice day, I will [MASK] to play [MASK] with my friends."
346
+ inputs = tokenizer(input_text, return_tensors="pt")
347
+ masked_positions = inputs["input_ids"] == tokenizer.mask_token_id
348
+ print("inputs=", inputs)
349
+ """
350
+ inputs= {"input_ids": tensor([[ 101, 2651, 2003, 1037, 3835, 2154, 1010, 1045, 2097, 103, 2000, 2377,
351
+ 103, 2007, 2026, 2814, 1012, 102]]), "token_type_ids": tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]), "attention_mask": tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])}
352
+ """
353
+ outputs = model(**inputs)
354
+ masked_results = outputs.logits.argmax(-1)[masked_positions]
355
+ masked_results = tokenizer.convert_ids_to_tokens(masked_results)
356
+ print("masked_results=", masked_results)
357
+ """
358
+ masked_results= ["have", "football"]
359
+ """